diff --git a/docs/docs/installation/locally.md b/docs/docs/installation/locally.md
index 9543fe3274..b2717c5ce5 100644
--- a/docs/docs/installation/locally.md
+++ b/docs/docs/installation/locally.md
@@ -1,7 +1,7 @@
To run PR-Agent locally, you first need to acquire two keys:
1. An OpenAI key from [here](https://platform.openai.com/api-keys){:target="_blank"}, with access to GPT-4 and o4-mini (or a key for other [language models](../usage-guide/changing_a_model.md), if you prefer).
-2. A personal access token from your Git platform (GitHub, GitLab, BitBucket,Gitea) with repo scope. GitHub token, for example, can be issued from [here](https://github.com/settings/tokens){:target="_blank"}
+2. A personal access token from your Git platform (GitHub, GitLab, BitBucket, Gitea) with repo scope. GitHub token, for example, can be issued from [here](https://github.com/settings/tokens){:target="_blank"}
## Using Docker image
diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py
index 9fb9d8add3..81dd4a2aec 100644
--- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py
+++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py
@@ -5,10 +5,20 @@
from litellm import acompletion
from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt
-from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS, STREAMING_REQUIRED_MODELS
+from pr_agent.algo import (
+ CLAUDE_EXTENDED_THINKING_MODELS,
+ NO_SUPPORT_TEMPERATURE_MODELS,
+ SUPPORT_REASONING_EFFORT_MODELS,
+ USER_MESSAGE_ONLY_MODELS,
+ STREAMING_REQUIRED_MODELS,
+)
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
-from pr_agent.algo.ai_handlers.litellm_helpers import _handle_streaming_response, MockResponse, _get_azure_ad_token, \
- _process_litellm_extra_body
+from pr_agent.algo.ai_handlers.litellm_helpers import (
+ _handle_streaming_response,
+ MockResponse,
+ _get_azure_ad_token,
+ _process_litellm_extra_body,
+)
from pr_agent.algo.utils import ReasoningEffort, get_version
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
@@ -31,6 +41,7 @@ def __init__(self):
"""
self.azure = False
self.api_base = None
+ self.api_key = None
self.repetition_penalty = None
if get_settings().get("LITELLM.DISABLE_AIOHTTP", False):
@@ -38,10 +49,12 @@ def __init__(self):
if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
- elif 'OPENAI_API_KEY' not in os.environ:
+ elif "OPENAI_API_KEY" not in os.environ:
litellm.api_key = "dummy_key"
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
- assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete"
+ assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, (
+ "AWS credentials are incomplete"
+ )
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
@@ -76,36 +89,36 @@ def __init__(self):
litellm.api_key = get_settings().xai.key
if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key
- if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
+ if get_settings().get("HUGGINGFACE.API_BASE", None) and "huggingface" in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
+ if get_settings().get("OLLAMA.API_KEY", None):
+ self.api_key = get_settings().ollama.api_key
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project
- litellm.vertex_location = get_settings().get(
- "VERTEXAI.VERTEX_LOCATION", None
- )
+ litellm.vertex_location = get_settings().get("VERTEXAI.VERTEX_LOCATION", None)
# Google AI Studio
# SEE https://docs.litellm.ai/docs/providers/gemini
if get_settings().get("GOOGLE_AI_STUDIO.GEMINI_API_KEY", None):
- os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key
+ os.environ["GEMINI_API_KEY"] = get_settings().google_ai_studio.gemini_api_key
# Support deepseek models
if get_settings().get("DEEPSEEK.KEY", None):
- os.environ['DEEPSEEK_API_KEY'] = get_settings().get("DEEPSEEK.KEY")
+ os.environ["DEEPSEEK_API_KEY"] = get_settings().get("DEEPSEEK.KEY")
# Support deepinfra models
if get_settings().get("DEEPINFRA.KEY", None):
- os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY")
+ os.environ["DEEPINFRA_API_KEY"] = get_settings().get("DEEPINFRA.KEY")
# Support mistral models
if get_settings().get("MISTRAL.KEY", None):
os.environ["MISTRAL_API_KEY"] = get_settings().get("MISTRAL.KEY")
-
+
# Support codestral models
if get_settings().get("CODESTRAL.KEY", None):
os.environ["CODESTRAL_API_KEY"] = get_settings().get("CODESTRAL.KEY")
@@ -117,7 +130,7 @@ def __init__(self):
access_token = _get_azure_ad_token()
litellm.api_key = access_token
openai.api_key = access_token
-
+
# Set API base from settings
self.api_base = get_settings().azure_ad.api_base
litellm.api_base = self.api_base
@@ -152,14 +165,14 @@ def __init__(self):
def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy()
- response_log['system'] = system
- response_log['user'] = user
- response_log['output'] = resp
- response_log['finish_reason'] = finish_reason
- if hasattr(self, 'main_pr_language'):
- response_log['main_pr_language'] = self.main_pr_language
+ response_log["system"] = system
+ response_log["user"] = user
+ response_log["output"] = resp
+ response_log["finish_reason"] = finish_reason
+ if hasattr(self, "main_pr_language"):
+ response_log["main_pr_language"] = self.main_pr_language
else:
- response_log['main_pr_language'] = 'unknown'
+ response_log["main_pr_language"] = "unknown"
return response_log
def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict:
@@ -178,18 +191,23 @@ def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict:
# Validate extended thinking parameters
if not isinstance(extended_thinking_budget_tokens, int) or extended_thinking_budget_tokens <= 0:
- raise ValueError(f"extended_thinking_budget_tokens must be a positive integer, got {extended_thinking_budget_tokens}")
+ raise ValueError(
+ f"extended_thinking_budget_tokens must be a positive integer, got {extended_thinking_budget_tokens}"
+ )
if not isinstance(extended_thinking_max_output_tokens, int) or extended_thinking_max_output_tokens <= 0:
- raise ValueError(f"extended_thinking_max_output_tokens must be a positive integer, got {extended_thinking_max_output_tokens}")
+ raise ValueError(
+ f"extended_thinking_max_output_tokens must be a positive integer, got {extended_thinking_max_output_tokens}"
+ )
if extended_thinking_max_output_tokens < extended_thinking_budget_tokens:
- raise ValueError(f"extended_thinking_max_output_tokens ({extended_thinking_max_output_tokens}) must be greater than or equal to extended_thinking_budget_tokens ({extended_thinking_budget_tokens})")
+ raise ValueError(
+ f"extended_thinking_max_output_tokens ({extended_thinking_max_output_tokens}) must be greater than or equal to extended_thinking_budget_tokens ({extended_thinking_budget_tokens})"
+ )
- kwargs["thinking"] = {
- "type": "enabled",
- "budget_tokens": extended_thinking_budget_tokens
- }
+ kwargs["thinking"] = {"type": "enabled", "budget_tokens": extended_thinking_budget_tokens}
if get_settings().config.verbosity_level >= 2:
- get_logger().info(f"Adding max output tokens {extended_thinking_max_output_tokens} to model {model}, extended thinking budget tokens: {extended_thinking_budget_tokens}")
+ get_logger().info(
+ f"Adding max output tokens {extended_thinking_max_output_tokens} to model {model}, extended thinking budget tokens: {extended_thinking_budget_tokens}"
+ )
kwargs["max_tokens"] = extended_thinking_max_output_tokens
# temperature may only be set to 1 when thinking is enabled
@@ -206,10 +224,10 @@ def capture_logs(message):
# Parsing the log message and context
record = message.record
log_entry = {}
- if record.get('extra', None).get('command', None) is not None:
- log_entry.update({"command": record['extra']["command"]})
- if record.get('extra', {}).get('pr_url', None) is not None:
- log_entry.update({"pr_url": record['extra']["pr_url"]})
+ if record.get("extra", None).get("command", None) is not None:
+ log_entry.update({"command": record["extra"]["command"]})
+ if record.get("extra", {}).get("pr_url", None) is not None:
+ log_entry.update({"pr_url": record["extra"]["pr_url"]})
# Append the log entry to the captured_logs list
captured_extra.append(log_entry)
@@ -228,25 +246,29 @@ def capture_logs(message):
metadata = dict()
callbacks = litellm.success_callback + litellm.failure_callback + litellm.service_callback
if "langfuse" in callbacks:
- metadata.update({
- "trace_name": command,
- "tags": [git_provider, command, f'version:{get_version()}'],
- "trace_metadata": {
- "command": command,
- "pr_url": pr_url,
- },
- })
- if "langsmith" in callbacks:
- metadata.update({
- "run_name": command,
- "tags": [git_provider, command, f'version:{get_version()}'],
- "extra": {
- "metadata": {
+ metadata.update(
+ {
+ "trace_name": command,
+ "tags": [git_provider, command, f"version:{get_version()}"],
+ "trace_metadata": {
"command": command,
"pr_url": pr_url,
- }
- },
- })
+ },
+ }
+ )
+ if "langsmith" in callbacks:
+ metadata.update(
+ {
+ "run_name": command,
+ "tags": [git_provider, command, f"version:{get_version()}"],
+ "extra": {
+ "metadata": {
+ "command": command,
+ "pr_url": pr_url,
+ }
+ },
+ }
+ )
# Adding the captured logs to the kwargs
kwargs["metadata"] = metadata
@@ -269,11 +291,12 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
resp, finish_reason = None, None
deployment_id = self.deployment_id
if self.azure:
- model = 'azure/' + model
- if 'claude' in model and not system:
+ model = "azure/" + model
+ if "claude" in model and not system:
system = "No system prompt provided"
get_logger().warning(
- "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.")
+ "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error."
+ )
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
if img_path:
@@ -287,11 +310,13 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
except Exception as e:
get_logger().error(f"Error fetching image: {img_path}", e)
return f"Error fetching image: {img_path}", "error"
- messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
- {"type": "image_url", "image_url": {"url": img_path}}]
+ messages[1]["content"] = [
+ {"type": "text", "text": messages[1]["content"]},
+ {"type": "image_url", "image_url": {"url": img_path}},
+ ]
thinking_kwargs_gpt5 = None
- if model.startswith('gpt-5'):
+ if model.startswith("gpt-5"):
# Use configured reasoning_effort or default to MEDIUM
config_effort = get_settings().config.reasoning_effort
try:
@@ -310,8 +335,7 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
"allowed_openai_params": ["reasoning_effort"],
}
get_logger().info(f"Using reasoning_effort='{effort}' for GPT-5 model")
- model = 'openai/'+model.replace('_thinking', '') # remove _thinking suffix
-
+ model = "openai/" + model.replace("_thinking", "") # remove _thinking suffix
# Currently, some models do not support a separate system and user prompts
if model in self.user_message_only_models or get_settings().config.custom_reasoning_model:
@@ -342,8 +366,8 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
if thinking_kwargs_gpt5:
kwargs.update(thinking_kwargs_gpt5)
- if 'temperature' in kwargs:
- del kwargs['temperature']
+ if "temperature" in kwargs:
+ del kwargs["temperature"]
# Add reasoning_effort if model supports it
if model in self.support_reasoning_models:
@@ -363,7 +387,9 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
kwargs["reasoning_effort"] = reasoning_effort
# https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking
- if (model in self.claude_extended_thinking_models) and get_settings().config.get("enable_claude_extended_thinking", False):
+ if (model in self.claude_extended_thinking_models) and get_settings().config.get(
+ "enable_claude_extended_thinking", False
+ ):
kwargs = self._configure_claude_extended_thinking(model, kwargs)
if get_settings().litellm.get("enable_callbacks", False):
@@ -379,7 +405,7 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
if self.repetition_penalty:
kwargs["repetition_penalty"] = self.repetition_penalty
- #Added support for extra_headers while using litellm to call underlying model, via a api management gateway, would allow for passing custom headers for security and authorization
+ # Added support for extra_headers while using litellm to call underlying model, via a api management gateway, would allow for passing custom headers for security and authorization
if get_settings().get("LITELLM.EXTRA_HEADERS", None):
try:
litellm_extra_headers = json.loads(get_settings().litellm.extra_headers)
@@ -394,7 +420,7 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
# Support for Bedrock custom inference profile via model_id
model_id = get_settings().get("litellm.model_id")
- if model_id and 'bedrock/' in model:
+ if model_id and "bedrock/" in model:
kwargs["model_id"] = model_id
get_logger().info(f"Using Bedrock custom inference profile: {model_id}")
@@ -404,6 +430,9 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
get_logger().info(f"\nSystem prompt:\n{system}")
get_logger().info(f"\nUser prompt:\n{user}")
+ if self.api_key:
+ kwargs["api_key"] = self.api_key
+
# Get completion with automatic streaming detection
resp, finish_reason, response_obj = await self._get_completion(**kwargs)
@@ -446,6 +475,4 @@ async def _get_completion(self, **kwargs):
response = await acompletion(**kwargs)
if response is None or len(response["choices"]) == 0:
raise openai.APIError
- return (response["choices"][0]['message']['content'],
- response["choices"][0]["finish_reason"],
- response)
+ return (response["choices"][0]["message"]["content"], response["choices"][0]["finish_reason"], response)
diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py
index b18826e25c..7cee1bd328 100644
--- a/pr_agent/algo/git_patch_processing.py
+++ b/pr_agent/algo/git_patch_processing.py
@@ -7,6 +7,11 @@
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
+# Optimized: Pre-compile the hunk header regex at the module level to avoid redundant compilation
+# in performance-critical patch processing functions.
+RE_HUNK_HEADER = re.compile(
+ r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
+
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
patch_extra_lines_after=0, filename: str = "", new_file_str="") -> str:
@@ -65,8 +70,6 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
is_valid_hunk = True
start1, size1, start2, size2 = -1, -1, -1, -1
- RE_HUNK_HEADER = re.compile(
- r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
try:
for i,line in enumerate(patch_lines):
if line.startswith('@@'):
@@ -114,7 +117,7 @@ def _calc_context_limits(patch_lines_before):
found_header = True
section_header = ''
else:
- pass # its ok to be here. We cant apply dynamic context if the lines are different if 'old' and 'new' hunks
+ pass # its ok to be here. We can't apply dynamic context if the lines are different if 'old' and 'new' hunks
break
if not found_header:
@@ -233,8 +236,6 @@ def omit_deletion_hunks(patch_lines) -> str:
added_patched = []
add_hunk = False
inside_hunk = False
- RE_HUNK_HEADER = re.compile(
- r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))?\ @@[ ]?(.*)")
for line in patch_lines:
if line.startswith('@@'):
@@ -336,8 +337,6 @@ def decouple_and_convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
patch_with_lines_str = ""
patch_lines = patch.splitlines()
- RE_HUNK_HEADER = re.compile(
- r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
new_content_lines = []
old_content_lines = []
match = None
@@ -412,8 +411,6 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
selected_lines = ""
patch_lines = patch.splitlines()
- RE_HUNK_HEADER = re.compile(
- r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
match = None
start1, size1, start2, size2 = -1, -1, -1, -1
skip_hunk = False
diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py
index 541279ec3d..68bc4e8c1a 100644
--- a/pr_agent/algo/utils.py
+++ b/pr_agent/algo/utils.py
@@ -1524,7 +1524,7 @@ def process_description(description_full: str) -> Tuple[str, List]:
pattern_back = r'\s*(.*?)(.*?).*?
\s*
\s*(.*?)\n\n\s*(.*?) '
res = re.search(pattern_back, file_data, re.DOTALL)
if not res or res.lastindex != 4:
- pattern_back = r'\s*(.*?)\s*(.*?).*?
\s*
\s*(.*?)\s*-\s*(.*?)\s* ' # looking for hypen ('- ')
+ pattern_back = r'\s*(.*?)\s*(.*?).*?
\s*
\s*(.*?)\s*-\s*(.*?)\s* ' # looking for hyphen ('- ')
res = re.search(pattern_back, file_data, re.DOTALL)
if res and res.lastindex == 4:
short_filename = res.group(1).strip()
diff --git a/pr_agent/custom_merge_loader.py b/pr_agent/custom_merge_loader.py
index 75b07a7718..abb11e3799 100644
--- a/pr_agent/custom_merge_loader.py
+++ b/pr_agent/custom_merge_loader.py
@@ -23,7 +23,7 @@ def load(obj, env=None, silent=True, key=None, filename=None):
None
"""
- MAX_TOML_SIZE_IN_BYTES = 100 * 1024 * 1024 # Prevent out of mem. exceptions by limiting to 100 MBs which is sufficient for upto 1M lines
+ MAX_TOML_SIZE_IN_BYTES = 100 * 1024 * 1024 # Prevent out of mem. exceptions by limiting to 100 MBs which is sufficient for up to 1M lines
# Get the list of files to load
# TODO: hasattr(obj, 'settings_files') for some reason returns False. Need to use 'settings_file'
diff --git a/pr_agent/servers/github_polling.py b/pr_agent/servers/github_polling.py
index 95f9d911b5..ab02339109 100644
--- a/pr_agent/servers/github_polling.py
+++ b/pr_agent/servers/github_polling.py
@@ -226,7 +226,7 @@ async def polling_loop():
break
task_queue.clear()
- # Dont wait for all processes to complete. Move on to the next iteration
+ # Don't wait for all processes to complete. Move on to the next iteration
# for p in processes:
# p.join()
diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml
index 70238c2d91..b8a4875976 100644
--- a/pr_agent/settings/.secrets_template.toml
+++ b/pr_agent/settings/.secrets_template.toml
@@ -50,7 +50,8 @@ key = "" # Optional, uncomment if you want to use Huggingface Inference API. Acq
api_base = "" # the base url for your huggingface inference endpoint
[ollama]
-api_base = "" # the base url for your local Llama 2, Code Llama, and other models inference endpoint. Acquire through https://ollama.ai/
+api_base = "" # the base url for your Ollama endpoint, e.g. https://ollama.com for Ollama Cloud or http://localhost:11434 for local
+api_key = "" # required for Ollama Cloud (ollama.com); leave empty for local Ollama
[vertexai]
vertex_project = "" # the google cloud platform project name for your vertexai deployment
@@ -107,7 +108,7 @@ pat = ""
[azure_devops_server]
# For Azure devops Server basic auth - configured in the webhook creation
-# Optional, uncomment if you want to use Azure devops webhooks. Value assinged when you create the webhook
+# Optional, uncomment if you want to use Azure devops webhooks. Value assigned when you create the webhook
# webhook_username = ""
# webhook_password = ""
diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml
index d4e52af2fa..177bde9d1a 100644
--- a/pr_agent/settings/configuration.toml
+++ b/pr_agent/settings/configuration.toml
@@ -21,7 +21,7 @@ use_wiki_settings_file=true
use_repo_settings_file=true
use_global_settings_file=true
disable_auto_feedback = false
-ai_timeout=120 # 2minutes
+ai_timeout=120 # 2 minutes
skip_keys = []
custom_reasoning_model = false # when true, disables system messages and temperature controls for models that don't support chat-style inputs
response_language="en-US" # Language locales code for PR responses in ISO 3166 and ISO 639 format (e.g., "en-US", "it-IT", "zh-CN", ...)
@@ -136,7 +136,7 @@ use_conversation_history=true
[pr_code_suggestions] # /improve #
commitable_code_suggestions = false
-dual_publishing_score_threshold=-1 # -1 to disable, [0-10] to set the threshold (>=) for publishing a code suggestion both in a table and as commitable
+dual_publishing_score_threshold=-1 # -1 to disable, [0-10] to set the threshold (>=) for publishing a code suggestion both in a table and as committable
focus_only_on_problems=true
findings_metadata = false
findings_metadata_badges = false
diff --git a/pr_agent/settings/pr_help_docs_headings_prompts.toml b/pr_agent/settings/pr_help_docs_headings_prompts.toml
index da9d6e5334..05bc579116 100644
--- a/pr_agent/settings/pr_help_docs_headings_prompts.toml
+++ b/pr_agent/settings/pr_help_docs_headings_prompts.toml
@@ -1,7 +1,7 @@
[pr_help_docs_headings_prompts]
system="""You are Doc-helper, a language model that ranks documentation files based on their relevance to user questions.
-You will receive a question, a repository url and file names along with optional groups of headings extracted from such files from that repository (either as markdown or as restructred text).
+You will receive a question, a repository url and file names along with optional groups of headings extracted from such files from that repository (either as markdown or as restructured text).
Your task is to rank file paths based on how likely they contain the answer to a user's question, using only the headings from each such file and the file name.
======
diff --git a/pr_agent/settings/pr_help_docs_prompts.toml b/pr_agent/settings/pr_help_docs_prompts.toml
index c73e1d958c..16358a5010 100644
--- a/pr_agent/settings/pr_help_docs_prompts.toml
+++ b/pr_agent/settings/pr_help_docs_prompts.toml
@@ -1,6 +1,6 @@
[pr_help_docs_prompts]
system="""You are Doc-helper, a language model designed to answer questions about a documentation website for a given repository.
-You will receive a question, a repository url and the full documentation content for that repository (either as markdown or as restructred text).
+You will receive a question, a repository url and the full documentation content for that repository (either as markdown or as restructured text).
Your goal is to provide the best answer to the question using the documentation provided.
Additional instructions:
diff --git a/pr_agent/settings/pr_help_prompts.toml b/pr_agent/settings/pr_help_prompts.toml
index 8bd182005a..274940fd2a 100644
--- a/pr_agent/settings/pr_help_prompts.toml
+++ b/pr_agent/settings/pr_help_prompts.toml
@@ -1,5 +1,5 @@
[pr_help_prompts]
-system="""You are Doc-helper, a language models designed to answer questions about a documentation website for an open-soure project called "PR-Agent" (recently renamed to "Qodo Merge").
+system="""You are Doc-helper, a language model designed to answer questions about a documentation website for an open-source project called "PR-Agent" (recently renamed to "Qodo Merge").
You will receive a question, and the full documentation website content.
Your goal is to provide the best answer to the question using the documentation provided.
diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml
index 67f433a5bb..474e74f216 100644
--- a/pr_agent/settings/pr_reviewer_prompts.toml
+++ b/pr_agent/settings/pr_reviewer_prompts.toml
@@ -126,7 +126,7 @@ class Review(BaseModel):
ticket_compliance_check: List[TicketCompliance] = Field(description="A list of compliance checks for the related tickets")
{%- endif %}
{%- if require_estimate_effort_to_review %}
- estimated_effort_to_review_[1-5]: int = Field(description="Estimate, on a scale of 1-5 (inclusive), the time and effort required to review this PR by an experienced and knowledgeable developer. 1 means short and easy review , 5 means long and hard review. Take into account the size, complexity, quality, and the needed changes of the PR code diff.")
+ estimated_effort_to_review_[1-5]: int = Field(description="Estimate, on a scale of 1-5 (inclusive), the time and effort required to review this PR by an experienced and knowledgeable developer. 1 means short and easy review, 5 means long and hard review. Take into account the size, complexity, quality, and the needed changes of the PR code diff.")
{%- endif %}
{%- if require_estimate_contribution_time_cost %}
contribution_time_cost_estimate: ContributionTimeCostEstimate = Field(description="An estimate of the time required to implement the changes, based on the quantity, quality, and complexity of the contribution, as well as the context from the PR description and commit messages.")
@@ -135,20 +135,20 @@ class Review(BaseModel):
score: str = Field(description="Rate this PR on a scale of 0-100 (inclusive), where 0 means the worst possible PR code, and 100 means PR code of the highest quality, without any bugs or performance issues, that is ready to be merged immediately and run in production at scale.")
{%- endif %}
{%- if require_tests %}
- relevant_tests: str = Field(description="yes/no question: does this PR have relevant tests added or updated ?")
+ relevant_tests: str = Field(description="yes/no question: does this PR have relevant tests added or updated?")
{%- endif %}
{%- if question_str %}
insights_from_user_answers: str = Field(description="shortly summarize the insights you gained from the user's answers to the questions")
{%- endif %}
key_issues_to_review: List[KeyIssuesComponentLink] = Field("A concise list (0-{{ num_max_findings }} issues) of bugs, security vulnerabilities, or significant performance concerns introduced in this PR. Only include issues you are confident about. If confidence is limited but the potential impact is high (e.g., data loss, security), you may include it only if you explicitly note what remains uncertain. Each issue must identify a concrete problem with a realistic trigger scenario. An empty list is acceptable if no clear issues are found.")
{%- if require_security_review %}
- security_concerns: str = Field(description="Does this PR code introduce vulnerabilities such as exposure of sensitive information (e.g., API keys, secrets, passwords), or security concerns like SQL injection, XSS, CSRF, and others ? Answer 'No' (without explaining why) if there are no possible issues. If there are security concerns or issues, start your answer with a short header, such as: 'Sensitive information exposure: ...', 'SQL injection: ...', etc. Explain your answer. Be specific and give examples if possible")
+ security_concerns: str = Field(description="Does this PR code introduce vulnerabilities such as exposure of sensitive information (e.g., API keys, secrets, passwords), or security concerns like SQL injection, XSS, CSRF, and others? Answer 'No' (without explaining why) if there are no possible issues. If there are security concerns or issues, start your answer with a short header, such as: 'Sensitive information exposure: ...', 'SQL injection: ...', etc. Explain your answer. Be specific and give examples if possible")
{%- endif %}
{%- if require_todo_scan %}
todo_sections: Union[List[TodoSection], str] = Field(description="A list of TODO comments found in the PR code. Return 'No' (as a string) if there are no TODO comments in the PR")
{%- endif %}
{%- if require_can_be_split_review %}
- can_be_split: List[SubPR] = Field(min_items=0, max_items=3, description="Can this PR, which contains {{ num_pr_files }} changed files in total, be divided into smaller sub-PRs with distinct tasks that can be reviewed and merged independently, regardless of the order ? Make sure that the sub-PRs are indeed independent, with no code dependencies between them, and that each sub-PR represent a meaningful independent task. Output an empty list if the PR code does not need to be split.")
+ can_be_split: List[SubPR] = Field(min_items=0, max_items=3, description="Can this PR, which contains {{ num_pr_files }} changed files in total, be divided into smaller sub-PRs with distinct tasks that can be reviewed and merged independently, regardless of the order? Make sure that the sub-PRs are indeed independent, with no code dependencies between them, and that each sub-PR represents a meaningful independent task. Output an empty list if the PR code does not need to be split.")
{%- endif %}
class PRReview(BaseModel):
diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py
index 36d9d774a6..81ac28399e 100644
--- a/pr_agent/tools/pr_description.py
+++ b/pr_agent/tools/pr_description.py
@@ -10,18 +10,25 @@
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
-from pr_agent.algo.pr_processing import (OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
- get_pr_diff,
- get_pr_diff_multiple_patchs,
- retry_with_fallback_models)
+from pr_agent.algo.pr_processing import (
+ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD,
+ get_pr_diff,
+ get_pr_diff_multiple_patchs,
+ retry_with_fallback_models,
+)
from pr_agent.algo.token_handler import TokenHandler
-from pr_agent.algo.utils import (ModelType, PRDescriptionHeader, clip_tokens,
- get_max_tokens, get_user_labels, load_yaml,
- set_custom_labels,
- show_relevant_configurations)
+from pr_agent.algo.utils import (
+ ModelType,
+ PRDescriptionHeader,
+ clip_tokens,
+ get_max_tokens,
+ get_user_labels,
+ load_yaml,
+ set_custom_labels,
+ show_relevant_configurations,
+)
from pr_agent.config_loader import get_settings
-from pr_agent.git_providers import (GithubProvider, get_git_provider,
- get_git_provider_with_context)
+from pr_agent.git_providers import GithubProvider, get_git_provider, get_git_provider_with_context
from pr_agent.git_providers.git_provider import get_main_pr_language
from pr_agent.log import get_logger
from pr_agent.servers.help import HelpMessage
@@ -29,8 +36,7 @@
class PRDescription:
- def __init__(self, pr_url: str, args: list = None,
- ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
+ def __init__(self, pr_url: str, args: list = None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
@@ -40,14 +46,13 @@ def __init__(self, pr_url: str, args: list = None,
"""
# Initialize the git provider and main PR language
self.git_provider = get_git_provider_with_context(pr_url)
- self.main_pr_language = get_main_pr_language(
- self.git_provider.get_languages(), self.git_provider.get_files()
- )
+ self.main_pr_language = get_main_pr_language(self.git_provider.get_languages(), self.git_provider.get_files())
self.pr_id = self.git_provider.get_pr_id()
self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"]
if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported(
- "gfm_markdown"):
+ "gfm_markdown"
+ ):
get_logger().debug(f"Disabling semantic files types for {self.pr_id}, gfm_markdown not supported.")
get_settings().pr_description.enable_semantic_files_types = False
@@ -57,7 +62,9 @@ def __init__(self, pr_url: str, args: list = None,
# Initialize the variables dictionary
self.COLLAPSIBLE_FILE_LIST_THRESHOLD = get_settings().pr_description.get("collapsible_file_list_threshold", 8)
- enable_pr_diagram = get_settings().pr_description.get("enable_pr_diagram", False) and self.git_provider.is_supported("gfm_markdown") # github and gitlab support gfm_markdown
+ enable_pr_diagram = get_settings().pr_description.get(
+ "enable_pr_diagram", False
+ ) and self.git_provider.is_supported("gfm_markdown") # github and gitlab support gfm_markdown
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
@@ -71,7 +78,8 @@ def __init__(self, pr_url: str, args: list = None,
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
"related_tickets": [],
"ticket_compliance_note": "",
- "include_file_summary_changes": len(self.git_provider.get_diff_files()) <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
+ "include_file_summary_changes": len(self.git_provider.get_diff_files())
+ <= self.COLLAPSIBLE_FILE_LIST_THRESHOLD,
"duplicate_prompt_examples": get_settings().config.get("duplicate_prompt_examples", False),
"enable_pr_diagram": enable_pr_diagram,
}
@@ -94,10 +102,12 @@ def __init__(self, pr_url: str, args: list = None,
async def run(self):
try:
get_logger().info(f"Generating a PR description for pr_id: {self.pr_id}")
- relevant_configs = {'pr_description': dict(get_settings().pr_description),
- 'config': dict(get_settings().config)}
+ relevant_configs = {
+ "pr_description": dict(get_settings().pr_description),
+ "config": dict(get_settings().config),
+ }
get_logger().debug("Relevant configs", artifact=relevant_configs)
- if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
+ if get_settings().config.publish_output and not get_settings().config.get("is_auto_command", False):
self.git_provider.publish_comment("Preparing PR description...", is_temporary=True)
# ticket extraction if exists
@@ -125,38 +135,49 @@ async def run(self):
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer_with_markers()
else:
pr_title, pr_body, changes_walkthrough, pr_file_changes = self._prepare_pr_answer()
- if not self.git_provider.is_supported(
- "publish_file_comments") or not get_settings().pr_description.inline_file_summary:
+ if (
+ not self.git_provider.is_supported("publish_file_comments")
+ or not get_settings().pr_description.inline_file_summary
+ ):
pr_body += "\n\n" + changes_walkthrough + "___\n\n"
get_logger().debug("PR output", artifact={"title": pr_title, "body": pr_body})
# Add help text if gfm_markdown is supported
if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_description.enable_help_text:
- pr_body += "
\n\n ✨ Describe tool usage guide:
\n\n"
+ pr_body += (
+ "
\n\n ✨ Describe tool usage guide:
\n\n"
+ )
pr_body += HelpMessage.get_describe_usage_guide()
pr_body += "\n \n"
elif get_settings().pr_description.enable_help_comment and self.git_provider.is_supported("gfm_markdown"):
if isinstance(self.git_provider, GithubProvider):
- pr_body += ('\n\n___\n\n> Need help?
Type /help how to ... '
- 'in the comments thread for any questions about PR-Agent usage.Check out the '
- 'documentation '
- 'for more information. ')
- else: # gitlab
- pr_body += ("\n\n___\n\nNeed help?
- Type /help how to ... in the comments "
- "thread for any questions about PR-Agent usage.
- Check out the "
- "documentation for more information. ")
+ pr_body += (
+ "\n\n___\n\n> Need help?
Type /help how to ... "
+ "in the comments thread for any questions about PR-Agent usage.Check out the "
+ 'documentation '
+ "for more information. "
+ )
+ else: # gitlab
+ pr_body += (
+ "\n\n___\n\nNeed help?
- Type /help how to ... in the comments "
+ "thread for any questions about PR-Agent usage.
- Check out the "
+ "documentation for more information. "
+ )
# elif get_settings().pr_description.enable_help_comment:
# pr_body += '\n\n___\n\n> 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information'
# Output the relevant configurations if enabled
- if get_settings().get('config', {}).get('output_relevant_configurations', False):
- pr_body += show_relevant_configurations(relevant_section='pr_description')
+ if get_settings().get("config", {}).get("output_relevant_configurations", False):
+ pr_body += show_relevant_configurations(relevant_section="pr_description")
if get_settings().config.publish_output:
-
# publish labels
- if get_settings().pr_description.publish_labels and pr_labels and self.git_provider.is_supported("get_labels"):
+ if (
+ get_settings().pr_description.publish_labels
+ and pr_labels
+ and self.git_provider.is_supported("get_labels")
+ ):
original_labels = self.git_provider.get_pr_labels(update=True)
get_logger().debug(f"original labels", artifact=original_labels)
user_labels = get_user_labels(original_labels)
@@ -172,41 +193,59 @@ async def run(self):
if get_settings().pr_description.publish_description_as_comment:
full_markdown_description = f"## Title\n\n{pr_title.strip()}\n\n___\n{pr_body}"
if get_settings().pr_description.publish_description_as_comment_persistent:
- self.git_provider.publish_persistent_comment(full_markdown_description,
- initial_header="## Title",
- update_header=True,
- name="describe",
- final_update_message=False, )
+ self.git_provider.publish_persistent_comment(
+ full_markdown_description,
+ initial_header="## Title",
+ update_header=True,
+ name="describe",
+ final_update_message=False,
+ )
else:
self.git_provider.publish_comment(full_markdown_description)
else:
self.git_provider.publish_description(pr_title.strip(), pr_body)
# publish final update message
- if (get_settings().pr_description.final_update_message and not get_settings().config.get('is_auto_command', False)):
+ if get_settings().pr_description.final_update_message and not get_settings().config.get(
+ "is_auto_command", False
+ ):
latest_commit_url = self.git_provider.get_latest_commit_url()
if latest_commit_url:
pr_url = self.git_provider.get_pr_url()
- update_comment = f"**[PR Description]({pr_url})** updated to latest commit ({latest_commit_url})"
+ update_comment = (
+ f"**[PR Description]({pr_url})** updated to latest commit ({latest_commit_url})"
+ )
self.git_provider.publish_comment(update_comment)
self.git_provider.remove_initial_comment()
else:
- get_logger().info('PR description, but not published since publish_output is False.')
+ get_logger().info("PR description, but not published since publish_output is False.")
get_settings().data = {"artifact": pr_body}
return
except Exception as e:
- get_logger().error(f"Error generating PR description {self.pr_id}: {e}",
- artifact={"traceback": traceback.format_exc()})
+ get_logger().error(
+ f"Error generating PR description {self.pr_id}: {e}", artifact={"traceback": traceback.format_exc()}
+ )
return ""
async def _prepare_prediction(self, model: str) -> None:
- if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
- get_logger().info("Markers were enabled, but user description does not contain markers. Skipping AI prediction")
+ if get_settings().pr_description.use_description_markers and "pr_agent:" not in self.user_description:
+ get_logger().info(
+ "Markers were enabled, but user description does not contain markers. Skipping AI prediction"
+ )
return None
- large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
- output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
+ large_pr_handling = (
+ get_settings().pr_description.enable_large_pr_handling
+ and "pr_description_only_files_prompts" in get_settings()
+ )
+ output = get_pr_diff(
+ self.git_provider,
+ self.token_handler,
+ model,
+ large_pr_handling=large_pr_handling,
+ return_remaining_files=True,
+ )
if isinstance(output, tuple):
patches_diff, remaining_files_list = output
else:
@@ -224,25 +263,32 @@ async def _prepare_prediction(self, model: str) -> None:
if get_settings().pr_description.enable_semantic_files_types:
self.prediction = await self.extend_uncovered_files(self.prediction)
else:
- get_logger().error(f"Error getting PR diff {self.pr_id}",
- artifact={
- "large_pr_handling": large_pr_handling,
- "patches_diff_type": type(patches_diff).__name__,
- "remaining_files": len(remaining_files_list),
- })
+ get_logger().error(
+ f"Error getting PR diff {self.pr_id}",
+ artifact={
+ "large_pr_handling": large_pr_handling,
+ "patches_diff_type": type(patches_diff).__name__,
+ "remaining_files": len(remaining_files_list),
+ },
+ )
self.prediction = None
else:
# get the diff in multiple patches, with the token handler only for the files prompt
- get_logger().debug('large_pr_handling for describe')
+ get_logger().debug("large_pr_handling for describe")
token_handler_only_files_prompt = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_files_prompts.system,
get_settings().pr_description_only_files_prompts.user,
)
- (patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict,
- files_in_patches_list) = get_pr_diff_multiple_patchs(
- self.git_provider, token_handler_only_files_prompt, model)
+ (
+ patches_compressed_list,
+ total_tokens_list,
+ deleted_files_list,
+ remaining_files_list,
+ file_dict,
+ files_in_patches_list,
+ ) = get_pr_diff_multiple_patchs(self.git_provider, token_handler_only_files_prompt, model)
# get the files prediction for each patch
if not get_settings().pr_description.async_ai_calls:
@@ -250,8 +296,9 @@ async def _prepare_prediction(self, model: str) -> None:
for i, patches in enumerate(patches_compressed_list): # sync calls
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
- prediction_files = await self._get_prediction(model, patches_diff,
- prompt="pr_description_only_files_prompts")
+ prediction_files = await self._get_prediction(
+ model, patches_diff, prompt="pr_description_only_files_prompts"
+ )
results.append(prediction_files)
else: # async calls
tasks = []
@@ -260,15 +307,16 @@ async def _prepare_prediction(self, model: str) -> None:
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
task = asyncio.create_task(
- self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
+ self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts")
+ )
tasks.append(task)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
file_description_str_list = []
for i, result in enumerate(results):
- prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
- if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'):
- prediction_files = prediction_files.removeprefix('pr_files:').strip()
+ prediction_files = result.strip().removeprefix("```yaml").strip("`").strip()
+ if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith("pr_files"):
+ prediction_files = prediction_files.removeprefix("pr_files:").strip()
file_description_str_list.append(prediction_files)
else:
get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files")
@@ -278,7 +326,8 @@ async def _prepare_prediction(self, model: str) -> None:
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_description_prompts.system,
- get_settings().pr_description_only_description_prompts.user)
+ get_settings().pr_description_only_description_prompts.user,
+ )
files_walkthrough = "\n".join(file_description_str_list)
files_walkthrough_prompt = copy.deepcopy(files_walkthrough)
MAX_EXTRA_FILES_TO_PROMPT = 50
@@ -288,7 +337,9 @@ async def _prepare_prediction(self, model: str) -> None:
files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT:
get_logger().debug(f"Too many remaining files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
- files_walkthrough_prompt += f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
+ files_walkthrough_prompt += (
+ f"\n... and {len(remaining_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
+ )
break
if deleted_files_list:
files_walkthrough_prompt += "\n\nAdditional deleted files:"
@@ -296,23 +347,31 @@ async def _prepare_prediction(self, model: str) -> None:
files_walkthrough_prompt += f"\n- {file}"
if i >= MAX_EXTRA_FILES_TO_PROMPT:
get_logger().debug(f"Too many deleted files, clipping to {MAX_EXTRA_FILES_TO_PROMPT}")
- files_walkthrough_prompt += f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
+ files_walkthrough_prompt += (
+ f"\n... and {len(deleted_files_list) - MAX_EXTRA_FILES_TO_PROMPT} more"
+ )
break
tokens_files_walkthrough = len(
- token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt))
+ token_handler_only_description_prompt.encoder.encode(files_walkthrough_prompt)
+ )
total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough
max_tokens_model = get_max_tokens(model)
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
# clip files_walkthrough to git the tokens within the limit
- files_walkthrough_prompt = clip_tokens(files_walkthrough_prompt,
- max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens,
- num_input_tokens=tokens_files_walkthrough)
+ files_walkthrough_prompt = clip_tokens(
+ files_walkthrough_prompt,
+ max_tokens_model
+ - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
+ - token_handler_only_description_prompt.prompt_tokens,
+ num_input_tokens=tokens_files_walkthrough,
+ )
# PR header inference
get_logger().debug(f"PR diff only description", artifact=files_walkthrough_prompt)
- prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough_prompt,
- prompt="pr_description_only_description_prompts")
- prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
+ prediction_headers = await self._get_prediction(
+ model, patches_diff=files_walkthrough_prompt, prompt="pr_description_only_description_prompts"
+ )
+ prediction_headers = prediction_headers.strip().removeprefix("```yaml").strip("`").strip()
# extend the tables with the files not shown
files_walkthrough_extended = await self.extend_uncovered_files(files_walkthrough)
@@ -336,8 +395,8 @@ async def extend_uncovered_files(self, original_prediction: str) -> str:
else:
original_prediction_dict = original_prediction_loaded
if original_prediction_dict:
- files = original_prediction_dict.get('pr_files', [])
- filenames_predicted = [file.get('filename', '').strip() for file in files if isinstance(file, dict)]
+ files = original_prediction_dict.get("pr_files", [])
+ filenames_predicted = [file.get("filename", "").strip() for file in files if isinstance(file, dict)]
else:
filenames_predicted = []
@@ -379,8 +438,12 @@ async def extend_uncovered_files(self, original_prediction: str) -> str:
if counter_extra_files > 0:
get_logger().info(f"Adding {counter_extra_files} unprocessed extra files to table prediction")
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
- if original_prediction_dict and isinstance(original_prediction_dict, dict) and \
- isinstance(prediction_extra_dict, dict) and "pr_files" in prediction_extra_dict:
+ if (
+ original_prediction_dict
+ and isinstance(original_prediction_dict, dict)
+ and isinstance(prediction_extra_dict, dict)
+ and "pr_files" in prediction_extra_dict
+ ):
if "pr_files" in original_prediction_dict:
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
else:
@@ -396,7 +459,6 @@ async def extend_uncovered_files(self, original_prediction: str) -> str:
get_logger().exception(f"Error extending uncovered files {self.pr_id}", artifact={"error": e})
return original_prediction
-
async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction
try:
@@ -438,10 +500,7 @@ async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_descri
user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(self.variables)
response, finish_reason = await self.ai_handler.chat_completion(
- model=model,
- temperature=get_settings().config.temperature,
- system=system_prompt,
- user=user_prompt
+ model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt
)
return response
@@ -454,55 +513,41 @@ def _prepare_data(self):
self.data["User Description"] = self.user_description
# re-order keys
- if 'User Description' in self.data:
- self.data['User Description'] = self.data.pop('User Description')
- if 'title' in self.data:
- self.data['title'] = self.data.pop('title')
- if 'type' in self.data:
- self.data['type'] = self.data.pop('type')
- if 'labels' in self.data:
- self.data['labels'] = self.data.pop('labels')
- if 'description' in self.data:
- self.data['description'] = self.data.pop('description')
- if 'changes_diagram' in self.data:
- changes_diagram = self._sanitize_mermaid_diagram(self.data.pop('changes_diagram').strip())
- if changes_diagram.startswith('```'):
- if not changes_diagram.endswith('```'): # fallback for missing closing
- changes_diagram += '\n```'
- self.data['changes_diagram'] = '\n'+ changes_diagram
- if 'pr_files' in self.data:
- self.data['pr_files'] = self.data.pop('pr_files')
+ if "User Description" in self.data:
+ self.data["User Description"] = self.data.pop("User Description")
+ if "title" in self.data:
+ self.data["title"] = self.data.pop("title")
+ if "type" in self.data:
+ self.data["type"] = self.data.pop("type")
+ if "labels" in self.data:
+ self.data["labels"] = self.data.pop("labels")
+ if "description" in self.data:
+ self.data["description"] = self.data.pop("description")
+ if "changes_diagram" in self.data:
+ sanitized = sanitize_diagram(self.data.pop("changes_diagram"))
+ if sanitized:
+ self.data["changes_diagram"] = sanitized
+ if "pr_files" in self.data:
+ self.data["pr_files"] = self.data.pop("pr_files")
@staticmethod
def _sanitize_mermaid_diagram(diagram: str) -> str:
- if not diagram:
- return diagram
-
- diagram_lines = diagram.splitlines()
- if diagram_lines and diagram_lines[0].startswith("```"):
- fence_end = diagram_lines[-1] if diagram_lines[-1].startswith("```") else None
- body_lines = diagram_lines[1:-1] if fence_end else diagram_lines[1:]
- body_lines = [line.replace("`", "") for line in body_lines]
- if fence_end:
- return "\n".join([diagram_lines[0]] + body_lines + [fence_end]).strip()
- return "\n".join([diagram_lines[0]] + body_lines).strip()
-
- return diagram.replace("`", "")
+ return sanitize_diagram(diagram).lstrip("\n")
def _prepare_labels(self) -> List[str]:
pr_labels = []
# If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
- if 'labels' in self.data and self.data['labels']:
- if type(self.data['labels']) == list:
- pr_labels = self.data['labels']
- elif type(self.data['labels']) == str:
- pr_labels = self.data['labels'].split(',')
- elif 'type' in self.data and self.data['type'] and get_settings().pr_description.publish_labels:
- if type(self.data['type']) == list:
- pr_labels = self.data['type']
- elif type(self.data['type']) == str:
- pr_labels = self.data['type'].split(',')
+ if "labels" in self.data and self.data["labels"]:
+ if type(self.data["labels"]) == list:
+ pr_labels = self.data["labels"]
+ elif type(self.data["labels"]) == str:
+ pr_labels = self.data["labels"].split(",")
+ elif "type" in self.data and self.data["type"] and get_settings().pr_description.publish_labels:
+ if type(self.data["type"]) == list:
+ pr_labels = self.data["type"]
+ elif type(self.data["type"]) == str:
+ pr_labels = self.data["type"].split(",")
pr_labels = [label.strip() for label in pr_labels]
# convert lowercase labels to original case
@@ -520,8 +565,8 @@ def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
get_logger().info(f"Using description marker replacements {self.pr_id}")
# Remove the 'PR Title' key from the dictionary
- ai_title = self.data.pop('title', self.vars["title"])
- if (not get_settings().pr_description.generate_ai_title):
+ ai_title = self.data.pop("title", self.vars["title"])
+ if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
@@ -534,36 +579,37 @@ def _prepare_pr_answer_with_markers(self) -> Tuple[str, str, str, List[dict]]:
else:
ai_header = ""
- ai_type = self.data.get('type')
- if ai_type and not re.search(r'', body):
+ ai_type = self.data.get("type")
+ if ai_type and not re.search(r"", body):
if isinstance(ai_type, list):
- pr_type = ', '.join(str(t) for t in ai_type)
+ pr_type = ", ".join(str(t) for t in ai_type)
else:
pr_type = ai_type
pr_type = f"{ai_header}{pr_type}"
- body = body.replace('pr_agent:type', pr_type)
+ body = body.replace("pr_agent:type", pr_type)
- ai_summary = self.data.get('description')
- if ai_summary and not re.search(r'', body):
+ ai_summary = self.data.get("description")
+ if ai_summary and not re.search(r"", body):
summary = f"{ai_header}{ai_summary}"
- body = body.replace('pr_agent:summary', summary)
+ body = body.replace("pr_agent:summary", summary)
- ai_walkthrough = self.data.get('pr_files')
+ ai_walkthrough = self.data.get("pr_files")
walkthrough_gfm = ""
pr_file_changes = []
- if ai_walkthrough and not re.search(r'', body):
+ if ai_walkthrough and not re.search(r"", body):
try:
- walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(walkthrough_gfm,
- self.file_label_dict)
- body = body.replace('pr_agent:walkthrough', walkthrough_gfm)
+ walkthrough_gfm, pr_file_changes = self.process_pr_files_prediction(
+ walkthrough_gfm, self.file_label_dict
+ )
+ body = body.replace("pr_agent:walkthrough", walkthrough_gfm)
except Exception as e:
get_logger().error(f"Failing to process walkthrough {self.pr_id}: {e}")
- body = body.replace('pr_agent:walkthrough', "")
+ body = body.replace("pr_agent:walkthrough", "")
# Add support for pr_agent:diagram marker (plain and HTML comment formats)
- ai_diagram = self.data.get('changes_diagram')
+ ai_diagram = self.data.get("changes_diagram")
if ai_diagram:
- body = re.sub(r'|pr_agent:diagram', ai_diagram, body)
+ body = re.sub(r"|pr_agent:diagram", ai_diagram, body)
return title, body, walkthrough_gfm, pr_file_changes
@@ -578,14 +624,14 @@ def _prepare_pr_answer(self) -> Tuple[str, str, str, List[dict]]:
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
# Don't display 'PR Labels'
- if 'labels' in self.data and self.git_provider.is_supported("get_labels"):
- self.data.pop('labels')
+ if "labels" in self.data and self.git_provider.is_supported("get_labels"):
+ self.data.pop("labels")
if not get_settings().pr_description.enable_pr_type:
- self.data.pop('type')
+ self.data.pop("type")
# Remove the 'PR Title' key from the dictionary
- ai_title = self.data.pop('title', self.vars["title"])
- if (not get_settings().pr_description.generate_ai_title):
+ ai_title = self.data.pop("title", self.vars["title"])
+ if not get_settings().pr_description.generate_ai_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
@@ -597,78 +643,93 @@ def _prepare_pr_answer(self) -> Tuple[str, str, str, List[dict]]:
pr_body, changes_walkthrough = "", ""
pr_file_changes = []
for idx, (key, value) in enumerate(self.data.items()):
- if key == 'changes_diagram':
+ if key == "changes_diagram":
pr_body += f"### {PRDescriptionHeader.DIAGRAM_WALKTHROUGH.value}\n\n"
pr_body += f"{value}\n\n"
continue
- if key == 'pr_files':
+ if key == "pr_files":
value = self.file_label_dict
else:
- key_publish = key.rstrip(':').replace("_", " ").capitalize()
+ key_publish = key.rstrip(":").replace("_", " ").capitalize()
if key_publish == "Type":
key_publish = "PR Type"
# elif key_publish == "Description":
# key_publish = "PR Description"
pr_body += f"### **{key_publish}**\n"
- if 'walkthrough' in key.lower():
+ if "walkthrough" in key.lower():
if self.git_provider.is_supported("gfm_markdown"):
pr_body += " files:
\n\n"
for file in value:
- filename = file['filename'].replace("'", "`")
- description = file['changes_in_file']
- pr_body += f'- `{filename}`: {description}\n'
+ filename = file["filename"].replace("'", "`")
+ description = file["changes_in_file"]
+ pr_body += f"- `{filename}`: {description}\n"
if self.git_provider.is_supported("gfm_markdown"):
pr_body += " \n"
- elif 'pr_files' in key.lower() and get_settings().pr_description.enable_semantic_files_types: # 'File Walkthrough' section
- changes_walkthrough_table, pr_file_changes = self.process_pr_files_prediction(changes_walkthrough, value)
- if get_settings().pr_description.get('file_table_collapsible_open_by_default', False):
+ elif (
+ "pr_files" in key.lower() and get_settings().pr_description.enable_semantic_files_types
+ ): # 'File Walkthrough' section
+ changes_walkthrough_table, pr_file_changes = self.process_pr_files_prediction(
+ changes_walkthrough, value
+ )
+ if get_settings().pr_description.get("file_table_collapsible_open_by_default", False):
initial_status = " open"
else:
initial_status = ""
changes_walkthrough = f" {PRDescriptionHeader.FILE_WALKTHROUGH.value}
\n\n"
changes_walkthrough += f"{changes_walkthrough_table}\n\n"
changes_walkthrough += " \n\n"
- elif key.lower().strip() == 'description':
+ elif key.lower().strip() == "description":
if isinstance(value, list):
- value = ', '.join(v.rstrip() for v in value)
- value = value.replace('\n-', '\n\n-').strip() # makes the bullet points more readable by adding double space
+ value = ", ".join(v.rstrip() for v in value)
+ value = value.replace(
+ "\n-", "\n\n-"
+ ).strip() # makes the bullet points more readable by adding double space
pr_body += f"{value}\n"
else:
# if the value is a list, join its items by comma
if isinstance(value, list):
- value = ', '.join(v.rstrip() for v in value)
+ value = ", ".join(v.rstrip() for v in value)
pr_body += f"{value}\n"
if idx < len(self.data) - 1:
pr_body += "\n\n___\n\n"
- return title, pr_body, changes_walkthrough, pr_file_changes,
+ return (
+ title,
+ pr_body,
+ changes_walkthrough,
+ pr_file_changes,
+ )
def _prepare_file_labels(self):
file_label_dict = {}
- if (not self.data or not isinstance(self.data, dict) or
- 'pr_files' not in self.data or not self.data['pr_files']):
+ if not self.data or not isinstance(self.data, dict) or "pr_files" not in self.data or not self.data["pr_files"]:
return file_label_dict
- for file in self.data['pr_files']:
+ for file in self.data["pr_files"]:
try:
- required_fields = ['changes_title', 'filename', 'label']
+ required_fields = ["changes_title", "filename", "label"]
if not all(field in file for field in required_fields):
# can happen for example if a YAML generation was interrupted in the middle (no more tokens)
- get_logger().warning(f"Missing required fields in file label dict {self.pr_id}, skipping file",
- artifact={"file": file})
+ get_logger().warning(
+ f"Missing required fields in file label dict {self.pr_id}, skipping file",
+ artifact={"file": file},
+ )
continue
- if not file.get('changes_title'):
- get_logger().warning(f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
- artifact={"file": file})
+ if not file.get("changes_title"):
+ get_logger().warning(
+ f"Empty changes title or summary in file label dict {self.pr_id}, skipping file",
+ artifact={"file": file},
+ )
continue
- filename = file['filename'].replace("'", "`").replace('"', '`')
- changes_summary = file.get('changes_summary', "")
- if not changes_summary and self.vars.get('include_file_summary_changes', True):
- get_logger().warning(f"Empty changes summary in file label dict, skipping file",
- artifact={"file": file})
+ filename = file["filename"].replace("'", "`").replace('"', "`")
+ changes_summary = file.get("changes_summary", "")
+ if not changes_summary and self.vars.get("include_file_summary_changes", True):
+ get_logger().warning(
+ f"Empty changes summary in file label dict, skipping file", artifact={"file": file}
+ )
continue
changes_summary = changes_summary.strip()
- changes_title = file['changes_title'].strip()
- label = file.get('label').strip().lower()
+ changes_title = file["changes_title"].strip()
+ label = file.get("label").strip().lower()
if label not in file_label_dict:
file_label_dict[label] = []
file_label_dict[label].append((filename, changes_title, changes_summary))
@@ -711,7 +772,9 @@ def process_pr_files_prediction(self, pr_body, value):
filename_publish = filename.split("/")[-1]
if file_changes_title and file_changes_title.strip() != "...":
file_changes_title_code = f"{file_changes_title}"
- file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip()
+ file_changes_title_code_br = insert_br_after_x_chars(
+ file_changes_title_code, x=(delta - 5)
+ ).strip()
if len(file_changes_title_code_br) < (delta - 5):
file_changes_title_code_br += " " * ((delta - 5) - len(file_changes_title_code_br))
filename_publish = f"{filename_publish}{file_changes_title_code_br}"
@@ -721,7 +784,7 @@ def process_pr_files_prediction(self, pr_body, value):
delta_nbsp = ""
diff_files = self.git_provider.get_diff_files()
for f in diff_files:
- if f.filename.lower().strip('/') == filename.lower().strip('/'):
+ if f.filename.lower().strip("/") == filename.lower().strip("/"):
num_plus_lines = f.num_plus_lines
num_minus_lines = f.num_minus_lines
diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"
@@ -732,18 +795,25 @@ def process_pr_files_prediction(self, pr_body, value):
# try to add line numbers link to code suggestions
link = ""
- if hasattr(self.git_provider, 'get_line_link'):
+ if hasattr(self.git_provider, "get_line_link"):
filename = filename.strip()
link = self.git_provider.get_line_link(filename, relevant_line_start=-1)
- if (not link or not diff_plus_minus) and ('additional files' not in filename.lower()):
+ if (not link or not diff_plus_minus) and ("additional files" not in filename.lower()):
# get_logger().warning(f"Error getting line link for '{filename}'")
link = ""
# continue
# Add file data to the PR body
file_change_description_br = insert_br_after_x_chars(file_change_description, x=(delta - 5))
- pr_body = self.add_file_data(delta_nbsp, diff_plus_minus, file_change_description_br, filename,
- filename_publish, link, pr_body)
+ pr_body = self.add_file_data(
+ delta_nbsp,
+ diff_plus_minus,
+ file_change_description_br,
+ filename,
+ filename_publish,
+ link,
+ pr_body,
+ )
# Close the collapsible file list
if use_collapsible_file_list:
@@ -757,9 +827,9 @@ def process_pr_files_prediction(self, pr_body, value):
pass
return pr_body, pr_comments
- def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link,
- pr_body) -> str:
-
+ def add_file_data(
+ self, delta_nbsp, diff_plus_minus, file_change_description_br, filename, filename_publish, link, pr_body
+ ) -> str:
if not file_change_description_br:
pr_body += f"""
@@ -791,10 +861,35 @@ def add_file_data(self, delta_nbsp, diff_plus_minus, file_change_description_br,
"""
return pr_body
+
+def sanitize_diagram(diagram_raw: str) -> str:
+ """Sanitize a diagram string: fix missing closing fence and remove backticks."""
+ if not isinstance(diagram_raw, str):
+ return ""
+ diagram = diagram_raw.strip()
+ if not diagram.startswith("```mermaid"):
+ return ""
+
+ # fallback missing closing
+ if not diagram.endswith("```"):
+ diagram += "\n```"
+
+ # remove backticks inside node labels: ["`label`"] -> ["label"]
+ result = []
+ for line in diagram.split("\n"):
+ line = re.sub(
+ r'\["([^"]*?)"\]',
+ lambda m: '["' + m.group(1).replace("`", "") + '"]',
+ line,
+ )
+ result.append(line)
+ return "\n" + "\n".join(result)
+
+
def count_chars_without_html(string):
- if '<' not in string:
+ if "<" not in string:
return len(string)
- no_html_string = re.sub('<[^>]+>', '', string)
+ no_html_string = re.sub("<[^>]+>", "", string)
return len(no_html_string)
@@ -817,22 +912,22 @@ def insert_br_after_x_chars(text: str, x=70):
# convert list items to only if the text is identified as a list
if is_list:
# To handle lists that start with indentation
- leading_whitespace = text[:len(text) - len(text.lstrip())]
+ leading_whitespace = text[: len(text) - len(text.lstrip())]
body = text.lstrip()
body = "" + body[2:]
text = leading_whitespace + body
- text = text.replace("\n- ", '
').replace("\n - ", '
')
- text = text.replace("\n* ", '
').replace("\n * ", '
')
+ text = text.replace("\n- ", "
").replace("\n - ", "
")
+ text = text.replace("\n* ", "
").replace("\n * ", "
")
# convert new lines to
- text = text.replace("\n", '
')
+ text = text.replace("\n", "
")
# split text into lines
- lines = text.split('
')
+ lines = text.split("
")
words = []
for i, line in enumerate(lines):
- words += line.split(' ')
+ words += line.split(" ")
if i < len(lines) - 1:
words[-1] += "
"
@@ -864,7 +959,7 @@ def insert_br_after_x_chars(text: str, x=70):
if "" in word:
is_inside_code = False
- processed_text = ''.join(new_text).strip()
+ processed_text = "".join(new_text).strip()
if is_list:
processed_text = f""
@@ -876,7 +971,7 @@ def replace_code_tags(text):
"""
Replace odd instances of ` with and even instances of ` with
"""
- parts = text.split('`')
+ parts = text.split("`")
for i in range(1, len(parts), 2):
- parts[i] = '' + parts[i] + ''
- return ''.join(parts)
+ parts[i] = "" + parts[i] + ""
+ return "".join(parts)
diff --git a/tests/unittest/test_litellm_reasoning_effort.py b/tests/unittest/test_litellm_reasoning_effort.py
index 30e7813c8f..ad7b3d1b46 100644
--- a/tests/unittest/test_litellm_reasoning_effort.py
+++ b/tests/unittest/test_litellm_reasoning_effort.py
@@ -8,20 +8,26 @@
def create_mock_settings(reasoning_effort_value):
"""Create a fake settings object with configurable reasoning_effort."""
- return type('', (), {
- 'config': type('', (), {
- 'reasoning_effort': reasoning_effort_value,
- 'ai_timeout': 120,
- 'custom_reasoning_model': False,
- 'max_model_tokens': 32000,
- 'verbosity_level': 0,
- 'get': lambda self, key, default=None: default
- })(),
- 'litellm': type('', (), {
- 'get': lambda self, key, default=None: default
- })(),
- 'get': lambda self, key, default=None: default
- })()
+ return type(
+ "",
+ (),
+ {
+ "config": type(
+ "",
+ (),
+ {
+ "reasoning_effort": reasoning_effort_value,
+ "ai_timeout": 120,
+ "custom_reasoning_model": False,
+ "max_model_tokens": 32000,
+ "verbosity_level": 0,
+ "get": lambda self, key, default=None: default,
+ },
+ )(),
+ "litellm": type("", (), {"get": lambda self, key, default=None: default})(),
+ "get": lambda self, key, default=None: default,
+ },
+ )()
def create_mock_acompletion_response():
@@ -37,7 +43,7 @@ def create_mock_acompletion_response():
@pytest.fixture
def mock_logger():
"""Mock logger to capture info and warning calls."""
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.get_logger') as mock_log:
+ with patch("pr_agent.algo.ai_handlers.litellm_ai_handler.get_logger") as mock_log:
mock_log_instance = MagicMock()
mock_log.return_value = mock_log_instance
yield mock_log_instance
@@ -66,15 +72,13 @@ async def test_gpt5_valid_reasoning_effort_none(self, monkeypatch, mock_logger):
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
# Mock acompletion to capture kwargs
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Verify the call was made with correct reasoning_effort
assert mock_completion.called
@@ -91,15 +95,13 @@ async def test_gpt5_valid_reasoning_effort_low(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings("low")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "low"
@@ -112,15 +114,13 @@ async def test_gpt5_valid_reasoning_effort_medium(self, monkeypatch, mock_logger
fake_settings = create_mock_settings("medium")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "medium"
@@ -133,15 +133,13 @@ async def test_gpt5_valid_reasoning_effort_high(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings("high")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "high"
@@ -154,15 +152,13 @@ async def test_gpt5_valid_reasoning_effort_xhigh(self, monkeypatch, mock_logger)
fake_settings = create_mock_settings("xhigh")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5.2",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5.2", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "xhigh"
@@ -175,15 +171,13 @@ async def test_gpt5_valid_reasoning_effort_minimal(self, monkeypatch, mock_logge
fake_settings = create_mock_settings("minimal")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "minimal"
@@ -198,15 +192,13 @@ async def test_gpt5_invalid_reasoning_effort_with_warning(self, monkeypatch, moc
fake_settings = create_mock_settings("extreme")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Should default to 'medium'
call_kwargs = mock_completion.call_args[1]
@@ -227,15 +219,13 @@ async def test_gpt5_invalid_reasoning_effort_thinking_model(self, monkeypatch, m
fake_settings = create_mock_settings("invalid_value")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07_thinking",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07_thinking", system="test system", user="test user")
# Should default to 'medium' (no special handling for _thinking models)
call_kwargs = mock_completion.call_args[1]
@@ -253,15 +243,13 @@ async def test_gpt5_none_config_defaults_to_medium(self, monkeypatch, mock_logge
fake_settings = create_mock_settings(None)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Should default to 'medium'
call_kwargs = mock_completion.call_args[1]
@@ -279,15 +267,13 @@ async def test_gpt5_none_config_thinking_model_defaults_to_medium(self, monkeypa
fake_settings = create_mock_settings(None)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07_thinking",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07_thinking", system="test system", user="test user")
# Should default to 'medium' (no special handling for _thinking models)
call_kwargs = mock_completion.call_args[1]
@@ -322,15 +308,13 @@ async def test_gpt5_model_detection_various_versions(self, monkeypatch, mock_log
]
for model in gpt5_models:
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model=model,
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model=model, system="test system", user="test user")
# All should trigger GPT-5 logic
call_kwargs = mock_completion.call_args[1]
@@ -346,15 +330,13 @@ async def test_non_gpt5_model_no_thinking_kwargs(self, monkeypatch, mock_logger)
non_gpt5_models = ["gpt-4o", "gpt-4-turbo", "claude-3-5-sonnet"]
for model in non_gpt5_models:
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model=model,
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model=model, system="test system", user="test user")
# Should not have reasoning_effort in kwargs
call_kwargs = mock_completion.call_args[1]
@@ -366,15 +348,13 @@ async def test_gpt5_suffix_removal(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings("low")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5_thinking",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5_thinking", system="test system", user="test user")
# Model should be transformed to openai/gpt-5
call_kwargs = mock_completion.call_args[1]
@@ -388,15 +368,13 @@ async def test_gpt5_thinking_suffix_default_medium(self, monkeypatch, mock_logge
fake_settings = create_mock_settings(None)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07_thinking",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07_thinking", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "medium"
@@ -408,15 +386,13 @@ async def test_gpt5_regular_suffix_default_medium(self, monkeypatch, mock_logger
fake_settings = create_mock_settings(None)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "medium"
@@ -428,15 +404,13 @@ async def test_gpt5_thinking_suffix_config_overrides_default(self, monkeypatch,
fake_settings = create_mock_settings("high")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07_thinking",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07_thinking", system="test system", user="test user")
# Should use 'high' from config, not 'medium' default
call_kwargs = mock_completion.call_args[1]
@@ -451,15 +425,13 @@ async def test_gpt5_info_logging_configured_value(self, monkeypatch, mock_logger
fake_settings = create_mock_settings("low")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Verify log
mock_logger.info.assert_any_call("Using reasoning_effort='low' for GPT-5 model")
@@ -470,15 +442,13 @@ async def test_gpt5_info_logging_default_value(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings(None)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Verify log
mock_logger.info.assert_any_call("Using reasoning_effort='medium' for GPT-5 model")
@@ -490,15 +460,13 @@ async def test_gpt5_warning_only_for_invalid_non_none(self, monkeypatch, mock_lo
fake_settings = create_mock_settings(None)
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# No warning for None
mock_logger.warning.assert_not_called()
@@ -510,15 +478,13 @@ async def test_gpt5_warning_only_for_invalid_non_none(self, monkeypatch, mock_lo
fake_settings = create_mock_settings("ultra")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Warning should be logged for invalid value
mock_logger.warning.assert_called_once()
@@ -531,15 +497,13 @@ async def test_thinking_kwargs_gpt5_structure(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings("medium")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
@@ -556,21 +520,21 @@ async def test_thinking_kwargs_not_created_for_non_gpt5(self, monkeypatch, mock_
fake_settings = create_mock_settings("high")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-4o",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-4o", system="test system", user="test user")
call_kwargs = mock_completion.call_args[1]
# Should not have reasoning_effort keys
assert "reasoning_effort" not in call_kwargs
- assert call_kwargs.get("allowed_openai_params") is None or "reasoning_effort" not in call_kwargs.get("allowed_openai_params", [])
+ assert call_kwargs.get("allowed_openai_params") is None or "reasoning_effort" not in call_kwargs.get(
+ "allowed_openai_params", []
+ )
# ========== Group 7: Edge Cases ==========
@@ -580,15 +544,13 @@ async def test_empty_string_reasoning_effort(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings("")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Should default to 'medium' and log warning
call_kwargs = mock_completion.call_args[1]
@@ -601,15 +563,13 @@ async def test_case_sensitive_reasoning_effort(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings("LOW")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Should treat uppercase as invalid and default to 'medium'
call_kwargs = mock_completion.call_args[1]
@@ -622,15 +582,13 @@ async def test_whitespace_reasoning_effort(self, monkeypatch, mock_logger):
fake_settings = create_mock_settings(" low ")
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5-2025-08-07",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5-2025-08-07", system="test system", user="test user")
# Should treat value with whitespace as invalid
call_kwargs = mock_completion.call_args[1]
@@ -649,15 +607,13 @@ async def test_gpt5_prefix_match_only(self, monkeypatch, mock_logger):
monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
# Test gpt-50 (will match due to startswith logic)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-50",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-50", system="test system", user="test user")
# Due to startswith('gpt-5'), gpt-50 will match and have reasoning_effort
call_kwargs = mock_completion.call_args[1]
@@ -667,16 +623,54 @@ async def test_gpt5_prefix_match_only(self, monkeypatch, mock_logger):
mock_logger.reset_mock()
# Test gpt-5 (should match)
- with patch('pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion', new_callable=AsyncMock) as mock_completion:
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
mock_completion.return_value = create_mock_acompletion_response()
handler = LiteLLMAIHandler()
- await handler.chat_completion(
- model="gpt-5",
- system="test system",
- user="test user"
- )
+ await handler.chat_completion(model="gpt-5", system="test system", user="test user")
# Should have reasoning_effort
call_kwargs = mock_completion.call_args[1]
assert call_kwargs["reasoning_effort"] == "medium"
+
+
+class TestLiteLLMApiKeyForwarding:
+ @pytest.mark.asyncio
+ async def test_openai_requests_do_not_forward_stale_global_api_key(self, monkeypatch, mock_logger):
+ fake_settings = create_mock_settings(None)
+ monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
+ monkeypatch.setenv("OPENAI_API_KEY", "openai-test-key")
+ monkeypatch.setattr(litellm_handler.litellm, "api_key", "stale-ollama-key")
+
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
+ mock_completion.return_value = create_mock_acompletion_response()
+
+ handler = LiteLLMAIHandler()
+ await handler.chat_completion(model="gpt-4o", system="test system", user="test user")
+
+ call_kwargs = mock_completion.call_args[1]
+ assert "api_key" not in call_kwargs
+
+ @pytest.mark.asyncio
+ async def test_ollama_requests_forward_instance_api_key(self, monkeypatch, mock_logger):
+ fake_settings = create_mock_settings(None)
+ fake_settings.get = lambda key, default=None: {"OLLAMA.API_KEY": "ollama-cloud-key"}.get(key, default)
+ fake_settings.ollama = type("", (), {"api_key": "ollama-cloud-key", "api_base": None})()
+ monkeypatch.setattr(litellm_handler, "get_settings", lambda: fake_settings)
+ monkeypatch.setenv("OPENAI_API_KEY", "openai-test-key")
+ monkeypatch.setattr(litellm_handler.litellm, "api_key", None)
+
+ with patch(
+ "pr_agent.algo.ai_handlers.litellm_ai_handler.acompletion", new_callable=AsyncMock
+ ) as mock_completion:
+ mock_completion.return_value = create_mock_acompletion_response()
+
+ handler = LiteLLMAIHandler()
+ await handler.chat_completion(model="ollama/llama3.2", system="test system", user="test user")
+
+ call_kwargs = mock_completion.call_args[1]
+ assert call_kwargs["api_key"] == "ollama-cloud-key"
diff --git a/tests/unittest/test_pr_description.py b/tests/unittest/test_pr_description.py
new file mode 100644
index 0000000000..c99f547e7d
--- /dev/null
+++ b/tests/unittest/test_pr_description.py
@@ -0,0 +1,79 @@
+from unittest.mock import MagicMock, patch
+
+import yaml
+
+from pr_agent.tools.pr_description import PRDescription, sanitize_diagram
+
+KEYS_FIX = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"]
+
+def _make_instance(prediction_yaml: str):
+ """Create a PRDescription instance, bypassing __init__."""
+ with patch.object(PRDescription, '__init__', lambda self, *a, **kw: None):
+ obj = PRDescription.__new__(PRDescription)
+ obj.prediction = prediction_yaml
+ obj.keys_fix = KEYS_FIX
+ obj.user_description = ""
+ return obj
+
+
+def _mock_settings():
+ """Mock get_settings used by _prepare_data."""
+ settings = MagicMock()
+ settings.pr_description.add_original_user_description = False
+ return settings
+
+
+def _prediction_with_diagram(diagram_value: str) -> str:
+ """Build a minimal YAML prediction string that includes changes_diagram."""
+ return yaml.dump({
+ 'title': 'test',
+ 'description': 'test',
+ 'changes_diagram': diagram_value,
+ })
+
+
+class TestPRDescriptionDiagram:
+
+ @patch('pr_agent.tools.pr_description.get_settings')
+ def test_diagram_not_starting_with_fence_is_removed(self, mock_get_settings):
+ mock_get_settings.return_value = _mock_settings()
+ obj = _make_instance(_prediction_with_diagram('graph LR\nA --> B'))
+ obj._prepare_data()
+ assert 'changes_diagram' not in obj.data
+
+ @patch('pr_agent.tools.pr_description.get_settings')
+ def test_diagram_missing_closing_fence_is_appended(self, mock_get_settings):
+ mock_get_settings.return_value = _mock_settings()
+ obj = _make_instance(_prediction_with_diagram('```mermaid\ngraph LR\nA --> B'))
+ obj._prepare_data()
+ assert obj.data['changes_diagram'] == '\n```mermaid\ngraph LR\nA --> B\n```'
+
+ @patch('pr_agent.tools.pr_description.get_settings')
+ def test_backticks_inside_label_are_removed(self, mock_get_settings):
+ mock_get_settings.return_value = _mock_settings()
+ obj = _make_instance(_prediction_with_diagram('```mermaid\ngraph LR\nA["`file`"] --> B\n```'))
+ obj._prepare_data()
+ assert obj.data['changes_diagram'] == '\n```mermaid\ngraph LR\nA["file"] --> B\n```'
+
+ @patch('pr_agent.tools.pr_description.get_settings')
+ def test_backticks_outside_label_are_kept(self, mock_get_settings):
+ mock_get_settings.return_value = _mock_settings()
+ obj = _make_instance(_prediction_with_diagram('```mermaid\ngraph LR\nA["`file`"] -->|`edge`| B\n```'))
+ obj._prepare_data()
+ assert obj.data['changes_diagram'] == '\n```mermaid\ngraph LR\nA["file"] -->|`edge`| B\n```'
+
+ @patch('pr_agent.tools.pr_description.get_settings')
+ def test_normal_diagram_only_adds_newline(self, mock_get_settings):
+ mock_get_settings.return_value = _mock_settings()
+ obj = _make_instance(_prediction_with_diagram('```mermaid\ngraph LR\nA["file.py"] --> B["output"]\n```'))
+ obj._prepare_data()
+ assert obj.data['changes_diagram'] == '\n```mermaid\ngraph LR\nA["file.py"] --> B["output"]\n```'
+
+ def test_none_input_returns_empty(self):
+ assert sanitize_diagram(None) == ''
+
+ def test_non_string_input_returns_empty(self):
+ assert sanitize_diagram(123) == ''
+
+ def test_non_mermaid_fence_returns_empty(self):
+ assert sanitize_diagram('```python\nprint("hello")\n```') == ''