Skip to content

Commit 1d4bad4

Browse files
committed
Update coc_plugin.py
fix coc
1 parent f18b921 commit 1d4bad4

File tree

1 file changed

+68
-13
lines changed

1 file changed

+68
-13
lines changed

optillm/plugins/coc_plugin.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@
33
from typing import Tuple, Dict, Any, List
44
import ast
55
import traceback
6+
import math
7+
import importlib
68

79
logger = logging.getLogger(__name__)
810

911
# Plugin identifier
1012
SLUG = "coc"
1113

14+
# List of allowed modules for execution
15+
ALLOWED_MODULES = {
16+
'math': math,
17+
}
18+
1219
# Prompts
1320
CHAIN_OF_CODE_PROMPT = '''
1421
You are an AI assistant that uses Chain of Code (CoC) approach to solve problems. Follow these steps:
@@ -36,13 +43,18 @@
3643

3744
STATE_SIMULATION_PROMPT = '''You are simulating the execution of Python code.
3845
Given the current program state and a line of code, return ONLY a Python dictionary representing the new state variables.
39-
Do not include any other text, code blocks, or formatting - just the Python dict.
46+
For module imports and references, return string representations.
47+
48+
Return ONLY primitive types (numbers, strings, lists, dicts) - no module references or complex objects.
4049
4150
For example:
4251
state = {'x': 5}
4352
code = "y = x + 3"
44-
You should return:
45-
{'y': 8}
53+
Return: {'y': 8}
54+
55+
For modules:
56+
code = "import math"
57+
Return: {'math': 'module:math'}
4658
'''
4759

4860
def extract_code_blocks(text: str) -> List[str]:
@@ -64,9 +76,8 @@ def extract_state_updates(text: str) -> List[Dict[str, Any]]:
6476
states = []
6577
for m in matches:
6678
try:
67-
# Clean up the state string before evaluation
6879
cleaned = re.sub(r'```python\s*|\s*```', '', m)
69-
state = ast.literal_eval(cleaned)
80+
state = parse_state_dict(cleaned)
7081
states.append(state)
7182
except:
7283
logger.warning(f"Could not parse state update: {m}")
@@ -81,17 +92,67 @@ def clean_state_response(response: str) -> str:
8192
response = re.sub(r'[^}]*$', '', response)
8293
return response.strip()
8394

95+
def parse_state_dict(state_str: str) -> Dict[str, Any]:
96+
"""Safely parse state dictionary, handling module references."""
97+
try:
98+
# First try direct evaluation
99+
return ast.literal_eval(state_str)
100+
except:
101+
# If that fails, try to parse manually
102+
state_dict = {}
103+
try:
104+
# Use a custom safe eval that handles module references
105+
# Remove brackets and split by commas
106+
items = state_str.strip('{}').split(',')
107+
for item in items:
108+
if ':' not in item:
109+
continue
110+
key, value = item.split(':', 1)
111+
key = key.strip().strip("'").strip('"')
112+
value = value.strip()
113+
114+
# Handle module references
115+
if 'module' in value:
116+
module_name = value.split("'")[1] if "'" in value else value.split(':')[1].strip()
117+
if module_name in ALLOWED_MODULES:
118+
state_dict[key] = ALLOWED_MODULES[module_name]
119+
# Handle normal values
120+
else:
121+
try:
122+
state_dict[key] = ast.literal_eval(value)
123+
except:
124+
state_dict[key] = value.strip("'").strip('"')
125+
return state_dict
126+
except Exception as e:
127+
logger.error(f"Failed to parse state dict: {state_str}")
128+
logger.error(f"Error: {str(e)}")
129+
return {}
130+
84131
def execute_line(line: str, state: Dict[str, Any], client, model: str) -> Tuple[Any, Dict[str, Any]]:
85132
"""Execute a single line of code, either with Python or LM simulation."""
86133
try:
134+
# Handle imports specially
135+
if line.startswith('import '):
136+
module_name = line.split()[1]
137+
if module_name in ALLOWED_MODULES:
138+
return None, {module_name: ALLOWED_MODULES[module_name]}
139+
else:
140+
logger.warning(f"Skipping import of unauthorized module: {module_name}")
141+
return None, {}
142+
87143
# Try executing with Python
88-
# Create a copy of state for local execution
89144
local_state = state.copy()
145+
# Add allowed modules to local state
146+
for mod_name, mod in ALLOWED_MODULES.items():
147+
if mod_name in state:
148+
local_state[mod_name] = mod
149+
90150
exec(line, globals(), local_state)
91151
# Extract any new/modified variables
92152
new_state = {k:v for k,v in local_state.items()
93153
if k not in state or state[k] != v}
94154
return None, new_state
155+
95156
except Exception as e:
96157
# If Python execution fails, simulate with LM
97158
context = f"Current program state: {state}\nExecute line: {line}"
@@ -104,9 +165,8 @@ def execute_line(line: str, state: Dict[str, Any], client, model: str) -> Tuple[
104165
temperature=0.2
105166
)
106167
try:
107-
# Clean and parse LM response
108168
cleaned_response = clean_state_response(response.choices[0].message.content)
109-
new_state = ast.literal_eval(cleaned_response)
169+
new_state = parse_state_dict(cleaned_response)
110170
return response.usage.completion_tokens, new_state
111171
except Exception as e:
112172
logger.error(f"Could not parse LM state response: {response.choices[0].message.content}")
@@ -116,7 +176,6 @@ def execute_line(line: str, state: Dict[str, Any], client, model: str) -> Tuple[
116176

117177
def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
118178
"""Main Chain of Code execution function."""
119-
# Generate initial code solution
120179
messages = [
121180
{"role": "system", "content": system_prompt + "\n" + CHAIN_OF_CODE_PROMPT},
122181
{"role": "user", "content": initial_query}
@@ -130,17 +189,14 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
130189
initial_response = response.choices[0].message.content
131190
total_tokens = response.usage.completion_tokens
132191

133-
# Extract code blocks
134192
code_blocks = extract_code_blocks(initial_response)
135193
if not code_blocks:
136194
logger.warning("No code blocks found in response")
137195
return initial_response, total_tokens
138196

139-
# Execute code blocks line by line
140197
final_state = {}
141198
code = code_blocks[0] # Take first code block
142199

143-
# Split into lines and filter empty lines
144200
lines = [line.strip() for line in code.split('\n') if line.strip()]
145201

146202
for line in lines:
@@ -154,7 +210,6 @@ def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str
154210
logger.debug(f"Executed line: {line}")
155211
logger.debug(f"New state: {new_state}")
156212

157-
# Extract output tags from the initial response, or use answer from state
158213
final_answer = extract_output(initial_response)
159214
if not final_answer and 'answer' in final_state:
160215
final_answer = str(final_state['answer'])

0 commit comments

Comments
 (0)