diff --git a/src/codeas/ui/components/refactoring_ui.py b/src/codeas/ui/components/refactoring_ui.py index 6c9258a..ff5f12e 100644 --- a/src/codeas/ui/components/refactoring_ui.py +++ b/src/codeas/ui/components/refactoring_ui.py @@ -13,10 +13,22 @@ generate_proposed_changes, ) +USE_PREVIOUS_OUTPUTS_LABEL = "Use previous outputs" + + +def is_safe_path(base_dir, target_path): + try: + base_dir_realpath = os.path.realpath(base_dir) + target_path_realpath = os.path.realpath(target_path) + + return target_path_realpath.startswith(base_dir_realpath + os.sep) or target_path_realpath == base_dir_realpath + except OSError: + return False + def display(): use_previous_outputs_groups = st.toggle( - "Use previous outputs", value=True, key="use_previous_outputs_groups" + USE_PREVIOUS_OUTPUTS_LABEL, value=True, key="use_previous_outputs_groups" ) if st.button( @@ -55,13 +67,10 @@ def display(): ), "cost": previous_output["cost"], "tokens": previous_output["tokens"], - "messages": previous_output["messages"], # Add this line + "messages": previous_output["messages"], }, ) except FileNotFoundError: - # st.warning( - # "No previous output found for refactoring groups. Running generation..." - # ) st.session_state.outputs[ "refactoring_groups" ] = define_refactoring_files() @@ -76,7 +85,7 @@ def display(): ].tokens, "messages": st.session_state.outputs[ "refactoring_groups" - ].messages, # Add this line + ].messages, }, "refactoring_groups.json", ) @@ -93,7 +102,7 @@ def display(): "tokens": st.session_state.outputs["refactoring_groups"].tokens, "messages": st.session_state.outputs[ "refactoring_groups" - ].messages, # Add this line + ].messages, }, "refactoring_groups.json", ) @@ -119,7 +128,6 @@ def display(): ) groups = output.response.choices[0].message.parsed - # Create a DataFrame for the data editor data = [ { "selected": True, @@ -154,7 +162,7 @@ def display(): def display_generate_proposed_changes(): use_previous_outputs_changes = st.toggle( - "Use previous outputs", value=True, key="use_previous_outputs_changes" + USE_PREVIOUS_OUTPUTS_LABEL, value=True, key="use_previous_outputs_changes" ) groups = ( @@ -203,13 +211,10 @@ def display_generate_proposed_changes(): }, "cost": previous_output["cost"], "tokens": previous_output["tokens"], - "messages": previous_output["messages"], # Add this line + "messages": previous_output["messages"], }, ) except FileNotFoundError: - # st.warning( - # "No previous output found for proposed changes. Running generation..." - # ) st.session_state.outputs[ "proposed_changes" ] = generate_proposed_changes(groups) @@ -229,7 +234,7 @@ def display_generate_proposed_changes(): ].tokens, "messages": st.session_state.outputs[ "proposed_changes" - ].messages, # Add this line + ].messages, }, "proposed_changes.json", ) @@ -249,7 +254,7 @@ def display_generate_proposed_changes(): "tokens": st.session_state.outputs["proposed_changes"].tokens, "messages": st.session_state.outputs[ "proposed_changes" - ].messages, # Add this line + ].messages, }, "proposed_changes.json", ) @@ -285,7 +290,7 @@ def display_generate_proposed_changes(): def display_apply_changes(): use_previous_outputs_diffs = st.toggle( - "Use previous outputs", value=True, key="use_previous_outputs_diffs" + USE_PREVIOUS_OUTPUTS_LABEL, value=True, key="use_previous_outputs_diffs" ) if st.button("Apply changes", type="primary", key="apply_changes"): @@ -296,7 +301,6 @@ def display_apply_changes(): ].response.values() ] with st.spinner("Generating and applying changes..."): - # Generate diffs if use_previous_outputs_diffs: try: previous_output = state.read_output("generated_diffs.json") @@ -311,9 +315,6 @@ def display_apply_changes(): }, ) except FileNotFoundError: - # st.warning( - # "No previous output found for generated diffs. Running generation..." - # ) st.session_state.outputs["generated_diffs"] = generate_diffs( groups_changes ) @@ -348,16 +349,44 @@ def display_apply_changes(): "generated_diffs.json", ) - # Apply diffs generated_diffs_output = st.session_state.outputs["generated_diffs"] + project_root = os.getcwd() + for file_path, response in generated_diffs_output.response.items(): + if not is_safe_path(project_root, file_path): + st.error(f"Error: Original file path '{file_path}' from generated changes is outside project root. Skipping.") + continue + directory, filename = os.path.split(file_path) name, ext = os.path.splitext(filename) new_file_path = os.path.join(directory, f"{name}_refactored{ext}") - with open(file_path, "r") as f: - original_content = f.read() + if not is_safe_path(project_root, new_file_path): + st.error(f"Error: Generated refactored path '{new_file_path}' is outside project root. Skipping.") + continue + + new_file_dir = os.path.dirname(new_file_path) + if not is_safe_path(project_root, new_file_dir): + st.error(f"Error: Directory for refactored file '{new_file_dir}' is outside project root. Skipping.") + continue + if not os.path.exists(new_file_dir): + try: + os.makedirs(new_file_dir, exist_ok=True) + except OSError as e: + st.error(f"Error creating directory '{new_file_dir}': {e}. Skipping.") + continue + + original_content = None + try: + with open(file_path, "r") as f: + original_content = f.read() + except OSError as e: + st.error(f"Error reading original file '{file_path}': {e}. Skipping.") + continue + + if original_content is None: + continue diff = ( f"```diff\n{response['content']}\n```" @@ -368,14 +397,17 @@ def display_apply_changes(): try: patched_content = apply_diffs(original_content, diff) except Exception: - # st.error(f"Error applying diff to {file_path}") + st.error(f"Error applying diff to {file_path}. Skipping.") continue - if not os.path.exists(os.path.dirname(new_file_path)): - os.makedirs(os.path.dirname(new_file_path), exist_ok=True) - with open(new_file_path, "w") as f: - f.write(patched_content) + try: + with open(new_file_path, "w") as f: + f.write(patched_content) + except OSError as e: + st.error(f"Error writing refactored file '{new_file_path}': {e}. Skipping.") + continue + st.success(f"{new_file_path} successfully written!") with st.expander(f"Generated changes [{file_path}]"): - st.code(diff) + st.code(diff) \ No newline at end of file