Skip to content

Commit eec352d

Browse files
committed
[api][python] Introduce long-term memory interface.
1 parent 82a2c30 commit eec352d

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#################################################################################
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#################################################################################
18+
import importlib
19+
from abc import ABC, abstractmethod
20+
from datetime import datetime
21+
from enum import Enum
22+
from typing import Any, Dict, List, Type
23+
24+
from pydantic import (
25+
BaseModel,
26+
Field,
27+
field_serializer,
28+
model_validator,
29+
)
30+
31+
from flink_agents.api.chat_message import ChatMessage
32+
from flink_agents.api.prompts.prompt import Prompt
33+
34+
35+
class ReduceStrategy(Enum):
36+
"""Strategy for reducing memory set size."""
37+
38+
TRIM = "trim"
39+
SUMMARIZE = "summarize"
40+
41+
42+
class ReduceSetup(BaseModel):
43+
"""Reduce setup for managing the storge of memory set."""
44+
45+
strategy: ReduceStrategy
46+
arguments: Dict[str, Any] = Field(default_factory=dict)
47+
48+
@staticmethod
49+
def trim_setup(n: int) -> "ReduceSetup":
50+
"""Create trim setup."""
51+
return ReduceSetup(strategy=ReduceStrategy.TRIM, arguments={"n": n})
52+
53+
@staticmethod
54+
def summarize_setup(
55+
n: int, model: str, prompt: Prompt | None = None
56+
) -> "ReduceSetup":
57+
"""Create summarize setup."""
58+
return ReduceSetup(
59+
strategy=ReduceStrategy.SUMMARIZE,
60+
arguments={"n": n, "model": model, "prompt": prompt},
61+
)
62+
63+
64+
class LongTermMemoryBackend(Enum):
65+
"""Backend for Long-Term Memory."""
66+
67+
CHROMA = "chroma"
68+
69+
70+
class DatetimeRange(BaseModel):
71+
"""Represents a datetime range."""
72+
73+
start: datetime
74+
end: datetime
75+
76+
77+
class MemoryItem(BaseModel):
78+
"""Represents a long term memory item retrieved from vector store.
79+
80+
Attributes:
81+
name: The name of the memory set this item belongs to.
82+
id: The id of this item.
83+
value: The value of this item.
84+
compacted: Whether this item has been compacted.
85+
created_at: The timestamp this item was added to the memory set.
86+
last_accessed_at: The timestamp this item was last accessed.
87+
access_frequency: The access frequency of this item.
88+
"""
89+
90+
name: str
91+
id: str
92+
value: Any
93+
compacted: bool = False
94+
created_at: DatetimeRange
95+
last_accessed_at: datetime
96+
access_frequency: int
97+
98+
99+
class MemorySet(BaseModel):
100+
"""Represents a long term memory set contains memory items.
101+
102+
Attributes:
103+
name: The name of this memory set.
104+
item_type: The type of items stored in this set.
105+
size: Current items count stored in this set.
106+
capacity: The capacity of this memory set.
107+
reduce_setup: Reduce strategy and additional arguments used to reduce memory
108+
set size.
109+
item_ids: The indices of items stored in this set.
110+
reduced: Whether this memory set has been reduced.
111+
"""
112+
113+
name: str
114+
item_type: Type[str] | Type[ChatMessage]
115+
size: int = 0
116+
capacity: int
117+
reduce_setup: ReduceSetup
118+
item_ids: List[str] = Field(default_factory=list)
119+
reduced: bool = False
120+
ltm: "BaseLongTermMemory" = Field(default=None, exclude=True)
121+
122+
@field_serializer("item_type")
123+
def _serialize_item_type(self, item_type: Type) -> Dict[str, str]:
124+
return {"module": item_type.__module__, "name": item_type.__name__}
125+
126+
@model_validator(mode="before")
127+
def _deserialize_item_type(self) -> "MemorySet":
128+
if isinstance(self["item_type"], Dict):
129+
module = importlib.import_module(self["item_type"]["module"])
130+
self["item_type"] = getattr(module, self["item_type"]["name"])
131+
return self
132+
133+
def add(self, item: str | ChatMessage) -> None:
134+
"""Add a memory item to the set, currently only support item with
135+
type str or ChatMessage.
136+
137+
If the capacity of this memory set is reached, will trigger reduce
138+
operation to manage the memory set size.
139+
140+
Args:
141+
item: The item to be inserted to this set.
142+
"""
143+
self.ltm.add(memory_set=self, memory_item=item)
144+
145+
def get(self) -> List[MemoryItem]:
146+
"""Retrieve all memory items.
147+
148+
Returns:
149+
All memory items in this set.
150+
"""
151+
return self.ltm.get(memory_set=self)
152+
153+
def get_recent(self, n: int) -> List[MemoryItem]:
154+
"""Retrieve n most recent memory items.
155+
156+
Args:
157+
n: The number of items to retrieve.
158+
159+
Returns:
160+
List of memory items retrieved, sorted by creation timestamp.
161+
"""
162+
return self.ltm.get_recent(memory_set=self, n=n)
163+
164+
def search(self, query: str, limit: int, **kwargs: Any) -> List[MemoryItem]:
165+
"""Retrieve n memory items related to the query.
166+
167+
Args:
168+
query: The query to search for.
169+
limit: The number of items to retrieve.
170+
**kwargs: Additional arguments for search.
171+
"""
172+
return self.ltm.search(memory_set=self, query=query, limit=limit, **kwargs)
173+
174+
175+
class BaseLongTermMemory(ABC, BaseModel):
176+
"""Base Abstract class for long term memory."""
177+
178+
@abstractmethod
179+
def create_memory_set(
180+
self,
181+
name: str,
182+
item_type: str | Type[ChatMessage],
183+
capacity: int,
184+
reduce_setup: ReduceSetup,
185+
) -> MemorySet:
186+
"""Create a memory set.
187+
188+
Args:
189+
name: The name of the memory set.
190+
item_type: The type of the memory item.
191+
capacity: The capacity of the memory set.
192+
reduce_setup: The reduce strategy and arguments for storge management.
193+
194+
Returns:
195+
The created memory set.
196+
"""
197+
198+
@abstractmethod
199+
def get_memory_set(self, name: str) -> MemorySet:
200+
"""Get the memory set.
201+
202+
Args:
203+
name: The name of the memory set.
204+
"""
205+
206+
@abstractmethod
207+
def delete_memory_set(self, name: str) -> None:
208+
"""Delete the memory set.
209+
210+
Args:
211+
name: The name of the memory set.
212+
"""
213+
214+
@abstractmethod
215+
def add(self, memory_set: MemorySet, memory_item: str | ChatMessage) -> None:
216+
"""Add a memory item to the named set, currently only support item with
217+
type str or ChatMessage.
218+
219+
This method will modify the `size` and `item_ids` field of memory_set.
220+
221+
Args:
222+
memory_set: The memory set to be inserted.
223+
memory_item: The item to be inserted to this set.
224+
"""
225+
226+
@abstractmethod
227+
def get(self, memory_set: MemorySet) -> List[MemoryItem]:
228+
"""Retrieve all memory items.
229+
230+
Args:
231+
memory_set: The set to be retrieved.
232+
233+
Returns:
234+
All the memory items of this set.
235+
"""
236+
237+
@abstractmethod
238+
def get_recent(self, memory_set: MemorySet, n: int) -> List[MemoryItem]:
239+
"""Retrieve n most recent memory items.
240+
241+
Args:
242+
memory_set: The set to be retrieved.
243+
n: The number of items to retrieve.
244+
245+
Returns:
246+
List of memory items retrieved, sorted by creation timestamp.
247+
"""
248+
249+
@abstractmethod
250+
def search(
251+
self, memory_set: MemorySet, query: str, limit: int, **kwargs: Any
252+
) -> List[MemoryItem]:
253+
"""Retrieve n memory items related to the query.
254+
255+
Args:
256+
memory_set: The set to be retrieved.
257+
query: The query for sematic search.
258+
limit: The number of items to retrieve.
259+
**kwargs: Additional arguments for sematic search.
260+
261+
Returns:
262+
Related memory items retrieved.
263+
"""
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#################################################################################
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#################################################################################
18+
19+
from flink_agents.api.chat_message import ChatMessage
20+
from flink_agents.api.memory.long_term_memory import MemorySet, ReduceSetup
21+
22+
23+
def test_memory_set_serialization() -> None: #noqa:D103
24+
memory_set = MemorySet(name="chat_history",
25+
item_type=ChatMessage,
26+
capacity=100,
27+
reduce_setup=ReduceSetup.trim_setup(10))
28+
29+
json_data = memory_set.model_dump_json()
30+
31+
memory_set_deserialized = MemorySet.model_validate_json(json_data)
32+
33+
assert memory_set_deserialized == memory_set

0 commit comments

Comments
 (0)