Skip to content

Commit 4095634

Browse files
committed
add class for custom cache
1 parent b355319 commit 4095634

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
import pickle
3+
import threading
4+
from collections import OrderedDict
5+
from typing import Any
6+
7+
8+
class PerSmilesPerModelLRUCache:
9+
def __init__(self, max_size: int = 100, persist_path: str | None = None):
10+
self._cache = OrderedDict()
11+
self._max_size = max_size
12+
self._lock = threading.Lock()
13+
self._persist_path = persist_path
14+
15+
self.hits = 0
16+
self.misses = 0
17+
18+
if self._persist_path:
19+
self._load_cache()
20+
21+
def get(self, smiles: str, model_name: str) -> Any | None:
22+
key = (smiles, model_name)
23+
with self._lock:
24+
if key in self._cache:
25+
self._cache.move_to_end(key)
26+
self.hits += 1
27+
return self._cache[key]
28+
else:
29+
self.misses += 1
30+
return None
31+
32+
def set(self, smiles: str, model_name: str, value: Any) -> None:
33+
key = (smiles, model_name)
34+
with self._lock:
35+
if key in self._cache:
36+
self._cache.move_to_end(key)
37+
self._cache[key] = value
38+
if len(self._cache) > self._max_size:
39+
self._cache.popitem(last=False)
40+
41+
def clear(self) -> None:
42+
self._save_cache()
43+
with self._lock:
44+
self._cache.clear()
45+
self.hits = 0
46+
self.misses = 0
47+
if self._persist_path and os.path.exists(self._persist_path):
48+
os.remove(self._persist_path)
49+
50+
def stats(self) -> dict:
51+
return {"hits": self.hits, "misses": self.misses}
52+
53+
def _save_cache(self) -> None:
54+
"""Serialize the cache to disk."""
55+
if not self._persist_path:
56+
try:
57+
with open(self._persist_path, "wb") as f:
58+
pickle.dump(self._cache, f)
59+
except Exception as e:
60+
print(f"[Cache Save Error] {e}")
61+
62+
def _load_cache(self) -> None:
63+
"""Load the cache from disk."""
64+
if os.path.exists(self._persist_path):
65+
try:
66+
with open(self._persist_path, "rb") as f:
67+
loaded = pickle.load(f)
68+
if isinstance(loaded, OrderedDict):
69+
self._cache = loaded
70+
except Exception as e:
71+
print(f"[Cache Load Error] {e}")
72+
73+
74+
if __name__ == "__main__":
75+
# Example usage
76+
cache = PerSmilesPerModelLRUCache(max_size=100, persist_path="cache.pkl")

0 commit comments

Comments
 (0)