Skip to content

Commit 2107637

Browse files
committed
Update coc_plugin.py
1 parent 1d4bad4 commit 2107637

File tree

1 file changed

+71
-132
lines changed

1 file changed

+71
-132
lines changed

optillm/plugins/coc_plugin.py

Lines changed: 71 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import traceback
66
import math
77
import importlib
8+
import json
89

910
logger = logging.getLogger(__name__)
1011

@@ -24,194 +25,132 @@
2425
2. Each step should either be:
2526
- Executable Python code that performs computations
2627
- Pseudocode that you will simulate with natural language understanding
27-
3. Track program state after each line execution
28+
3. Track final result in an 'answer' variable
2829
4. Return the final answer within the <output> tags
2930
3031
Format your response using:
3132
```python
32-
[Your code here]
33+
[Your complete Python program here]
3334
```
3435
35-
And track state after each line with:
36-
delta_state: {...}
37-
3836
Finally provide output as:
3937
<output>
4038
[Your final answer]
4139
</output>
4240
'''
4341

44-
STATE_SIMULATION_PROMPT = '''You are simulating the execution of Python code.
45-
Given the current program state and a line of code, return ONLY a Python dictionary representing the new state variables.
46-
For module imports and references, return string representations.
47-
48-
Return ONLY primitive types (numbers, strings, lists, dicts) - no module references or complex objects.
42+
STATE_SIMULATION_PROMPT = '''You are simulating the execution of a Python program.
43+
Given the code below, simulate its execution and return the final value that would be in the 'answer' variable.
44+
Return ONLY the final value, no explanations or additional text.
4945
50-
For example:
51-
state = {'x': 5}
52-
code = "y = x + 3"
53-
Return: {'y': 8}
54-
55-
For modules:
56-
code = "import math"
57-
Return: {'math': 'module:math'}
46+
Code to simulate:
47+
{code}
5848
'''
5949

6050
def extract_code_blocks(text: str) -> List[str]:
6151
"""Extract Python code blocks from text."""
6252
pattern = r'```python\s*(.*?)\s*```'
6353
matches = re.findall(pattern, text, re.DOTALL)
64-
return [m.strip() for m in matches]
54+
blocks = [m.strip() for m in matches]
55+
logger.info(f"Extracted {len(blocks)} code blocks")
56+
for i, block in enumerate(blocks):
57+
logger.info(f"Code block {i+1}:\n{block}")
58+
return blocks
6559

6660
def extract_output(text: str) -> str:
6761
"""Extract content from output tags."""
6862
pattern = r'<output>(.*?)</output>'
6963
match = re.search(pattern, text, re.DOTALL)
70-
return match.group(1).strip() if match else text.strip()
71-
72-
def extract_state_updates(text: str) -> List[Dict[str, Any]]:
73-
"""Extract state updates from delta_state markers."""
74-
pattern = r'delta_state:\s*({.*?})'
75-
matches = re.findall(pattern, text, re.DOTALL)
76-
states = []
77-
for m in matches:
78-
try:
79-
cleaned = re.sub(r'```python\s*|\s*```', '', m)
80-
state = parse_state_dict(cleaned)
81-
states.append(state)
82-
except:
83-
logger.warning(f"Could not parse state update: {m}")
84-
return states
85-
86-
def clean_state_response(response: str) -> str:
87-
"""Clean up LM state response to get just the dictionary."""
88-
# Remove any code blocks
89-
response = re.sub(r'```python\s*|\s*```', '', response)
90-
# Remove any natural language before or after the dict
91-
response = re.sub(r'^[^{]*', '', response)
92-
response = re.sub(r'[^}]*$', '', response)
93-
return response.strip()
94-
95-
def parse_state_dict(state_str: str) -> Dict[str, Any]:
96-
"""Safely parse state dictionary, handling module references."""
97-
try:
98-
# First try direct evaluation
99-
return ast.literal_eval(state_str)
100-
except:
101-
# If that fails, try to parse manually
102-
state_dict = {}
103-
try:
104-
# Use a custom safe eval that handles module references
105-
# Remove brackets and split by commas
106-
items = state_str.strip('{}').split(',')
107-
for item in items:
108-
if ':' not in item:
109-
continue
110-
key, value = item.split(':', 1)
111-
key = key.strip().strip("'").strip('"')
112-
value = value.strip()
113-
114-
# Handle module references
115-
if 'module' in value:
116-
module_name = value.split("'")[1] if "'" in value else value.split(':')[1].strip()
117-
if module_name in ALLOWED_MODULES:
118-
state_dict[key] = ALLOWED_MODULES[module_name]
119-
# Handle normal values
120-
else:
121-
try:
122-
state_dict[key] = ast.literal_eval(value)
123-
except:
124-
state_dict[key] = value.strip("'").strip('"')
125-
return state_dict
126-
except Exception as e:
127-
logger.error(f"Failed to parse state dict: {state_str}")
128-
logger.error(f"Error: {str(e)}")
129-
return {}
130-
131-
def execute_line(line: str, state: Dict[str, Any], client, model: str) -> Tuple[Any, Dict[str, Any]]:
132-
"""Execute a single line of code, either with Python or LM simulation."""
64+
result = match.group(1).strip() if match else text.strip()
65+
logger.info(f"Extracted output: {result}")
66+
return result
67+
68+
def execute_code(code: str, client, model: str) -> Tuple[Any, int]:
69+
"""Execute full code block either with Python or LM simulation."""
70+
logger.info("Attempting to execute complete code block")
71+
logger.info(f"Code:\n{code}")
72+
73+
# Add imports
74+
execution_env = {}
75+
for mod_name, mod in ALLOWED_MODULES.items():
76+
execution_env[mod_name] = mod
77+
13378
try:
134-
# Handle imports specially
135-
if line.startswith('import '):
136-
module_name = line.split()[1]
137-
if module_name in ALLOWED_MODULES:
138-
return None, {module_name: ALLOWED_MODULES[module_name]}
139-
else:
140-
logger.warning(f"Skipping import of unauthorized module: {module_name}")
141-
return None, {}
142-
143-
# Try executing with Python
144-
local_state = state.copy()
145-
# Add allowed modules to local state
146-
for mod_name, mod in ALLOWED_MODULES.items():
147-
if mod_name in state:
148-
local_state[mod_name] = mod
149-
150-
exec(line, globals(), local_state)
151-
# Extract any new/modified variables
152-
new_state = {k:v for k,v in local_state.items()
153-
if k not in state or state[k] != v}
154-
return None, new_state
79+
# Try executing the complete code block with Python
80+
logger.info("Attempting Python execution")
81+
exec(code, execution_env)
82+
answer = execution_env.get('answer')
83+
logger.info(f"Python execution successful. Answer: {answer}")
84+
return answer, 0
15585

15686
except Exception as e:
87+
logger.info(f"Python execution failed: {str(e)}")
88+
logger.info("Falling back to LM simulation")
89+
15790
# If Python execution fails, simulate with LM
158-
context = f"Current program state: {state}\nExecute line: {line}"
15991
response = client.chat.completions.create(
16092
model=model,
16193
messages=[
162-
{"role": "system", "content": STATE_SIMULATION_PROMPT},
163-
{"role": "user", "content": context}
94+
{"role": "system", "content": STATE_SIMULATION_PROMPT.format(code=code)},
95+
{"role": "user", "content": "Simulate this code and return the final value of 'answer'."}
16496
],
16597
temperature=0.2
16698
)
99+
167100
try:
168-
cleaned_response = clean_state_response(response.choices[0].message.content)
169-
new_state = parse_state_dict(cleaned_response)
170-
return response.usage.completion_tokens, new_state
101+
answer = response.choices[0].message.content.strip()
102+
logger.info(f"LM simulation successful. Answer: {answer}")
103+
104+
# Try to convert to number if possible
105+
try:
106+
answer = ast.literal_eval(answer)
107+
except:
108+
pass
109+
110+
return answer, response.usage.completion_tokens
111+
171112
except Exception as e:
172-
logger.error(f"Could not parse LM state response: {response.choices[0].message.content}")
173-
logger.error(f"Error: {str(e)}")
174-
logger.error(f"Cleaned response: {cleaned_response}")
175-
return response.usage.completion_tokens, {}
113+
logger.error(f"Could not parse LM simulation response: {str(e)}")
114+
return None, response.usage.completion_tokens
176115

177116
def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
178117
"""Main Chain of Code execution function."""
118+
logger.info("Starting Chain of Code execution")
119+
logger.info(f"Query: {initial_query}")
120+
179121
messages = [
180122
{"role": "system", "content": system_prompt + "\n" + CHAIN_OF_CODE_PROMPT},
181123
{"role": "user", "content": initial_query}
182124
]
183125

126+
logger.info("Generating code solution")
184127
response = client.chat.completions.create(
185128
model=model,
186129
messages=messages,
187130
temperature=0.7
188131
)
189132
initial_response = response.choices[0].message.content
190133
total_tokens = response.usage.completion_tokens
134+
135+
logger.info("Initial response from LM:")
136+
logger.info(initial_response)
191137

192138
code_blocks = extract_code_blocks(initial_response)
193139
if not code_blocks:
194140
logger.warning("No code blocks found in response")
195141
return initial_response, total_tokens
196142

197-
final_state = {}
198-
code = code_blocks[0] # Take first code block
199-
200-
lines = [line.strip() for line in code.split('\n') if line.strip()]
201-
202-
for line in lines:
203-
if not line or line.startswith('#'):
204-
continue
143+
# Execute the complete code block
144+
code = code_blocks[0]
145+
answer, execution_tokens = execute_code(code, client, model)
146+
total_tokens += execution_tokens
147+
148+
# If we got an answer from code execution, use it
149+
if answer is not None:
150+
final_answer = str(answer)
151+
else:
152+
# Fall back to output tags if code execution failed
153+
final_answer = extract_output(initial_response)
205154

206-
tokens, new_state = execute_line(line, final_state, client, model)
207-
if tokens:
208-
total_tokens += tokens
209-
final_state.update(new_state)
210-
logger.debug(f"Executed line: {line}")
211-
logger.debug(f"New state: {new_state}")
212-
213-
final_answer = extract_output(initial_response)
214-
if not final_answer and 'answer' in final_state:
215-
final_answer = str(final_state['answer'])
216-
155+
logger.info(f"Chain of Code execution completed. Final answer: {final_answer}")
217156
return final_answer, total_tokens

0 commit comments

Comments
 (0)