Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
1bd8558
Update LLMSlot and LLMGroupSlot to use LLM_API instead of model
NotBioWaste905 Apr 9, 2025
f8eac43
add _ainvoke method to LLM_API, add history usage in LLM Slots
NotBioWaste905 Apr 10, 2025
80c0ffb
move imports in llm/filters inder TYPE_CHECKING, add llm slots tutorial
NotBioWaste905 Apr 14, 2025
4567fbf
add __future__.annotations import
NotBioWaste905 Apr 14, 2025
16a990f
rework LLMGroupSlot to extract values from nested slots via models sp…
NotBioWaste905 Apr 16, 2025
13e7dd4
format
NotBioWaste905 Apr 16, 2025
6f22311
group slots by model
NotBioWaste905 Apr 23, 2025
e56ae42
Refactor imports to use TYPE_CHECKING and enhance LLM slot functional…
NotBioWaste905 Apr 28, 2025
8b42740
move typehints under TYPE_CHECKING, fix default return_type value in …
NotBioWaste905 Apr 28, 2025
3048fa6
fix missing awaits, update tutorial
NotBioWaste905 Apr 30, 2025
d679927
fix groupslot test
NotBioWaste905 May 12, 2025
34d4ee5
Enhance LLMSlot and LLMGroupSlot functionality by updating prompt han…
NotBioWaste905 May 15, 2025
ca0e5f0
Add history attribute to LLMSlot and update prompt handling in tests
NotBioWaste905 May 15, 2025
53f4c71
fix condition
NotBioWaste905 May 15, 2025
453bd05
add a few words about llm slot prompting
NotBioWaste905 May 16, 2025
79329b3
Update documentation to clarify prompt usage for LLMSlot only
NotBioWaste905 May 16, 2025
ac91cd0
remove init from LLMSlot
NotBioWaste905 May 16, 2025
343200d
Add prompt field to LLMGroupSlot and update slot extraction in tutorial
NotBioWaste905 May 22, 2025
1722919
Merge remote-tracking branch 'origin/dev' into feat/improve_llm_slots
NotBioWaste905 May 23, 2025
14980f7
fixed parameters related bugs
NotBioWaste905 May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chatsky/__rebuild_pydantic_models__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from chatsky.core.transition import Transition
from chatsky.llm import LLM_API
from chatsky.messengers.telegram.abstract import TelegramMetadata
from chatsky.slots import GroupSlot

ContextMainInfo.model_rebuild()
ContextDict.model_rebuild()
Expand Down
4 changes: 2 additions & 2 deletions chatsky/core/ctx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, field_serializer, field_validator

from chatsky.slots.slots import SlotManager

if TYPE_CHECKING:
from chatsky.slots.slots import SlotManager
from chatsky.core.service import ComponentExecutionState
from chatsky.core.script import Node
from chatsky.core.pipeline import Pipeline
Expand Down Expand Up @@ -63,7 +63,7 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True):
"""
stats: Dict[str, Any] = Field(default_factory=dict)
"Enables complex stats collection across multiple turns."
slot_manager: SlotManager = Field(default_factory=SlotManager)
slot_manager: SlotManager = Field(default_factory=dict, validate_default=True)
"Stores extracted slots."
transition: Optional[Transition] = Field(default=None, exclude=True)
"""
Expand Down
6 changes: 4 additions & 2 deletions chatsky/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from chatsky.context_storages import DBContextStorage, MemoryContextStorage
from chatsky.messengers.console import CLIMessengerInterface
from chatsky.messengers.common import MessengerInterface
from chatsky.slots.slots import GroupSlot

# to TYPE_CHECKING
from chatsky.core.service.group import ServiceGroup, ServiceGroupInitTypes
from chatsky.core.service.extra import ComponentExtraHandlerInitTypes, BeforeHandler, AfterHandler
from .service import Service
Expand All @@ -33,6 +34,7 @@
from chatsky.core.script_parsing import JSONImporter, Path

if TYPE_CHECKING:
from chatsky.slots.slots import GroupSlot
from chatsky.llm.llm_api import LLM_API

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,7 +81,7 @@ class Pipeline(BaseModel, extra="forbid", arbitrary_types_allowed=True):

Defaults to ``1.0``.
"""
slots: GroupSlot = Field(default_factory=GroupSlot)
slots: GroupSlot = Field(default_factory=dict, validate_default=True)
"""
Slots configuration.
"""
Expand Down
11 changes: 7 additions & 4 deletions chatsky/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
This module contains a collection of basic functions for history filtering to avoid cluttering LLMs context window.
"""

from __future__ import annotations

import abc
from enum import Enum
from logging import Logger
from typing import Union, Optional
from typing import Union, Optional, TYPE_CHECKING

from pydantic import BaseModel

from chatsky.core.message import Message
from chatsky.core.context import Context
if TYPE_CHECKING:
from chatsky.core import Context
from chatsky.core.message import Message


logger = Logger(name=__name__)
Expand Down Expand Up @@ -149,7 +152,7 @@ def single_message_filter_call(self, ctx: Context, message: Optional[Message], l

class FromModel(BaseHistoryFilter):
"""
Filter that checks if the response of the turn is generated by the currently
Filter that checks if the response of the turn is generated by the currently used model.
"""

def call(
Expand Down
8 changes: 6 additions & 2 deletions chatsky/llm/langchain_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@
The Utils module contains functions for converting Chatsky's objects to an LLM_API and langchain compatible versions.
"""

from __future__ import annotations

import re
import logging
from typing import Literal, Union
from typing import Literal, Union, TYPE_CHECKING
import asyncio

from chatsky.core import Context, Message
from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter, Return
from chatsky.llm.prompt import Prompt, PositionConfig

if TYPE_CHECKING:
from chatsky.core import Context, Message


logger = logging.getLogger(__name__)

Expand Down
19 changes: 16 additions & 3 deletions chatsky/llm/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ async def respond(
result = await self.parser.ainvoke(await self.model.ainvoke(history))
return Message(text=result)
elif issubclass(message_schema, Message):
result = await self._ainvoke(history, message_schema)
return Message.model_validate(result)
elif issubclass(message_schema, BaseModel):
result = await self._ainvoke(history, message_schema)
return Message(text=result.model_dump_json())
else:
raise ValueError

async def _ainvoke(
self,
history: list[BaseMessage],
message_schema: Union[Type[Message], Type[BaseModel]],
) -> Union[Message, BaseModel]:
# call the model and return result as BaseMessage or BaseModel
if issubclass(message_schema, Message):
# Case if the message_schema describes Message structure
structured_model = self.model.with_structured_output(message_schema, method="json_mode")
model_result = await structured_model.ainvoke(history)
Expand All @@ -69,9 +84,7 @@ async def respond(
# Case if the message_schema describes Message.text structure
structured_model = self.model.with_structured_output(message_schema)
model_result = await structured_model.ainvoke(history)
return Message(text=message_schema.model_validate(model_result).model_dump_json())
else:
raise ValueError
return message_schema.model_validate(model_result)

async def condition(self, history: list[BaseMessage], method: BaseMethod) -> bool:
"""
Expand Down
7 changes: 6 additions & 1 deletion chatsky/llm/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
These methods return bool values based on LLM result.
"""

from __future__ import annotations

import abc
from typing import TYPE_CHECKING

from pydantic import BaseModel

from chatsky.core.context import Context
from chatsky.llm._langchain_imports import LLMResult

if TYPE_CHECKING:
from chatsky.core.context import Context


class BaseMethod(BaseModel, abc.ABC):
"""
Expand Down
143 changes: 103 additions & 40 deletions chatsky/slots/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

from pydantic import BaseModel, Field, create_model

from chatsky.llm.langchain_context import context_to_history, message_to_langchain, get_langchain_context
from chatsky.llm.filters import DefaultFilter
from chatsky.slots.slots import ValueSlot, SlotNotExtracted, GroupSlot, ExtractedGroupSlot, ExtractedValueSlot
from chatsky.llm.prompt import Prompt

if TYPE_CHECKING:
from chatsky.core import Context
from chatsky.core.message import Message


logger = logging.getLogger(__name__)
Expand All @@ -29,29 +33,47 @@ class LLMSlot(ValueSlot, frozen=True):
`caption` parameter using LLM.
"""

# TODO:
# add history (and overall update the class)

caption: str
return_type: type = str
llm_model_name: str = ""

def __init__(self, caption, llm_model_name=""):
super().__init__(caption=caption, llm_model_name=llm_model_name)
prompt: Prompt = Field(
default="You are an expert extraction algorithm. "
"Only extract relevant information from the text. "
"If you do not know the value of an attribute asked to extract, "
"return null for the attribute's value.",
validate_default=True,
)
history: int = 0

async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]:
request_text = ctx.last_request.text
if request_text == "":
return SlotNotExtracted()
model_instance = ctx.pipeline.models[self.llm_model_name].model

history_messages = await get_langchain_context(
system_prompt=await ctx.pipeline.models[self.llm_model_name].system_prompt(ctx),
call_prompt=self.prompt,
ctx=ctx,
length=self.history,
filter_func=DefaultFilter(),
llm_model_name=self.llm_model_name,
max_size=1000,
)
if history_messages == []:
print("No history messages found, using last request")
history_messages = [await message_to_langchain(ctx.last_request, ctx)]
# Dynamically create a Pydantic model based on the caption
return_type = self.return_type

class DynamicModel(BaseModel):
value: self.return_type = Field(description=self.caption)
value: return_type = Field(description=self.caption)

print(f"History messages: {history_messages}")

structured_model = model_instance.with_structured_output(DynamicModel)
result: DynamicModel = await ctx.pipeline.models[self.llm_model_name]._ainvoke(
history=history_messages, message_schema=DynamicModel
)

result = await structured_model.ainvoke(request_text)
return result.value


Expand All @@ -64,29 +86,60 @@ class LLMGroupSlot(GroupSlot):

__pydantic_extra__: Dict[str, Union[LLMSlot, "LLMGroupSlot"]]
llm_model_name: str
prompt: Prompt = Field(
default="You are an expert extraction algorithm. "
"Only extract relevant information from the text. "
"If you do not know the value of an attribute asked to extract, "
"return null for the attribute's value.",
validate_default=True,
)
history: int = 0

async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
request_text = ctx.last_request.text
if request_text == "":
return ExtractedGroupSlot()
flat_items = self._flatten_llm_group_slot(self)
captions = {}
for child_name, slot_item in flat_items.items():
captions[child_name] = (slot_item.return_type, Field(description=slot_item.caption, default=None))

logger.debug(f"Flattened group slot: {flat_items}")
DynamicGroupModel = create_model("DynamicGroupModel", **captions)
logger.debug(f"DynamicGroupModel: {DynamicGroupModel}")
# Get all slots grouped by their model names
model_groups = self._group_slots_by_model(self)

# Process each model group separately
all_results = {}
for model_name, slots in model_groups.items():
if not slots:
continue

# Create dynamic model for this group
captions = {}
for child_name, slot_item in slots.items():
captions[child_name] = (slot_item.return_type, Field(description=slot_item.caption, default=None))

DynamicGroupModel = create_model("DynamicGroupModel", **captions)
logger.debug(f"DynamicGroupModel for {model_name}: {DynamicGroupModel}")

model_instance = ctx.pipeline.models[self.llm_model_name].model
structured_model = model_instance.with_structured_output(DynamicGroupModel)
result = await structured_model.ainvoke(request_text)
result_json = result.model_dump()
logger.debug(f"Result JSON: {result_json}")
# swith to get_langchain_context
history_messages = await context_to_history(
ctx, self.history, filter_func=DefaultFilter(), llm_model_name=model_name, max_size=1000
)
if history_messages == []:
history_messages = [await message_to_langchain(ctx.last_request, ctx)]

# Get model and process request
model = ctx.pipeline.models.get(model_name)
if model is None:
logger.warning(f"Model {model_name} not found in pipeline.models")
continue

result: Message = await model._ainvoke(history=history_messages, message_schema=DynamicGroupModel)
result_json = result.model_dump()
logger.debug(f"Result JSON for {model_name}: {result_json}")

# Add results to all_results
all_results.update(result_json)

# Convert flat dict to nested structure
nested_result = {}
for key, value in result_json.items():
for key, value in all_results.items():
if value is None and self.allow_partial_extraction:
continue

Expand All @@ -107,27 +160,37 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot:

return self._dict_to_extracted_slots(nested_result)

def _dict_to_extracted_slots(self, d):
def _group_slots_by_model(self, slot, parent_key="") -> Dict[str, Dict[str, LLMSlot]]:
"""
Convert nested dictionary of ExtractedValueSlots into an ExtractedGroupSlot.
Group slots by their llm_model_name.
Returns a dictionary where keys are model names and values are dictionaries
of slot paths to slot objects.
"""
if not isinstance(d, dict):
return d
return ExtractedGroupSlot(**{k: self._dict_to_extracted_slots(v) for k, v in d.items()})

def _flatten_llm_group_slot(self, slot, parent_key="") -> Dict[str, LLMSlot]:
"""
Convert potentially nested group slot into a dictionary with
flat keys.
Nested keys are flattened as concatenations via ".".
model_groups = {}

As such, values in the returned dictionary are only of type :py:class:`LLMSlot`.
"""
items = {}
for key, value in slot.__pydantic_extra__.items():
new_key = f"{parent_key}.{key}" if parent_key else key

if isinstance(value, LLMGroupSlot):
items.update(self._flatten_llm_group_slot(value, new_key))
# Recursively process nested group slots
nested_groups = self._group_slots_by_model(value, new_key)
for model_name, slots in nested_groups.items():
if model_name not in model_groups:
model_groups[model_name] = {}
model_groups[model_name].update(slots)
else:
items[new_key] = value
return items
# Use the slot's model name or fall back to the group's model name
model_name = value.llm_model_name or self.llm_model_name
if model_name not in model_groups:
model_groups[model_name] = {}
model_groups[model_name][new_key] = value

return model_groups

def _dict_to_extracted_slots(self, d):
"""
Convert nested dictionary of ExtractedValueSlots into an ExtractedGroupSlot.
"""
if not isinstance(d, dict):
return d
return ExtractedGroupSlot(**{k: self._dict_to_extracted_slots(v) for k, v in d.items()})
Loading
Loading