Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions nion/utils/Threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import threading
from collections.abc import MutableMapping
from typing import TypeVar, Generic, Iterator, Optional, Tuple, List, Dict

K = TypeVar('K')
V = TypeVar('V')

class ThreadSafeDict(MutableMapping[K, V], Generic[K, V]):
def __init__(self, *args, **kwargs) -> None:
self.__data: Dict[K, V] = dict(*args, **kwargs)
self.__lock = threading.RLock()

# --- Core mapping methods ---
def __getitem__(self, key: K) -> V:
with self.__lock:
return self.__data[key]

def __setitem__(self, key: K, value: V) -> None:
with self.__lock:
self.__data[key] = value

def __delitem__(self, key: K) -> None:
with self.__lock:
del self.__data[key]

def __iter__(self) -> Iterator[K]:
with self.__lock:
return iter(list(self.__data))

def __len__(self) -> int:
with self.__lock:
return len(self.__data)

def __contains__(self, key: object) -> bool:
with self.__lock:
return key in self.__data

# --- Optional helpers ---
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
with self.__lock:
return self.__data.get(key, default)

def setdefault(self, key: K, default: V) -> V:
with self.__lock:
return self.__data.setdefault(key, default)

def pop(self, key: K, default: Optional[V] = None) -> Optional[V]:
with self.__lock:
return self.__data.pop(key, default)

def popitem(self) -> Tuple[K, V]:
with self.__lock:
return self.__data.popitem()

def update(self, *args, **kwargs) -> None:
with self.__lock:
self.__data.update(*args, **kwargs)

def clear(self) -> None:
with self.__lock:
self.__data.clear()

def keys(self) -> List[K]:
with self.__lock:
return list(self.__data.keys())

def values(self) -> List[V]:
with self.__lock:
return list(self.__data.values())

def items(self) -> List[Tuple[K, V]]:
with self.__lock:
return list(self.__data.items())

def copy(self) -> 'ThreadSafeDict[K, V]':
with self.__lock:
return ThreadSafeDict(self.__data.copy())

def to_dict(self) -> Dict[K, V]:
with self.__lock:
return dict(self.__data)

def __repr__(self) -> str:
with self.__lock:
return f"{self.__class__.__name__}({self.__data!r})"

# --- Optional: context manager for atomic operations ---
def locked(self) -> threading.RLock:
return self.__lock
138 changes: 138 additions & 0 deletions nion/utils/test/Threading_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# standard libraries
import logging
import threading
import time
import unittest

# third party libraries
# None

# local libraries
from nion.utils.Threading import ThreadSafeDict

class TestThreadSafeDict(unittest.TestCase):
def test_single_thread_operations(self) -> None:
d = ThreadSafeDict[str, int]()
d["a"] = 1
self.assertEqual(d["a"], 1)
d["b"] = 2
self.assertIn("b", d)
self.assertEqual(len(d), 2)
d.pop("a")
self.assertNotIn("a", d)

def test_multi_thread_safe_without_locked(self) -> None:
d = ThreadSafeDict[str, int]()

def worker(start: int, end: int):
for i in range(start, end):
# Increment counter atomically using built-in thread-safe __setitem__ + __getitem__
key = f"key-{i % 10}"
with d.locked():
d[key] = d.get(key, 0) + 1

threads = [threading.Thread(target=worker, args=(0, 1000)) for _ in range(5)]
for t in threads: t.start()
for t in threads: t.join()

# Each key should have been incremented 5 * (1000/10) = 500 times
for i in range(10):
self.assertEqual(d[f"key-{i}"], 500)

def test_locked_context_atomic(self) -> None:
d = ThreadSafeDict[str, int]()
d["counter"] = 0

def worker(n: int):
with d.locked():
for local_counter in range(n):
# Read-modify-write sequence
val = d["counter"]
modulo = val % n
assert local_counter == modulo
time.sleep(0.0001) # simulate some delay
d["counter"] = val + 1

threads = [threading.Thread(target=worker, args=(100,)) for _ in range(5)]
for t in threads: t.start()
for t in threads: t.join()

# Total increments = 5 * 100 = 500
self.assertEqual(d["counter"], 500)

def test_fine_grained_locked_increment(self) -> None:
"""
Tests multiple threads performing single-step increments with d.locked(),
demonstrating that fine-grained access is safely serialized.
"""
d = ThreadSafeDict[str, int]()
d["counter"] = 0

num_threads = 5
increments_per_thread = 1000

def worker():
for _ in range(increments_per_thread):
# Fine-grained atomic increment
with d.locked():
d["counter"] = d.get("counter", 0) + 1

threads = [threading.Thread(target=worker) for _ in range(num_threads)]
for t in threads:
t.start()
for t in threads:
t.join()

# All increments should be accounted for
expected_total = num_threads * increments_per_thread
self.assertEqual(d["counter"], expected_total)

def test_copy_and_to_dict_thread_safety(self) -> None:
d = ThreadSafeDict[str, int]()
d["a"] = 1
d["b"] = 2

# Test copy
copy_d = d.copy()
self.assertEqual(copy_d.to_dict(), d.to_dict())

# Test to_dict
plain = d.to_dict()
self.assertEqual(plain, {"a": 1, "b": 2})

def test_threadsafedict_overhead(self):
N_ITEMS = 1000
N_OPS = 1000

# Prepare test data
keys = list(range(N_ITEMS))
values = list(range(N_ITEMS))

# --- Built-in dict ---
d = dict(zip(keys, values))
start = time.time()
for _ in range(N_OPS):
for k in keys:
d[k] = d[k] + 1
_ = d[k]
dict_time = time.time() - start

# --- ThreadSafeDict ---
d2 = ThreadSafeDict[int, int](zip(keys, values))
start = time.time()
for _ in range(N_OPS):
for k in keys:
d2[k] = d2[k] + 1
_ = d2[k]
tsd_time = time.time() - start

multiplier = tsd_time / dict_time if dict_time > 0 else float('inf')

print(f"Built-in dict time: {dict_time:.6f}s")
print(f"ThreadSafeDict time: {tsd_time:.6f}s")
print(f"Rough overhead multiplier: {multiplier:.2f}x")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
unittest.main()
Loading