From ad6ad644aeedf5f41185ad7961fbae7dc6f41fff Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Thu, 20 Nov 2025 16:00:01 +0800 Subject: [PATCH 1/2] [api][python] Introduce long-term memory interface. rename memory item --- python/flink_agents/api/memory/__init__.py | 17 ++ .../api/memory/long_term_memory.py | 263 ++++++++++++++++++ .../flink_agents/api/memory/tests/__init__.py | 17 ++ .../api/memory/tests/test_long_term_memory.py | 33 +++ 4 files changed, 330 insertions(+) create mode 100644 python/flink_agents/api/memory/__init__.py create mode 100644 python/flink_agents/api/memory/long_term_memory.py create mode 100644 python/flink_agents/api/memory/tests/__init__.py create mode 100644 python/flink_agents/api/memory/tests/test_long_term_memory.py diff --git a/python/flink_agents/api/memory/__init__.py b/python/flink_agents/api/memory/__init__.py new file mode 100644 index 00000000..e154fadd --- /dev/null +++ b/python/flink_agents/api/memory/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/api/memory/long_term_memory.py b/python/flink_agents/api/memory/long_term_memory.py new file mode 100644 index 00000000..98318a3d --- /dev/null +++ b/python/flink_agents/api/memory/long_term_memory.py @@ -0,0 +1,263 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import importlib +from abc import ABC, abstractmethod +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Type + +from pydantic import ( + BaseModel, + Field, + field_serializer, + model_validator, +) + +from flink_agents.api.chat_message import ChatMessage +from flink_agents.api.prompts.prompt import Prompt + + +class ReduceStrategy(Enum): + """Strategy for reducing memory set size.""" + + TRIM = "trim" + SUMMARIZE = "summarize" + + +class ReduceSetup(BaseModel): + """Reduce setup for managing the storge of memory set.""" + + strategy: ReduceStrategy + arguments: Dict[str, Any] = Field(default_factory=dict) + + @staticmethod + def trim_setup(n: int) -> "ReduceSetup": + """Create trim setup.""" + return ReduceSetup(strategy=ReduceStrategy.TRIM, arguments={"n": n}) + + @staticmethod + def summarize_setup( + n: int, model: str, prompt: Prompt | None = None + ) -> "ReduceSetup": + """Create summarize setup.""" + return ReduceSetup( + strategy=ReduceStrategy.SUMMARIZE, + arguments={"n": n, "model": model, "prompt": prompt}, + ) + + +class LongTermMemoryBackend(Enum): + """Backend for Long-Term Memory.""" + + CHROMA = "chroma" + + +class DatetimeRange(BaseModel): + """Represents a datetime range.""" + + start: datetime + end: datetime + + +class MemorySetItem(BaseModel): + """Represents a long term memory item retrieved from vector store. + + Attributes: + memory_set_name: The name of the memory set this item belongs to. + id: The id of this item. + value: The value of this item. + compacted: Whether this item has been compacted. + created_time: The timestamp this item was added to the memory set. + last_accessed_time: The timestamp this item was last accessed. + """ + + memory_set_name: str + id: str + value: Any + compacted: bool = False + created_time: DatetimeRange + last_accessed_time: datetime + +class MemorySet(BaseModel): + """Represents a long term memory set contains memory items. + + Attributes: + name: The name of this memory set. + item_type: The type of items stored in this set. + size: Current items count stored in this set. + capacity: The capacity of this memory set. + reduce_setup: Reduce strategy and additional arguments used to reduce memory + set size. + item_ids: The indices of items stored in this set. + reduced: Whether this memory set has been reduced. + """ + + name: str + item_type: Type[str] | Type[ChatMessage] + size: int = 0 + capacity: int + reduce_setup: ReduceSetup + item_ids: List[str] = Field(default_factory=list) + reduced: bool = False + ltm: "BaseLongTermMemory" = Field(default=None, exclude=True) + + @field_serializer("item_type") + def _serialize_item_type(self, item_type: Type) -> Dict[str, str]: + return {"module": item_type.__module__, "name": item_type.__name__} + + @model_validator(mode="before") + def _deserialize_item_type(self) -> "MemorySet": + if isinstance(self["item_type"], Dict): + module = importlib.import_module(self["item_type"]["module"]) + self["item_type"] = getattr(module, self["item_type"]["name"]) + return self + + def add(self, item: str | ChatMessage) -> None: + """Add a memory item to the set, currently only support item with + type str or ChatMessage. + + If the capacity of this memory set is reached, will trigger reduce + operation to manage the memory set size. + + Args: + item: The item to be inserted to this set. + """ + self.ltm.add(memory_set=self, memory_item=item) + + def get(self) -> List[MemorySetItem]: + """Retrieve all memory items. + + Returns: + All memory items in this set. + """ + return self.ltm.get(memory_set=self) + + def get_recent(self, n: int) -> List[MemorySetItem]: + """Retrieve n most recent memory items. + + Args: + n: The number of items to retrieve. + + Returns: + List of memory items retrieved, sorted by creation timestamp. + """ + return self.ltm.get_recent(memory_set=self, n=n) + + def search(self, query: str, limit: int, **kwargs: Any) -> List[MemorySetItem]: + """Retrieve n memory items related to the query. + + Args: + query: The query to search for. + limit: The number of items to retrieve. + **kwargs: Additional arguments for search. + """ + return self.ltm.search(memory_set=self, query=query, limit=limit, **kwargs) + + +class BaseLongTermMemory(ABC, BaseModel): + """Base Abstract class for long term memory.""" + + @abstractmethod + def create_memory_set( + self, + name: str, + item_type: str | Type[ChatMessage], + capacity: int, + reduce_setup: ReduceSetup, + ) -> MemorySet: + """Create a memory set, if the memory set already exists, return it. + + Args: + name: The name of the memory set. + item_type: The type of the memory item. + capacity: The capacity of the memory set. + reduce_setup: The reduce strategy and arguments for storge management. + + Returns: + The created memory set. + """ + + @abstractmethod + def get_memory_set(self, name: str) -> MemorySet: + """Get the memory set. + + Args: + name: The name of the memory set. + + Returns: + The memory set. + """ + + @abstractmethod + def delete_memory_set(self, name: str) -> None: + """Delete the memory set. + + Args: + name: The name of the memory set. + """ + + @abstractmethod + def add(self, memory_set: MemorySet, memory_item: str | ChatMessage) -> None: + """Add a memory item to the named set, currently only support item with + type str or ChatMessage. + + This method may trigger reduce operation to manage the memory set size. + + Args: + memory_set: The memory set to be inserted. + memory_item: The item to be inserted to this set. + """ + + @abstractmethod + def get(self, memory_set: MemorySet) -> List[MemorySetItem]: + """Retrieve all memory items. + + Args: + memory_set: The set to be retrieved. + + Returns: + All the memory items of this set. + """ + + @abstractmethod + def get_recent(self, memory_set: MemorySet, n: int) -> List[MemorySetItem]: + """Retrieve n most recent memory items. + + Args: + memory_set: The set to be retrieved. + n: The number of items to retrieve. + + Returns: + List of memory items retrieved, sorted by creation timestamp. + """ + + @abstractmethod + def search( + self, memory_set: MemorySet, query: str, limit: int, **kwargs: Any + ) -> List[MemorySetItem]: + """Retrieve n memory items related to the query. + + Args: + memory_set: The set to be retrieved. + query: The query for sematic search. + limit: The number of items to retrieve. + **kwargs: Additional arguments for sematic search. + + Returns: + Related memory items retrieved. + """ diff --git a/python/flink_agents/api/memory/tests/__init__.py b/python/flink_agents/api/memory/tests/__init__.py new file mode 100644 index 00000000..e154fadd --- /dev/null +++ b/python/flink_agents/api/memory/tests/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/api/memory/tests/test_long_term_memory.py b/python/flink_agents/api/memory/tests/test_long_term_memory.py new file mode 100644 index 00000000..659f5a43 --- /dev/null +++ b/python/flink_agents/api/memory/tests/test_long_term_memory.py @@ -0,0 +1,33 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +from flink_agents.api.chat_message import ChatMessage +from flink_agents.api.memory.long_term_memory import MemorySet, ReduceSetup + + +def test_memory_set_serialization() -> None: #noqa:D103 + memory_set = MemorySet(name="chat_history", + item_type=ChatMessage, + capacity=100, + reduce_setup=ReduceSetup.trim_setup(10)) + + json_data = memory_set.model_dump_json() + + memory_set_deserialized = MemorySet.model_validate_json(json_data) + + assert memory_set_deserialized == memory_set From 7c14985fc4a9ef305d0ad0a17b0b4e68ba1b523e Mon Sep 17 00:00:00 2001 From: WenjinXie Date: Thu, 20 Nov 2025 16:00:50 +0800 Subject: [PATCH 2/2] [runtime][python] Implement chroma based long term memory. refactor --- .../flink_agents/runtime/memory/__init__.py | 17 + .../runtime/memory/chroma_long_term_memory.py | 456 ++++++++++++++++++ .../runtime/memory/reduce_functions.py | 80 +++ .../runtime/memory/tests/__init__.py | 17 + .../memory/tests/start_chroma_server.sh | 23 + .../tests/test_chroma_long_term_memory.py | 246 ++++++++++ 6 files changed, 839 insertions(+) create mode 100644 python/flink_agents/runtime/memory/__init__.py create mode 100644 python/flink_agents/runtime/memory/chroma_long_term_memory.py create mode 100644 python/flink_agents/runtime/memory/reduce_functions.py create mode 100644 python/flink_agents/runtime/memory/tests/__init__.py create mode 100644 python/flink_agents/runtime/memory/tests/start_chroma_server.sh create mode 100644 python/flink_agents/runtime/memory/tests/test_chroma_long_term_memory.py diff --git a/python/flink_agents/runtime/memory/__init__.py b/python/flink_agents/runtime/memory/__init__.py new file mode 100644 index 00000000..e154fadd --- /dev/null +++ b/python/flink_agents/runtime/memory/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/runtime/memory/chroma_long_term_memory.py b/python/flink_agents/runtime/memory/chroma_long_term_memory.py new file mode 100644 index 00000000..7bd1a324 --- /dev/null +++ b/python/flink_agents/runtime/memory/chroma_long_term_memory.py @@ -0,0 +1,456 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Dict, List, Type, cast + +import chromadb +from chromadb import ClientAPI as ChromaClient +from chromadb import CloudClient, GetResult, Metadata, Settings +from chromadb.api.types import Document +from chromadb.errors import NotFoundError +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import override + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.embedding_models.embedding_model import BaseEmbeddingModelSetup +from flink_agents.api.memory.long_term_memory import ( + BaseLongTermMemory, + DatetimeRange, + MemorySet, + MemorySetItem, + ReduceSetup, + ReduceStrategy, +) +from flink_agents.api.resource import ResourceType +from flink_agents.api.runner_context import RunnerContext +from flink_agents.runtime.memory.reduce_functions import summarize + +if TYPE_CHECKING: + from chromadb import GetResult + + +class ChromaLongTermMemory(BaseLongTermMemory): + """Long-Term Memory based on ChromaDB.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + embedding_model: str | BaseEmbeddingModelSetup | None = Field( + default=None, + description="Custom embedding model used to generate item embedding. " + "If not provided, will use the default embedding function of chroma.", + ) + runner_context: RunnerContext = Field( + default="The runner context for this long term memory.", + ) + + # Connection configuration + persist_directory: str | None = Field( + default=None, + description="Directory for persistent storage. If None, uses in-memory client.", + ) + host: str | None = Field( + default=None, + description="Host for ChromaDB server connection.", + ) + port: int | None = Field( + default=8000, + description="Port for ChromaDB server connection.", + ) + api_key: str | None = Field( + default=None, + description="API key for Chroma Cloud connection.", + ) + client_settings: Settings | None = Field( + default=None, + description="ChromaDB client settings for advanced configuration.", + ) + tenant: str = Field( + default="default_tenant", + description="ChromaDB tenant for multi-tenancy support.", + ) + database: str = Field( + default="default_database", + description="ChromaDB database name.", + ) + + __client: ChromaClient | None = None + + def __init__( + self, + *, + runner_context: RunnerContext, + embedding_model: str | None = None, + persist_directory: str | None = None, + host: str | None = None, + port: int | None = 8000, + api_key: str | None = None, + client_settings: Settings | None = None, + tenant: str = "default_tenant", + database: str, + **kwargs: Any, + ) -> None: + """Init method.""" + super().__init__( + runner_context=runner_context, + embedding_model=embedding_model, + persist_directory=persist_directory, + host=host, + port=port, + api_key=api_key, + client_settings=client_settings, + tenant=tenant, + database=database, + **kwargs, + ) + + @property + def client(self) -> ChromaClient: + """Return ChromaDB client, creating it if necessary.""" + if self.__client is None: + # Choose client type based on configuration + if self.api_key is not None: + # Cloud mode + self.__client = CloudClient( + tenant=self.tenant, + database=self.database, + api_key=self.api_key, + ) + elif self.host is not None: + # Client-Server Mode + # create database if needed + admin_client = chromadb.AdminClient( + Settings( + chroma_api_impl="chromadb.api.fastapi.FastAPI", + chroma_server_host=self.host, + chroma_server_http_port=self.port, + ) + ) + try: + admin_client.get_database(name=self.database, tenant=self.tenant) + except Exception: + admin_client.create_database(name=self.database, tenant=self.tenant) + self.__client = chromadb.HttpClient( + host=self.host, + port=self.port, + settings=self.client_settings, + tenant=self.tenant, + database=self.database, + ) + else: + err_msg = ( + "As long-term memory backend, chromadb must run in cloud mode or client-server mode, " + "user should provide cloud api-key or server address." + ) + raise RuntimeError(err_msg) + + return self.__client + + @override + def create_memory_set( + self, + name: str, + item_type: type[str] | Type[ChatMessage], + capacity: int, + reduce_setup: ReduceSetup, + ) -> MemorySet: + try: + return self.get_memory_set(name=name) + except NotFoundError: + memory_set = MemorySet( + name=name, + item_type=item_type, + size=0, + capacity=capacity, + reduce_setup=reduce_setup, + ltm=self, + ) + self.client.get_or_create_collection( + name=name, metadata={"serialization": memory_set.model_dump_json()} + ) + return memory_set + + @override + def get_memory_set(self, name: str) -> MemorySet: + metadata = self.client.get_collection(name=name).metadata + memory_set = MemorySet.model_validate_json(metadata["serialization"]) + memory_set.ltm = self + return memory_set + + @override + def delete_memory_set(self, name: str) -> None: + self.client.delete_collection(name=name) + + @override + def add(self, memory_set: MemorySet, memory_item: str | ChatMessage) -> None: + if memory_set.size >= memory_set.capacity: + # trigger reduce operation to manage memory set size. + self._reduce(memory_set) + + if issubclass(memory_set.item_type, BaseModel): + memory_item = memory_item.model_dump_json() + + embedding = self._generate_embedding(text=memory_item) + + item_id = str(uuid.uuid4()) + timestamp = datetime.now(timezone.utc).isoformat() + self.client.get_collection(name=memory_set.name).add( + ids=item_id, + embeddings=embedding, + documents=memory_item, + metadatas={ + "compacted": False, + "created_time_start": timestamp, + "created_time_end": timestamp, + "last_accessed_time": timestamp, + }, + ) + memory_set.size = memory_set.size + 1 + memory_set.item_ids.append(item_id) + self._update_set_metadata(memory_set) + + @override + def get(self, memory_set: MemorySet) -> List[MemorySetItem]: + return self.slice(memory_set=memory_set) + + @override + def get_recent(self, memory_set: MemorySet, n: int) -> List[MemorySetItem]: + offset = memory_set.size - n if memory_set.size > n else 0 + return self.slice(memory_set=memory_set, offset=offset, n=n) + + @override + def search( + self, memory_set: MemorySet, query: str, limit: int, **kwargs: Any + ) -> List[MemorySetItem]: + embedding = self._generate_embedding(text=query) + result = self.client.get_collection(name=memory_set.name).query( + query_embeddings=[embedding] if embedding else None, + query_texts=[query], + n_results=limit, + where=kwargs.get("where"), + include=["documents", "metadatas"], + ) + + ids = result["ids"][0] + documents = result["documents"][0] + metadatas = result["metadatas"][0] + + self._update_items_metadata(memory_set=memory_set, ids=ids, metadatas=metadatas) + + return self._convert_to_items( + memory_set=memory_set, ids=ids, documents=documents, metadatas=metadatas + ) + + def slice( + self, + *, + memory_set: MemorySet, + offset: int | None = None, + n: int | None = None, + update_metadata: bool = True, + ) -> List[MemorySetItem]: + """Retrieve memory items up to limit starting at offset. + + Args: + memory_set: The memory set to be retrieved from. + offset: The offset to start retrieving from. + n: The number of items to retrieve. + update_metadata: If True, update the items metadata in chroma store. + + Returns: + Retrieved memory items. + """ + result: GetResult = self.client.get_collection(name=memory_set.name).get( + offset=offset, limit=n + ) + + ids = result["ids"] + metadatas = result["metadatas"] + documents = result["documents"] + + if update_metadata: + self._update_items_metadata( + memory_set=memory_set, ids=ids, metadatas=metadatas + ) + + return self._convert_to_items( + memory_set=memory_set, ids=ids, documents=documents, metadatas=metadatas + ) + + def delete(self, memory_set: MemorySet, offset: int, n: int) -> None: + """Delete memory items. + + Args: + memory_set: The memory set to be deleted from. + offset: The offset to start delete from. + n: The number of items to delete. + """ + self.client.get_collection(name=memory_set.name).delete( + memory_set.item_ids[offset : offset + n] + ) + del memory_set.item_ids[offset : offset + n] + memory_set.size = memory_set.size - n + + self.client.get_collection(name=memory_set.name).modify( + metadata={"serialization": memory_set.model_dump_json()} + ) + + def update( + self, memory_set: MemorySet, item_id: str, text: str, metadata: Dict[str, Any] + ) -> None: + """Update memory item. + + Args: + memory_set: The memory set item belongs to. + item_id: The memory item to be updated. + text: The updated text. + metadata: The updated metadata. + """ + embedding = self._generate_embedding(text=text) + self.client.get_collection(name=memory_set.name).update( + ids=item_id, + embeddings=embedding, + documents=text, + metadatas=metadata, + ) + + @staticmethod + def _convert_to_items( + memory_set: MemorySet, + ids: List[str], + documents: List[Document], + metadatas: List[Metadata], + ) -> List[MemorySetItem]: + """Convert retrival result to memory items.""" + return [ + MemorySetItem( + memory_set_name=memory_set.name, + id=item_id, + value=document + if memory_set.item_type is str + else memory_set.item_type.model_validate_json(document), + compacted=metadata["compacted"], + created_time=DatetimeRange( + start=datetime.fromisoformat( + cast("str", metadata["created_time_start"]) + ), + end=datetime.fromisoformat(cast("str", metadata["created_time_end"])), + ), + last_accessed_time=datetime.fromisoformat(metadata["last_accessed_time"]), + ) + for item_id, document, metadata in zip( + ids, documents, metadatas, strict=False + ) + ] + + def _generate_embedding(self, text: str) -> list[float] | None: + """Generate embedding for text. + + If no embedding model configured, return None to use the default embedding + generated by chroma. + """ + if self.embedding_model is not None: + if isinstance(self.embedding_model, str): + self.embedding_model = cast( + "BaseEmbeddingModelSetup", + self.runner_context.get_resource( + self.embedding_model, ResourceType.EMBEDDING_MODEL + ), + ) + return self.embedding_model.embed(text=text) + else: + return None + + def _update_set_metadata(self, memory_set: MemorySet) -> None: + """Update metadata for memory set.""" + self.client.get_collection(name=memory_set.name).modify( + metadata={"serialization": memory_set.model_dump_json()} + ) + + def _update_items_metadata( + self, memory_set: MemorySet, ids: List[str], metadatas: List[Metadata] + ) -> None: + """Update metadata for retrieved memory items.""" + current_timestamp = datetime.now(timezone.utc).isoformat() + update_metadatas = [] + for metadata in metadatas: + metadata = dict(metadata) + metadata["last_accessed_time"] = current_timestamp + update_metadatas.append(metadata) + self.client.get_collection(name=memory_set.name).update( + ids=ids, metadatas=update_metadatas + ) + + def _reduce(self, memory_set: MemorySet) -> None: + """Reduce memory set size.""" + reduce_setup: ReduceSetup = memory_set.reduce_setup + if reduce_setup.strategy == ReduceStrategy.TRIM: + self._trim(memory_set) + elif reduce_setup.strategy == ReduceStrategy.SUMMARIZE: + self._summarize(memory_set) + else: + msg = f"Unknown reduce strategy: {reduce_setup.strategy}" + raise RuntimeError(msg) + memory_set.reduced = True + self._update_set_metadata(memory_set) + + def _trim(self, memory_set: MemorySet) -> None: + reduce_setup: ReduceSetup = memory_set.reduce_setup + n = reduce_setup.arguments.get("n") + self.delete(memory_set=memory_set, offset=0, n=n) + + def _summarize(self, memory_set: MemorySet) -> None: + # get arguments + reduce_setup: ReduceSetup = memory_set.reduce_setup + n = reduce_setup.arguments.get("n") + + # retrieve items involved + items: List[MemorySetItem] = self.slice(memory_set=memory_set, offset=0, n=n) + + response: ChatMessage = summarize( + items, memory_set.item_type, memory_set.reduce_setup, self.runner_context + ) + + # update memory set + if memory_set.item_type == ChatMessage: + text = ChatMessage( + role=MessageRole.USER, content=response.content + ).model_dump_json() + else: + text = response.content + + start = min([item.created_time.start for item in items]).isoformat() + end = max([item.created_time.end for item in items]).isoformat() + + # to keep the addition order for items in collection, update the exist item + # rather than add a new item. + self.update( + memory_set=memory_set, + item_id=memory_set.item_ids[0], + text=text, + metadata={ + "compacted": True, + "created_time_start": start, + "created_time_end": end, + "last_accessed_time": max( + [item.last_accessed_time for item in items] + ).isoformat(), + }, + ) + + # delete other items involved in reduction. + self.delete(memory_set=memory_set, offset=1, n=n - 1) diff --git a/python/flink_agents/runtime/memory/reduce_functions.py b/python/flink_agents/runtime/memory/reduce_functions.py new file mode 100644 index 00000000..1902c566 --- /dev/null +++ b/python/flink_agents/runtime/memory/reduce_functions.py @@ -0,0 +1,80 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import TYPE_CHECKING, List, Type, cast + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.memory.long_term_memory import MemorySetItem, ReduceSetup +from flink_agents.api.resource import ResourceType +from flink_agents.api.runner_context import RunnerContext + +if TYPE_CHECKING: + from flink_agents.api.chat_models.chat_model import BaseChatModelSetup + from flink_agents.api.prompts.prompt import Prompt + + +def summarize( + memory_set_items: List[MemorySetItem], + item_type: Type, + reduce_setup: ReduceSetup, + ctx: RunnerContext, +) -> ChatMessage: + """Util functions to summarize the items by llm.""" + # get arguments + reduce_setup.arguments.get("n") + model_name = reduce_setup.arguments.get("model") + prompt = reduce_setup.arguments.get("prompt") + + msgs: List[ChatMessage] + if item_type == ChatMessage: + msgs = [item.value for item in memory_set_items] + else: + msgs = [ + ChatMessage(role=MessageRole.USER, content=str(item.value)) + for item in memory_set_items + ] + + # generate summary + model: BaseChatModelSetup = cast( + "BaseChatModelSetup", + ctx.get_resource(name=model_name, type=ResourceType.CHAT_MODEL), + ) + input_variable = {} + for msg in msgs: + input_variable.update(msg.extra_args) + + if prompt is not None: + if isinstance(prompt, str): + prompt: Prompt = cast( + "Prompt", + ctx.get_resource(prompt, ResourceType.PROMPT), + ) + prompt_messages = prompt.format_messages( + role=MessageRole.USER, **input_variable + ) + msgs.extend(prompt_messages) + else: + msgs.append( + ChatMessage( + role=MessageRole.USER, + content="Create a summary of the conversation above", + ) + ) + + response: ChatMessage = model.chat(messages=msgs) + + return response diff --git a/python/flink_agents/runtime/memory/tests/__init__.py b/python/flink_agents/runtime/memory/tests/__init__.py new file mode 100644 index 00000000..e154fadd --- /dev/null +++ b/python/flink_agents/runtime/memory/tests/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/runtime/memory/tests/start_chroma_server.sh b/python/flink_agents/runtime/memory/tests/start_chroma_server.sh new file mode 100644 index 00000000..d7c3a0b8 --- /dev/null +++ b/python/flink_agents/runtime/memory/tests/start_chroma_server.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# + +path=$1 +chroma run --path $path diff --git a/python/flink_agents/runtime/memory/tests/test_chroma_long_term_memory.py b/python/flink_agents/runtime/memory/tests/test_chroma_long_term_memory.py new file mode 100644 index 00000000..6ca544ba --- /dev/null +++ b/python/flink_agents/runtime/memory/tests/test_chroma_long_term_memory.py @@ -0,0 +1,246 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +import signal +import subprocess +import tempfile +import time +from pathlib import Path +from typing import TYPE_CHECKING, Generator, List +from unittest.mock import create_autospec + +import pytest + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.memory.long_term_memory import ( + MemorySet, + ReduceSetup, +) +from flink_agents.api.resource import Resource, ResourceType +from flink_agents.api.runner_context import RunnerContext +from flink_agents.integrations.chat_models.ollama_chat_model import ( + OllamaChatModelConnection, + OllamaChatModelSetup, +) +from flink_agents.integrations.embedding_models.local.ollama_embedding_model import ( + OllamaEmbeddingModelConnection, + OllamaEmbeddingModelSetup, +) +from flink_agents.runtime.memory.chroma_long_term_memory import ChromaLongTermMemory + +if TYPE_CHECKING: + from flink_agents.api.memory.long_term_memory import ( + MemorySetItem, + ) + +current_dir = Path(__file__).parent + + +@pytest.fixture(scope="module") +def start_chroma() -> Generator: # noqa: D103 + chromadb_path = tempfile.mkdtemp() + print(f"Starting ChromaDB in {chromadb_path}...") + process = subprocess.Popen( + ["bash", f"{current_dir}/start_chroma_server.sh", chromadb_path], + preexec_fn=os.setsid, + ) + time.sleep(10) + yield + # clean up running chroma process + os.killpg(os.getpgid(process.pid), signal.SIGTERM) + + +@pytest.fixture(scope="module") +def long_term_memory() -> ChromaLongTermMemory: # noqa: D103 + embedding_model_connection = OllamaEmbeddingModelConnection() + + def get_embed_connection(name: str, type: ResourceType) -> Resource: + return embedding_model_connection + + embedding_model = OllamaEmbeddingModelSetup( + get_resource=get_embed_connection, + connection="embedding_model_connection", + model="nomic-embed-text", + ) + + chat_model_connection = OllamaChatModelConnection() + + def get_chat_connection(name: str, type: ResourceType) -> Resource: + return chat_model_connection + + chat_model = OllamaChatModelSetup( + get_resource=get_chat_connection, + connection="chat_model_connection", + model="qwen3:8b", + ) + + def get_resource(name: str, type: ResourceType) -> Resource: + if type == ResourceType.CHAT_MODEL: + return chat_model + else: + return embedding_model + + mock_runner_context = create_autospec(RunnerContext, instance=True) + mock_runner_context.get_resource = get_resource + + use_ollama = os.environ.get("USE_OLLAMA") + + return ChromaLongTermMemory( + runner_context=mock_runner_context, + database="bc0b2ad61ecd4a615d92ce25390f61ad.00001", + embedding_model="embedding_model" if use_ollama else None, + host="localhost", + port=8000, + ) + + +def prepare_memory_set( # noqa: D103 + long_term_memory: ChromaLongTermMemory, + reduce_setup: ReduceSetup = ReduceSetup.trim_setup(10), # noqa:B008 +) -> (MemorySet, List[ChatMessage]): + memory_set: MemorySet = long_term_memory.create_memory_set( + name="chat_history", + item_type=ChatMessage, + capacity=100, + reduce_setup=reduce_setup, + ) + + msgs: List[ChatMessage] = [] + for i in range(20): + msg = ChatMessage(role=MessageRole.USER, content=f"This is the no.{i} message.") + msgs.append(msg) + memory_set.add(item=msg) + + return memory_set, msgs + + +def test_get_memory_set( # noqa:D103 + start_chroma: Generator, long_term_memory: ChromaLongTermMemory +) -> None: + memory_set, _ = prepare_memory_set(long_term_memory) + retrieved = long_term_memory.get_memory_set(memory_set.name) + assert retrieved == memory_set + + long_term_memory.delete_memory_set(name="chat_history") + + +def test_add_and_get( # noqa:D103 + start_chroma: Generator, long_term_memory: ChromaLongTermMemory +) -> None: + memory_set, msgs = prepare_memory_set(long_term_memory) + + retrieved: List[MemorySetItem] = memory_set.get() + retrieved_msgs = [item.value for item in retrieved] + + assert retrieved_msgs == msgs + + long_term_memory.delete_memory_set(name="chat_history") + + +def test_get_recent_n( # noqa:D103 + start_chroma: Generator, long_term_memory: ChromaLongTermMemory +) -> None: + memory_set, msgs = prepare_memory_set(long_term_memory) + + retrieved: List[MemorySetItem] = memory_set.get_recent(10) + retrieved_msgs = [item.value for item in retrieved] + + assert retrieved_msgs == msgs[10:] + + long_term_memory.delete_memory_set(name="chat_history") + + +def test_search( # noqa:D103 + start_chroma: Generator, long_term_memory: ChromaLongTermMemory +) -> None: + memory_set, msgs = prepare_memory_set(long_term_memory) + + retrieved: List[MemorySetItem] = memory_set.search( + query="The no.10 message", limit=1 + ) + retrieved_msgs = [item.value for item in retrieved] + + assert retrieved_msgs == msgs[10:11] + + long_term_memory.delete_memory_set(name="chat_history") + + +def test_reduce_trim( # noqa:D103 + start_chroma: Generator, long_term_memory: ChromaLongTermMemory +) -> None: + memory_set, _ = prepare_memory_set(long_term_memory) + + msgs: List[ChatMessage] = [] + + for i in range(100): + msg = ChatMessage( + role=MessageRole.USER, content=f"This is the no.{i + 20} message." + ) + msgs.append(msg) + memory_set.add(item=msg) + + retrieved: List[MemorySetItem] = memory_set.get() + retrieved_msgs = [item.value for item in retrieved] + + assert retrieved_msgs == msgs + + long_term_memory.delete_memory_set(name="chat_history") + + +@pytest.mark.skip("Depend on ollama server") +def test_reduce_summarize( # noqa:D103 + start_chroma: Generator, long_term_memory: ChromaLongTermMemory +) -> None: + memory_set, _ = prepare_memory_set( + long_term_memory, + reduce_setup=ReduceSetup.summarize_setup(n=20, model="chat_model"), + ) + + msgs: List[ChatMessage] = [] + + for i in range(100): + msg = ChatMessage( + role=MessageRole.USER, content=f"This is the no.{i + 20} message." + ) + msgs.append(msg) + memory_set.add(item=msg) + + retrieved: List[MemorySetItem] = memory_set.get() + retrieved_msgs = [item.value for item in retrieved] + + assert retrieved[0].compacted + assert retrieved[0].created_time.start < retrieved[0].created_time.end + assert memory_set.size == 82 + assert len(retrieved_msgs) == 82 + assert retrieved_msgs[1:] == msgs[19:] + + long_term_memory.delete_memory_set(name="chat_history") + + +def test_update_metadata( # noqa:D103 + start_chroma: Generator, long_term_memory: ChromaLongTermMemory +) -> None: + memory_set, msgs = prepare_memory_set(long_term_memory) + + memory_set.search(query="The no.10 message", limit=1) + retrieved: List[MemorySetItem] = memory_set.search( + query="The no.10 message", limit=1 + ) + item = retrieved[0] + + assert item.last_accessed_time > item.created_time.start