diff --git "a/src/codeas/ui/pages/5_\360\237\222\254_Chat.py" "b/src/codeas/ui/pages/5_\360\237\222\254_Chat.py" index 566cfe8..68c99a8 100644 --- "a/src/codeas/ui/pages/5_\360\237\222\254_Chat.py" +++ "b/src/codeas/ui/pages/5_\360\237\222\254_Chat.py" @@ -12,6 +12,10 @@ from codeas.ui.utils import read_prompts +ALL_FILES = "All files" +FULL_CONTENT = "Full content" + + def chat(): st.subheader("💬 Chat") state.update_current_page("Chat") @@ -51,8 +55,8 @@ def display_config_section(): retriever = ContextRetriever(**get_retriever_args()) if ( - st.session_state.get("file_types") != "All files" - or st.session_state.get("content_types") != "Full content" + st.session_state.get("file_types") != ALL_FILES + or st.session_state.get("content_types") != FULL_CONTENT ): files_missing_metadata = metadata_ui.display() if not any(files_missing_metadata): @@ -76,17 +80,13 @@ def display_config_section(): st.caption(f"{num_selected_files:,} files | {selected_tokens:,} tokens") repo_ui.display_files_editor() - if not any(files_missing_metadata): - if st.button("Show context"): - context = retriever.retrieve( - files_paths=state.repo.included_files_paths, - files_tokens=state.repo.included_files_tokens, - metadata=state.repo_metadata, - ) - st.text_area("Context", context, height=300) - - if not any(files_missing_metadata): - st.caption(f"{num_selected_files:,} files | {selected_tokens:,} tokens") + if not any(files_missing_metadata) and st.button("Show context"): + context = retriever.retrieve( + files_paths=state.repo.included_files_paths, + files_tokens=state.repo.included_files_tokens, + metadata=state.repo_metadata, + ) + st.text_area("Context", context, height=300) def display_file_options(): @@ -95,7 +95,7 @@ def display_file_options(): st.selectbox( "File types", options=[ - "All files", + ALL_FILES, "Code files", "Testing files", "Config files", @@ -109,7 +109,7 @@ def display_file_options(): with col2: st.selectbox( "Content types", - options=["Full content", "Descriptions", "Details"], + options=[FULL_CONTENT, "Descriptions", "Details"], key="content_types", ) @@ -166,14 +166,13 @@ def display_chat_history(): icon="🤖", ): if entry.get("content") is None: - with st.spinner("Running agent..."): - content, cost = run_agent(entry["model"]) - st.write(f"💰 ${cost['total_cost']:.4f}") - st.session_state.chat_history[i]["content"] = content - st.session_state.chat_history[i]["cost"] = cost - else: + with st.spinner("Running agent..."): + pass + + if entry.get("content") is not None: st.write(entry["content"]) - st.write(f"💰 ${entry['cost']['total_cost']:.4f}") + if entry.get("cost"): + st.write(f"💰 ${entry['cost']['total_cost']:.4f}") def display_user_input(): @@ -200,30 +199,6 @@ def display_template_options(): index=0 if st.session_state.input_reset else None, ) - # remaining_options = [ - # opt for opt in prompt_options if opt != st.session_state.template1 - # ] - # with col2: - # st.selectbox( - # "Template 2", - # options=remaining_options, - # key="template2", - # index=0 if st.session_state.input_reset else None, - # disabled=not st.session_state.template1, - # ) - - # final_options = [ - # opt for opt in remaining_options if opt != st.session_state.template2 - # ] - # with col3: - # st.selectbox( - # "Template 3", - # options=final_options, - # key="template3", - # index=0 if st.session_state.input_reset else None, - # disabled=not st.session_state.template2, - # ) - def display_input_areas(): prompts = read_prompts() @@ -233,26 +208,34 @@ def display_input_areas(): if st.session_state.get(f"template{i}") ] - if len(selected_templates) > 1: - for i, template in enumerate(selected_templates, 1): + valid_selected_templates = [t for t in selected_templates if t] + + if len(valid_selected_templates) > 1: + for i, template in enumerate(valid_selected_templates, 1): instruction_key = f"instructions{i}" if st.session_state.input_reset: - st.session_state[instruction_key] = "" + st.session_state[instruction_key] = "" prompt_content = prompts.get(template, "") with st.expander(f"Template {i}: {template}", expanded=True): st.text_area( "Instructions", - value=prompt_content, + value=st.session_state.get(instruction_key, prompt_content), height=200, key=instruction_key, ) else: + instruction_key = "instructions" if st.session_state.input_reset: - st.session_state.instructions = "" - template = selected_templates[0] if selected_templates else "" + st.session_state[instruction_key] = "" + + template = valid_selected_templates[0] if valid_selected_templates else "" prompt_content = prompts.get(template, "") + st.text_area( - "Instructions", value=prompt_content, key="instructions", height=200 + "Instructions", + value=st.session_state.get(instruction_key, prompt_content), + key=instruction_key, + height=200, ) @@ -264,8 +247,7 @@ def initialize_input_reset(): def reset_input_flag(): # used to empty user input and templates after user sends a message - if st.session_state.input_reset: - st.session_state.input_reset = False + pass def display_action_buttons(): @@ -276,94 +258,100 @@ def display_action_buttons(): handle_preview_button() +def _get_user_message_data(): + """Helper to extract user inputs and associated templates based on UI keys.""" + user_message_data = [] + selected_templates_by_slot = [st.session_state.get(f"template{i}") for i in range(1, 4)] + valid_selected_templates = [t for t in selected_templates_by_slot if t] + + if len(valid_selected_templates) > 1: + for i in range(1, 4): + template = st.session_state.get(f"template{i}") + if template: + instruction_key = f"instructions{i}" + user_input = st.session_state.get(instruction_key, "").strip() + if user_input: + user_message_data.append((user_input, template)) + else: + instruction_key = "instructions" + user_input = st.session_state.get(instruction_key, "").strip() + template = valid_selected_templates[0] if valid_selected_templates else "" + if user_input: + user_message_data.append((user_input, template)) + + return user_message_data + + def handle_send_button(): - selected_templates = [ - st.session_state.get(f"template{i}") - for i in range(1, 4) - if st.session_state.get(f"template{i}") - ] + user_message_data = _get_user_message_data() - if len(selected_templates) > 1: - user_inputs = [ - st.session_state.get(f"instructions{i}").strip() - for i in range(1, len(selected_templates) + 1) - ] - else: - user_inputs = [st.session_state.instructions.strip()] + if user_message_data: + for user_input, template in user_message_data: + st.session_state.chat_history.append( + {"role": "user", "content": user_input, "template": template} + ) - if any(user_inputs): - for i, user_input in enumerate(user_inputs): - if user_input: - template = selected_templates[i] if len(selected_templates) > 1 else "" - st.session_state.chat_history.append( - {"role": "user", "content": user_input, "template": template} - ) - for i, user_input in enumerate(user_inputs): - for model in get_selected_models(): + selected_models = get_selected_models() + using_multiple_models = len(selected_models) > 1 + + for user_input, template in user_message_data: + for model in selected_models: st.session_state.chat_history.append( { "role": "assistant", "model": model, "template": template, - "multiple_models": len(get_selected_models()) > 1, + "multiple_models": using_multiple_models, + "content": None, + "cost": None, } ) + st.session_state.input_reset = True st.rerun() def handle_preview_button(): - selected_templates = [ - st.session_state.get(f"template{i}") - for i in range(1, 4) - if st.session_state.get(f"template{i}") - ] + user_message_data = _get_user_message_data() - if len(selected_templates) > 1: - user_inputs = [ - st.session_state.get(f"instructions{i}").strip() - for i in range(1, len(selected_templates) + 1) - ] - else: - user_inputs = [st.session_state.instructions.strip()] - - for i, user_input in enumerate(user_inputs): - if user_input: - template_label = ( - f"[{selected_templates[i]}]" if len(selected_templates) > 1 else "" - ) - for model in get_selected_models(): - with st.expander( - f"🤖 PREVIEW [{model}] {template_label}", expanded=True - ): - with st.spinner("Previewing..."): - messages = get_history_messages(model) - messages.append({"role": "user", "content": user_input}) - st.json(messages, expanded=False) - llm_client = LLMClients(model=model) - cost = llm_client.calculate_cost(messages) - st.write( - f"💰 ${cost['input_cost']:.4f} [input] ({cost['input_tokens']:,} tokens) " - ) + for user_input, template in user_message_data: + template_label = f"[{template}]" if template else "" + for model in get_selected_models(): + with st.expander( + f"🤖 PREVIEW [{model}] {template_label}", expanded=True + ): + with st.spinner("Previewing..."): + messages = get_history_messages(model) + messages.append({"role": "user", "content": user_input}) + st.json(messages, expanded=False) + llm_client = LLMClients(model=model) + cost = llm_client.calculate_cost(messages) + st.write( + f"💰 ${cost['input_cost']:.4f} [input] ({cost['input_tokens']:,} tokens) " + ) def run_agent(model): llm_client = LLMClients(model=model) messages = get_history_messages(model) - if model == "claude-3-5-sonnet" or model == "claude-3-haiku": - if ( - tokencost.count_string_tokens(llm_client.extract_strings(messages), model) - > 10000 - ): - st.warning( - "Anthropic API is limited to 80k tokens per minute. Using it with large context may result in errors." - ) + + is_anthropic_claude = (model == "claude-3-5-sonnet" or model == "claude-3-haiku") + has_large_context = tokencost.count_string_tokens( + llm_client.extract_strings(messages), model + ) > 10000 + + if is_anthropic_claude and has_large_context: + st.warning( + "Anthropic API is limited to 80k tokens per minute. Using it with large context may result in errors." + ) + if model == "o1-preview" or model == "o1-mini": st.caption("Streaming is not supported for o1 models.") completion = llm_client.run(messages) st.markdown(completion) else: completion = st.write_stream(llm_client.stream(messages)) + cost = llm_client.calculate_cost(messages, completion) log_agent_execution(model, messages, cost) return completion, cost @@ -377,20 +365,22 @@ def get_history_messages(model): metadata=state.repo_metadata, ) messages = [{"role": "user", "content": context}] + for entry in st.session_state.chat_history: if entry["role"] == "user": messages.append({"role": entry["role"], "content": entry["content"]}) elif entry["role"] == "assistant" and entry.get("content") is not None: - if entry.get("multiple_models") is False or entry.get("model") == model: + if entry.get("multiple_models", False) is False or entry.get("model") == model: messages.append({"role": entry["role"], "content": entry["content"]}) + return messages def get_retriever_args(): - file_types = st.session_state.get("file_types", "All files") - content_types = st.session_state.get("content_types", "Full content") + file_types = st.session_state.get("file_types", ALL_FILES) + content_types = st.session_state.get("content_types", FULL_CONTENT) return { - "include_all_files": file_types == "All files", + "include_all_files": file_types == ALL_FILES, "include_code_files": file_types == "Code files", "include_testing_files": file_types == "Testing files", "include_config_files": file_types == "Config files", @@ -404,25 +394,40 @@ def get_retriever_args(): def log_agent_execution(model, messages, cost): - # Get or create a conversation ID if "conversation_id" not in st.session_state: st.session_state.conversation_id = str(uuid.uuid4()) - # Get the content of the last message - prompt = messages[-1]["content"] if messages else "" - # Get template information - selected_templates = [ - st.session_state.get(f"template{i}") - for i in range(1, 4) - if st.session_state.get(f"template{i}") - ] - using_template = any(selected_templates) - using_multiple_templates = len(selected_templates) > 1 - # Check if multiple models are being used + + raw_user_prompt_content = messages[-1]["content"] if messages else "" + + corresponding_entry = None + for entry in reversed(st.session_state.chat_history): + if ( + entry.get("role") == "assistant" + and entry.get("model") == model + and entry.get("content") is not None + and entry.get("cost") == cost + ): + corresponding_entry = entry + break + + used_template = corresponding_entry.get("template") if corresponding_entry else "" + using_template = bool(used_template) + + prompt_summary_for_logging = f"User input length: {len(raw_user_prompt_content)} characters" + if using_template: + prompt_summary_for_logging += f", Template: [{used_template}]" + else: + prompt_summary_for_logging += ", Template: [None]" + + + all_selected_templates_in_ui = [t for i in range(1, 4) if (t := st.session_state.get(f"template{i}"))] + using_multiple_templates = len(all_selected_templates_in_ui) > 1 + using_multiple_models = len(get_selected_models()) > 1 - # Log the agent execution using the UsageTracker + usage_tracker.log_agent_execution( model=model, - prompt=prompt, + prompt=prompt_summary_for_logging, cost=cost, conversation_id=st.session_state.conversation_id, using_template=using_template, @@ -431,4 +436,4 @@ def log_agent_execution(model, messages, cost): ) -chat() +chat() \ No newline at end of file