diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 26c0303a2d..b13fb408ed 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -13,6 +13,7 @@ from chatsky.core.transition import Transition from chatsky.llm import LLM_API from chatsky.messengers.telegram.abstract import TelegramMetadata +from chatsky.slots import GroupSlot ContextMainInfo.model_rebuild() ContextDict.model_rebuild() diff --git a/chatsky/core/ctx_utils.py b/chatsky/core/ctx_utils.py index bafdd867ec..2640dc2ade 100644 --- a/chatsky/core/ctx_utils.py +++ b/chatsky/core/ctx_utils.py @@ -15,9 +15,9 @@ from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, field_serializer, field_validator -from chatsky.slots.slots import SlotManager if TYPE_CHECKING: + from chatsky.slots.slots import SlotManager from chatsky.core.service import ComponentExecutionState from chatsky.core.script import Node from chatsky.core.pipeline import Pipeline @@ -63,7 +63,7 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True): """ stats: Dict[str, Any] = Field(default_factory=dict) "Enables complex stats collection across multiple turns." - slot_manager: SlotManager = Field(default_factory=SlotManager) + slot_manager: SlotManager = Field(default_factory=dict, validate_default=True) "Stores extracted slots." transition: Optional[Transition] = Field(default=None, exclude=True) """ diff --git a/chatsky/core/pipeline.py b/chatsky/core/pipeline.py index 90adb336c0..cee820f5d4 100644 --- a/chatsky/core/pipeline.py +++ b/chatsky/core/pipeline.py @@ -23,7 +23,8 @@ from chatsky.context_storages import DBContextStorage, MemoryContextStorage from chatsky.messengers.console import CLIMessengerInterface from chatsky.messengers.common import MessengerInterface -from chatsky.slots.slots import GroupSlot + +# to TYPE_CHECKING from chatsky.core.service.group import ServiceGroup, ServiceGroupInitTypes from chatsky.core.service.extra import ComponentExtraHandlerInitTypes, BeforeHandler, AfterHandler from .service import Service @@ -33,6 +34,7 @@ from chatsky.core.script_parsing import JSONImporter, Path if TYPE_CHECKING: + from chatsky.slots.slots import GroupSlot from chatsky.llm.llm_api import LLM_API logger = logging.getLogger(__name__) @@ -79,7 +81,7 @@ class Pipeline(BaseModel, extra="forbid", arbitrary_types_allowed=True): Defaults to ``1.0``. """ - slots: GroupSlot = Field(default_factory=GroupSlot) + slots: GroupSlot = Field(default_factory=dict, validate_default=True) """ Slots configuration. """ diff --git a/chatsky/llm/filters.py b/chatsky/llm/filters.py index 2a60d20d33..ecf3e84d83 100644 --- a/chatsky/llm/filters.py +++ b/chatsky/llm/filters.py @@ -4,15 +4,18 @@ This module contains a collection of basic functions for history filtering to avoid cluttering LLMs context window. """ +from __future__ import annotations + import abc from enum import Enum from logging import Logger -from typing import Union, Optional +from typing import Union, Optional, TYPE_CHECKING from pydantic import BaseModel -from chatsky.core.message import Message -from chatsky.core.context import Context +if TYPE_CHECKING: + from chatsky.core import Context + from chatsky.core.message import Message logger = Logger(name=__name__) @@ -149,7 +152,7 @@ def single_message_filter_call(self, ctx: Context, message: Optional[Message], l class FromModel(BaseHistoryFilter): """ - Filter that checks if the response of the turn is generated by the currently + Filter that checks if the response of the turn is generated by the currently used model. """ def call( diff --git a/chatsky/llm/langchain_context.py b/chatsky/llm/langchain_context.py index e0f0a19a96..69f8bf9d9e 100644 --- a/chatsky/llm/langchain_context.py +++ b/chatsky/llm/langchain_context.py @@ -4,16 +4,20 @@ The Utils module contains functions for converting Chatsky's objects to an LLM_API and langchain compatible versions. """ +from __future__ import annotations + import re import logging -from typing import Literal, Union +from typing import Literal, Union, TYPE_CHECKING import asyncio -from chatsky.core import Context, Message from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available from chatsky.llm.filters import BaseHistoryFilter, Return from chatsky.llm.prompt import Prompt, PositionConfig +if TYPE_CHECKING: + from chatsky.core import Context, Message + logger = logging.getLogger(__name__) diff --git a/chatsky/llm/llm_api.py b/chatsky/llm/llm_api.py index 3170d4bb93..f052d16914 100644 --- a/chatsky/llm/llm_api.py +++ b/chatsky/llm/llm_api.py @@ -60,6 +60,21 @@ async def respond( result = await self.parser.ainvoke(await self.model.ainvoke(history)) return Message(text=result) elif issubclass(message_schema, Message): + result = await self._ainvoke(history, message_schema) + return Message.model_validate(result) + elif issubclass(message_schema, BaseModel): + result = await self._ainvoke(history, message_schema) + return Message(text=result.model_dump_json()) + else: + raise ValueError + + async def _ainvoke( + self, + history: list[BaseMessage], + message_schema: Union[Type[Message], Type[BaseModel]], + ) -> Union[Message, BaseModel]: + # call the model and return result as BaseMessage or BaseModel + if issubclass(message_schema, Message): # Case if the message_schema describes Message structure structured_model = self.model.with_structured_output(message_schema, method="json_mode") model_result = await structured_model.ainvoke(history) @@ -69,9 +84,7 @@ async def respond( # Case if the message_schema describes Message.text structure structured_model = self.model.with_structured_output(message_schema) model_result = await structured_model.ainvoke(history) - return Message(text=message_schema.model_validate(model_result).model_dump_json()) - else: - raise ValueError + return message_schema.model_validate(model_result) async def condition(self, history: list[BaseMessage], method: BaseMethod) -> bool: """ diff --git a/chatsky/llm/methods.py b/chatsky/llm/methods.py index 8867a0c3b5..c7a89e9dc6 100644 --- a/chatsky/llm/methods.py +++ b/chatsky/llm/methods.py @@ -5,13 +5,18 @@ These methods return bool values based on LLM result. """ +from __future__ import annotations + import abc +from typing import TYPE_CHECKING from pydantic import BaseModel -from chatsky.core.context import Context from chatsky.llm._langchain_imports import LLMResult +if TYPE_CHECKING: + from chatsky.core.context import Context + class BaseMethod(BaseModel, abc.ABC): """ diff --git a/chatsky/slots/llm.py b/chatsky/slots/llm.py index 285cece2bf..fc6fefc050 100644 --- a/chatsky/slots/llm.py +++ b/chatsky/slots/llm.py @@ -14,10 +14,14 @@ from pydantic import BaseModel, Field, create_model +from chatsky.llm.langchain_context import context_to_history, message_to_langchain, get_langchain_context +from chatsky.llm.filters import DefaultFilter from chatsky.slots.slots import ValueSlot, SlotNotExtracted, GroupSlot, ExtractedGroupSlot, ExtractedValueSlot +from chatsky.llm.prompt import Prompt if TYPE_CHECKING: from chatsky.core import Context + from chatsky.core.message import Message logger = logging.getLogger(__name__) @@ -29,29 +33,47 @@ class LLMSlot(ValueSlot, frozen=True): `caption` parameter using LLM. """ - # TODO: - # add history (and overall update the class) - caption: str return_type: type = str llm_model_name: str = "" - - def __init__(self, caption, llm_model_name=""): - super().__init__(caption=caption, llm_model_name=llm_model_name) + prompt: Prompt = Field( + default="You are an expert extraction algorithm. " + "Only extract relevant information from the text. " + "If you do not know the value of an attribute asked to extract, " + "return null for the attribute's value.", + validate_default=True, + ) + history: int = 0 async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: request_text = ctx.last_request.text if request_text == "": return SlotNotExtracted() - model_instance = ctx.pipeline.models[self.llm_model_name].model + history_messages = await get_langchain_context( + system_prompt=await ctx.pipeline.models[self.llm_model_name].system_prompt(ctx), + call_prompt=self.prompt, + ctx=ctx, + length=self.history, + filter_func=DefaultFilter(), + llm_model_name=self.llm_model_name, + max_size=1000, + ) + if history_messages == []: + print("No history messages found, using last request") + history_messages = [await message_to_langchain(ctx.last_request, ctx)] # Dynamically create a Pydantic model based on the caption + return_type = self.return_type + class DynamicModel(BaseModel): - value: self.return_type = Field(description=self.caption) + value: return_type = Field(description=self.caption) + + print(f"History messages: {history_messages}") - structured_model = model_instance.with_structured_output(DynamicModel) + result: DynamicModel = await ctx.pipeline.models[self.llm_model_name]._ainvoke( + history=history_messages, message_schema=DynamicModel + ) - result = await structured_model.ainvoke(request_text) return result.value @@ -64,29 +86,60 @@ class LLMGroupSlot(GroupSlot): __pydantic_extra__: Dict[str, Union[LLMSlot, "LLMGroupSlot"]] llm_model_name: str + prompt: Prompt = Field( + default="You are an expert extraction algorithm. " + "Only extract relevant information from the text. " + "If you do not know the value of an attribute asked to extract, " + "return null for the attribute's value.", + validate_default=True, + ) + history: int = 0 async def get_value(self, ctx: Context) -> ExtractedGroupSlot: request_text = ctx.last_request.text if request_text == "": return ExtractedGroupSlot() - flat_items = self._flatten_llm_group_slot(self) - captions = {} - for child_name, slot_item in flat_items.items(): - captions[child_name] = (slot_item.return_type, Field(description=slot_item.caption, default=None)) - logger.debug(f"Flattened group slot: {flat_items}") - DynamicGroupModel = create_model("DynamicGroupModel", **captions) - logger.debug(f"DynamicGroupModel: {DynamicGroupModel}") + # Get all slots grouped by their model names + model_groups = self._group_slots_by_model(self) + + # Process each model group separately + all_results = {} + for model_name, slots in model_groups.items(): + if not slots: + continue + + # Create dynamic model for this group + captions = {} + for child_name, slot_item in slots.items(): + captions[child_name] = (slot_item.return_type, Field(description=slot_item.caption, default=None)) + + DynamicGroupModel = create_model("DynamicGroupModel", **captions) + logger.debug(f"DynamicGroupModel for {model_name}: {DynamicGroupModel}") - model_instance = ctx.pipeline.models[self.llm_model_name].model - structured_model = model_instance.with_structured_output(DynamicGroupModel) - result = await structured_model.ainvoke(request_text) - result_json = result.model_dump() - logger.debug(f"Result JSON: {result_json}") + # swith to get_langchain_context + history_messages = await context_to_history( + ctx, self.history, filter_func=DefaultFilter(), llm_model_name=model_name, max_size=1000 + ) + if history_messages == []: + history_messages = [await message_to_langchain(ctx.last_request, ctx)] + + # Get model and process request + model = ctx.pipeline.models.get(model_name) + if model is None: + logger.warning(f"Model {model_name} not found in pipeline.models") + continue + + result: Message = await model._ainvoke(history=history_messages, message_schema=DynamicGroupModel) + result_json = result.model_dump() + logger.debug(f"Result JSON for {model_name}: {result_json}") + + # Add results to all_results + all_results.update(result_json) # Convert flat dict to nested structure nested_result = {} - for key, value in result_json.items(): + for key, value in all_results.items(): if value is None and self.allow_partial_extraction: continue @@ -107,27 +160,37 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: return self._dict_to_extracted_slots(nested_result) - def _dict_to_extracted_slots(self, d): + def _group_slots_by_model(self, slot, parent_key="") -> Dict[str, Dict[str, LLMSlot]]: """ - Convert nested dictionary of ExtractedValueSlots into an ExtractedGroupSlot. + Group slots by their llm_model_name. + Returns a dictionary where keys are model names and values are dictionaries + of slot paths to slot objects. """ - if not isinstance(d, dict): - return d - return ExtractedGroupSlot(**{k: self._dict_to_extracted_slots(v) for k, v in d.items()}) - - def _flatten_llm_group_slot(self, slot, parent_key="") -> Dict[str, LLMSlot]: - """ - Convert potentially nested group slot into a dictionary with - flat keys. - Nested keys are flattened as concatenations via ".". + model_groups = {} - As such, values in the returned dictionary are only of type :py:class:`LLMSlot`. - """ - items = {} for key, value in slot.__pydantic_extra__.items(): new_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, LLMGroupSlot): - items.update(self._flatten_llm_group_slot(value, new_key)) + # Recursively process nested group slots + nested_groups = self._group_slots_by_model(value, new_key) + for model_name, slots in nested_groups.items(): + if model_name not in model_groups: + model_groups[model_name] = {} + model_groups[model_name].update(slots) else: - items[new_key] = value - return items + # Use the slot's model name or fall back to the group's model name + model_name = value.llm_model_name or self.llm_model_name + if model_name not in model_groups: + model_groups[model_name] = {} + model_groups[model_name][new_key] = value + + return model_groups + + def _dict_to_extracted_slots(self, d): + """ + Convert nested dictionary of ExtractedValueSlots into an ExtractedGroupSlot. + """ + if not isinstance(d, dict): + return d + return ExtractedGroupSlot(**{k: self._dict_to_extracted_slots(v) for k, v in d.items()}) diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index 8e0df51cee..a4d17af177 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -17,7 +17,7 @@ if not langchain_available: pytest.skip(allow_module_level=True, reason="Langchain not available.") -from chatsky.llm._langchain_imports import AIMessage, LLMResult, HumanMessage, SystemMessage +from chatsky.llm._langchain_imports import AIMessage, LLMResult, HumanMessage, SystemMessage, BaseMessage from langchain_core.outputs.chat_generation import ChatGeneration @@ -71,22 +71,22 @@ def __init__(self, root_model): self.root = root_model async def ainvoke(self, history): - if isinstance(history, list): - inst = self.root(history=history) - else: - # For LLMSlot - fields = {} - for field in self.root.model_fields: - fields[field] = "test_data" - inst = self.root(**fields) + fields = {} + print(f"Root model fields: {self.root}") + print(f"History: {history}") + for field in self.root.model_fields: + if field == "history": + fields[field] = history + elif self.root.model_fields[field].annotation is int: + fields[field] = len(history) + elif self.root.model_fields[field].annotation is str: + fields[field] = str(history) + inst = self.root(**fields) return inst - def with_structured_output(self, message_schema): - return message_schema - class MessageSchema(BaseModel): - history: list[str] + history: list[BaseMessage] def __call__(self): return self.model_dump() @@ -128,13 +128,20 @@ async def test_structured_output(self, monkeypatch, mock_structured_model): llm_api = LLM_API(MockChatOpenAI()) # Test data - history = ["message1", "message2"] + history = [HumanMessage("message1"), AIMessage("message2")] # Call the respond method result = await llm_api.respond(message_schema=MessageSchema, history=history) + print(f"Result: {result}") + # Assert the result - expected_result = Message(text='{"history":["message1","message2"]}') + expected_result = Message( + text='{"history":[{"content":"message1","additional_kwargs":{},' + '"response_metadata":{},"type":"human","name":null,"id":null},' + '{"content":"message2","additional_kwargs":{},' + '"response_metadata":{},"type":"ai","name":null,"id":null}]}' + ) assert result == expected_result @@ -259,6 +266,17 @@ async def test_context_to_history(self, context): ] assert res == expected + res = await context_to_history( + ctx=context, length=2, filter_func=DefaultFilter(), llm_model_name="test_model", max_size=100 + ) + expected = [ + HumanMessage(content=[{"type": "text", "text": "Request 2"}]), + AIMessage(content=[{"type": "text", "text": "Response 2"}]), + HumanMessage(content=[{"type": "text", "text": "Request 3"}]), + AIMessage(content=[{"type": "text", "text": "Response 3"}]), + ] + assert res == expected + async def test_context_with_response_to_history(self, filter_context): res = await context_to_history( ctx=filter_context, length=-1, filter_func=DefaultFilter(), llm_model_name="test_model", max_size=100 @@ -447,19 +465,37 @@ async def test_logprob_method(self, filter_context, llmresult): class TestSlots: - async def test_llm_slot(self, pipeline, context): + async def test_empty_llm_slot(self, context): + # Test empty request slot = LLMSlot(caption="test_caption", llm_model_name="test_model") context.current_turn_id = 5 - # Test empty request context.requests[5] = "" assert isinstance(await slot.extract_value(context), SlotNotExtracted) + async def test_llm_slot(self, context): # Test normal request + slot = LLMSlot(caption="test_caption", llm_model_name="test_model") context.requests[5] = "test request" result = await slot.extract_value(context) + print(f"Extracted normal request result: {result}") + assert isinstance(result, str) + + async def test_llm_slot_with_history(self, context): + # Test request with history + slot = LLMSlot(caption="test_caption", llm_model_name="test_model", history=2) + context.requests[5] = "test request with history" + result = await slot.extract_value(context) + print(f"Extracted request with history result: {result}") assert isinstance(result, str) - async def test_llm_group_slot(self, pipeline, context): + async def test_int_llm_slot(self, context): + slot = LLMSlot(caption="test_caption", return_type=int, llm_model_name="test_model", history=2) + context.requests[5] = "test request with history" + result = await slot.extract_value(context) + print(f"Extracted request with history result: {result}") + assert result == 8 + + async def test_llm_group_slot(self, context): slot = LLMGroupSlot( llm_model_name="test_model", name=LLMSlot(caption="Extract person's name"), @@ -475,6 +511,18 @@ async def test_llm_group_slot(self, pipeline, context): print(f"Extracted result: {result}") - assert result.name.extracted_value == "test_data" - assert result.age.extracted_value == "test_data" - assert result.nested.city.extracted_value == "test_data" + assert ( + result.name.extracted_value == "[HumanMessage(content=[{'type': 'text', " + "'text': 'John is 25 years old and lives in New York'}], " + "additional_kwargs={}, response_metadata={})]" + ) + assert ( + result.age.extracted_value == "[HumanMessage(content=[{'type': 'text', 'text': " + "'John is 25 years old and lives in New York'}], " + "additional_kwargs={}, response_metadata={})]" + ) + assert ( + result.nested.city.extracted_value == "[HumanMessage(content=[{'type': 'text', 'text': '" + "John is 25 years old and lives in New York'}], " + "additional_kwargs={}, response_metadata={})]" + ) diff --git a/tutorials/llm/5_llm_slots.py b/tutorials/llm/5_llm_slots.py new file mode 100644 index 0000000000..0fff12cdb8 --- /dev/null +++ b/tutorials/llm/5_llm_slots.py @@ -0,0 +1,135 @@ +# %% [markdown] +""" +# LLM: 5. LLM Slots + +When we need to retrieve specific information from user input—such as a name, +address, or email we can use Chatsky's Slot system along with regexes or other +formally specified data retrieval techniques. +However, if the data is more nuanced or not explicitly stated in the user's +utterance, we recommend using Chatsky's **LLM Slots**. + +In this tutorial, we will explore how to set up Slots that leverage LLMs +to extract more complex or implicit information from user input. +""" +# %pip install chatsky[llm] langchain-openai +# %% +from chatsky import ( + RESPONSE, + TRANSITIONS, + PRE_TRANSITION, + GLOBAL, + LOCAL, + Pipeline, + Transition as Tr, + conditions as cnd, + processing as proc, + responses as rsp, +) +from langchain_openai import ChatOpenAI + +from chatsky.utils.testing import ( + is_interactive_mode, +) +from chatsky.slots.llm import LLMSlot, LLMGroupSlot +from chatsky.llm import LLM_API + +import os + +openai_api_key = os.getenv("OPENAI_API_KEY") + +# %% [markdown] +""" +In this example, we define an **LLM Group Slot** containing two **LLM Slots**. +While these slots can be used independently as regular slots, +grouping them together is recommended when extracting multiple LLM Slots +simultaneously. This approach optimizes performance and improves convenience. + +- In the `LLMSlot.caption` parameter, provide a description of the data you +want to retrieve. More specific descriptions yield better results, +especially when using smaller models. +- Note that we pass the name of the model from the `pipeline.models` +dictionary to the `LLMGroupSlot.model` field. +- Additionally, the `allow_partial_extraction` flag is set to `True` for the +"person" slot. This allows the slot to be filled across multiple messages. +For more details on partial extraction, +refer to the tutorial: %mddoclink(tutorial,slots.2_partial_extraction). +""" + +# %% +slot_model = LLM_API( + ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0) +) + +another_slot_model = LLM_API( + ChatOpenAI(model="gpt-4.1-nano", api_key=openai_api_key, temperature=0) +) + +# You can pass additional prompts to the LLMSlot +# using the `prompt` parameter to fine-tune the extraction process. +SLOTS = { + "person": LLMGroupSlot( + username=LLMSlot( + caption="User's username in uppercase", + prompt="You are an expert extraction algorithm." + "Extract the user's full name that can be " + "scattered troughout the text.", + ), + job=LLMSlot( + llm_model_name="another_slot_model", + caption="User's occupation, job, profession", + ), + age=LLMSlot(caption="User's age", return_type=int), + llm_model_name="slot_model", + allow_partial_extraction=True, + ) +} + +script = { + GLOBAL: { + TRANSITIONS: [ + Tr(dst=("user_flow", "ask"), cnd=cnd.Regexp(pattern=r"^[sS]tart")) + ] + }, + "user_flow": { + LOCAL: { + PRE_TRANSITION: {"get_slot": proc.Extract(slots=["person"])}, + TRANSITIONS: [ + Tr( + dst=("user_flow", "tell"), + cnd=cnd.SlotsExtracted(slots=["person"]), + priority=1.2, + ), + Tr(dst=("user_flow", "repeat_question"), priority=0.8), + ], + }, + "start": {RESPONSE: "", TRANSITIONS: [Tr(dst=("user_flow", "ask"))]}, + "ask": { + RESPONSE: "Hello! Tell me about yourself: what are you doing for " + "the living or your hobbies, your age... " + "And don't forget to introduce yourself!", + }, + "tell": { + RESPONSE: rsp.FilledTemplate( + template="So you are {person.username}, {person.age} and your " + "occupation is {person.job}, right?" + ), + TRANSITIONS: [Tr(dst=("user_flow", "ask"))], + }, + "repeat_question": { + RESPONSE: "I didn't quite understand you...", + }, + }, +} + +pipeline = Pipeline( + script=script, + start_label=("user_flow", "start"), + fallback_label=("user_flow", "repeat_question"), + slots=SLOTS, + models={"slot_model": slot_model, "another_slot_model": another_slot_model}, +) + + +if __name__ == "__main__": + if is_interactive_mode(): + pipeline.run()