generated from oracle/template-repo
-
Notifications
You must be signed in to change notification settings - Fork 29
feat: add AI module for LLM interaction and a heuristic for checking code–docstring consistency #1121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AmineRaouane
wants to merge
5
commits into
oracle:main
Choose a base branch
from
AmineRaouane:matching-docstring
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
feat: add AI module for LLM interaction and a heuristic for checking code–docstring consistency #1121
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
b9c1921
feat: add AI module for LLM interaction and a heuristic for checking…
AmineRaouane 65e54a1
feat(ai): improve robustness of AI client
AmineRaouane 6da5458
feat: add Inconsistent Description heuristic
AmineRaouane d0a0d65
refactor: move threshold configuration to defaults.ini
AmineRaouane 2df3401
chore(tests): improve test coverage and apply minor heuristic changes
AmineRaouane File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Macaron AI Module | ||
|
||
This module provides the foundation for interacting with Large Language Models (LLMs) in a provider-agnostic way. It includes an abstract client definition, provider-specific client implementations, a client factory, and utility functions for processing responses. | ||
|
||
## Module Components | ||
|
||
- **ai_client.py** | ||
Defines the abstract [`AIClient`](./clients/base.py) class. This class handles the initialization of LLM configuration from the defaults and serves as the base for all specific AI client implementations. | ||
|
||
- **openai_client.py** | ||
Implements the [`OpenAiClient`](./clients/openai_client.py) class, a concrete subclass of [`AIClient`](./ai_client.py). This client interacts with OpenAI-like APIs by sending requests using HTTP and processing the responses. It also validates and structures responses using the tools provided. | ||
|
||
- **ai_factory.py** | ||
Contains the [`AIClientFactory`](./clients/base.py) class, which is responsible for reading provider configuration from the defaults and creating the correct AI client instance. | ||
|
||
- **ai_tools.py** | ||
Offers utility functions such as `structure_response` to assist with parsing and validating the JSON response returned by an LLM. These functions ensure that responses conform to a given Pydantic model for easier downstream processing. | ||
|
||
## Usage | ||
|
||
1. **Configuration:** | ||
The module reads the LLM configuration from the application defaults (using the `defaults` module). Make sure that the `llm` section in your configuration includes valid settings such as `enabled`, `api_key`, `api_endpoint`, `model`, and `context_window`. | ||
|
||
2. **Creating a Client:** | ||
Use the [`AIClientFactory`](./clients/ai_factory.py) to create an AI client instance. The factory checks the configured provider and returns a client (e.g., an instance of [`OpenAiClient`](./clients/openai_client.py)) that can be used to invoke the LLM. | ||
|
||
Example: | ||
```py | ||
from macaron.ai.clients.ai_factory import AIClientFactory | ||
|
||
factory = AIClientFactory() | ||
client = factory.create_client(system_prompt="You are a helpful assistant.") | ||
response = client.invoke("Hello, how can you assist me?") | ||
print(response) | ||
``` | ||
|
||
3. **Response Processing:** | ||
When a structured response is required, pass a Pydantic model class to the `invoke` method. The [`ai_tools.py`](./ai_tools.py) module takes care of parsing and validating the response to ensure it meets the expected structure. | ||
|
||
## Logging and Error Handling | ||
|
||
- The module uses Python's logging framework to report important events, such as token usage and warnings when prompts exceed the allowed context window. | ||
- Configuration errors (e.g., missing API key or endpoint) are handled by raising descriptive exceptions, such as those defined in the [`ConfigurationError`](../errors.py). | ||
|
||
## Extensibility | ||
|
||
The design of the AI module is provider-agnostic. To add support for additional LLM providers: | ||
- Implement a new client by subclassing [`AIClient`](./clients/base.py). | ||
- Add the new client to the [`PROVIDER_MAPPING`](./clients/ai_factory.py). | ||
- Update the configuration defaults accordingly. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. | ||
|
||
"""This module provides utility functions for Large Language Model (LLM).""" | ||
import json | ||
import logging | ||
import re | ||
from typing import Any | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
def extract_json(response_text: str) -> Any: | ||
""" | ||
Parse the response from the LLM. | ||
|
||
If raw JSON parsing fails, attempts to extract a JSON object from text. | ||
|
||
Parameters | ||
---------- | ||
response_text: str | ||
The response text from the LLM. | ||
|
||
Returns | ||
------- | ||
dict[str, Any] | None | ||
The structured JSON object. | ||
""" | ||
try: | ||
data = json.loads(response_text) | ||
except json.JSONDecodeError: | ||
logger.debug("Full JSON parse failed; trying to extract JSON from text.") | ||
# If the response is not a valid JSON, try to extract a JSON object from the text. | ||
match = re.search(r"\{.*\}", response_text, re.DOTALL) | ||
if not match: | ||
return None | ||
try: | ||
data = json.loads(match.group(0)) | ||
except json.JSONDecodeError as e: | ||
logger.debug("Failed to parse extracted JSON: %s", e) | ||
return None | ||
|
||
return data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. | ||
|
||
"""This module provides a mapping of AI client providers to their respective client classes.""" | ||
|
||
from macaron.ai.clients.base import AIClient | ||
from macaron.ai.clients.openai_client import OpenAiClient | ||
|
||
PROVIDER_MAPPING: dict[str, type[AIClient]] = {"openai": OpenAiClient} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. | ||
|
||
"""This module defines the AIClientFactory class for creating AI clients based on provider configuration.""" | ||
|
||
import logging | ||
|
||
from macaron.ai.clients import PROVIDER_MAPPING | ||
from macaron.ai.clients.base import AIClient | ||
from macaron.config.defaults import defaults | ||
from macaron.errors import ConfigurationError | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class AIClientFactory: | ||
"""Factory to create AI clients based on provider configuration.""" | ||
|
||
def __init__(self) -> None: | ||
""" | ||
Initialize the AI client. | ||
|
||
The LLM configuration is read from defaults. | ||
""" | ||
self.params = self._load_defaults() | ||
|
||
def _load_defaults(self) -> dict | None: | ||
section_name = "llm" | ||
default_values = { | ||
"enabled": False, | ||
"provider": "", | ||
"api_key": "", | ||
"api_endpoint": "", | ||
"model": "", | ||
} | ||
|
||
if defaults.has_section(section_name): | ||
section = defaults[section_name] | ||
default_values["enabled"] = section.getboolean("enabled", default_values["enabled"]) | ||
for key, default_value in default_values.items(): | ||
if isinstance(default_value, str): | ||
default_values[key] = str(section.get(key, default_value)).strip().lower() | ||
|
||
if default_values["enabled"]: | ||
for key, value in default_values.items(): | ||
if not value: | ||
raise ConfigurationError( | ||
f"AI client configuration '{key}' is required but not set in the defaults." | ||
) | ||
|
||
return default_values | ||
|
||
def create_client(self, system_prompt: str) -> AIClient | None: | ||
"""Create an AI client based on the configured provider.""" | ||
if not self.params or not self.params["enabled"]: | ||
return None | ||
|
||
client_class = PROVIDER_MAPPING.get(self.params["provider"]) | ||
if client_class is None: | ||
logger.error("Provider '%s' is not supported.", self.params["provider"]) | ||
return None | ||
return client_class(system_prompt, self.params) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. | ||
|
||
"""This module defines the abstract AIClient class for implementing AI clients.""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
class AIClient(ABC): | ||
"""This abstract class is used to implement ai clients.""" | ||
|
||
def __init__(self, system_prompt: str, params: dict) -> None: | ||
""" | ||
Initialize the AI client. | ||
|
||
The LLM configuration is read from defaults. | ||
""" | ||
self.system_prompt = system_prompt | ||
self.params = params | ||
|
||
@abstractmethod | ||
def invoke( | ||
self, | ||
user_prompt: str, | ||
temperature: float = 0.2, | ||
response_format: dict | None = None, | ||
) -> dict: | ||
""" | ||
Invoke the LLM and optionally validate its response. | ||
|
||
Parameters | ||
---------- | ||
user_prompt: str | ||
The user prompt to send to the LLM. | ||
temperature: float | ||
The temperature for the LLM response. | ||
response_format: dict | None | ||
The json schema to validate the response against. | ||
|
||
Returns | ||
------- | ||
dict | ||
The validated schema if `response_format` is provided, | ||
or the raw string response if not. | ||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) 2024 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. | ||
|
||
"""This module provides a client for interacting with a Large Language Model (LLM) that is Openai like.""" | ||
|
||
import logging | ||
from typing import Any, TypeVar | ||
|
||
from pydantic import BaseModel | ||
|
||
from macaron.ai.ai_tools import extract_json | ||
from macaron.ai.clients.base import AIClient | ||
from macaron.errors import ConfigurationError, HeuristicAnalyzerValueError | ||
from macaron.util import send_post_http_raw | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the |
||
|
||
|
||
class OpenAiClient(AIClient): | ||
"""A client for interacting with a Large Language Model that is OpenAI API like.""" | ||
|
||
def invoke( | ||
self, | ||
user_prompt: str, | ||
temperature: float = 0.2, | ||
response_format: dict | None = None, | ||
max_tokens: int = 4000, | ||
seed: int = 42, | ||
timeout: int = 30, | ||
) -> Any: | ||
""" | ||
Invoke the LLM and optionally validate its response. | ||
|
||
Parameters | ||
---------- | ||
user_prompt: str | ||
The user prompt to send to the LLM. | ||
temperature: float | ||
The temperature for the LLM response. | ||
response_format: dict | ||
The json schema to validate the response against. If provided, the response will be parsed and validated. | ||
max_tokens: int | ||
The maximum number of tokens for the LLM response. | ||
timeout: int | ||
The timeout for the HTTP request in seconds. | ||
|
||
Returns | ||
------- | ||
Optional[T | str] | ||
The validated Pydantic model instance if `structured_output` is provided, | ||
or the raw string response if not. | ||
|
||
Raises | ||
------ | ||
HeuristicAnalyzerValueError | ||
If there is an error in parsing or validating the response. | ||
""" | ||
if not self.params["enabled"]: | ||
raise ConfigurationError("AI client is not enabled. Please check your configuration.") | ||
|
||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.params['api_key']}"} | ||
payload = { | ||
"model": self.params["model"], | ||
"messages": [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}], | ||
"response_format": response_format, | ||
"temperature": temperature, | ||
"seed": seed, | ||
"max_tokens": max_tokens, | ||
} | ||
|
||
try: | ||
response = send_post_http_raw( | ||
url=self.params["api_endpoint"], json_data=payload, headers=headers, timeout=timeout | ||
) | ||
if not response: | ||
raise HeuristicAnalyzerValueError("No response received from the LLM.") | ||
response_json = response.json() | ||
usage = response_json.get("usage", {}) | ||
|
||
if usage: | ||
usage_str = ", ".join(f"{key} = {value}" for key, value in usage.items()) | ||
logger.info("LLM call token usage: %s", usage_str) | ||
|
||
message_content = response_json["choices"][0]["message"]["content"] | ||
return extract_json(message_content) | ||
|
||
except Exception as e: | ||
logger.error("Error during LLM invocation: %s", e) | ||
raise HeuristicAnalyzerValueError(f"Failed to get or validate LLM response: {e}") from e |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
pydantic
currently used? I can see it being used for theT
type inopenai_client.py
, but I do not see it used elsewhere. Thesrc/macaron/ai/README.md
states it is passed as an argument toinvoke
, but I don't seem to see anywhere that is performed?