diff --git a/pasha/context.py b/pasha/context.py index 814e450..c87515c 100644 --- a/pasha/context.py +++ b/pasha/context.py @@ -159,9 +159,9 @@ class PoolContext(MapContext): """ def __init__(self, num_workers=None): - from os import cpu_count - if num_workers is None: + if not num_workers: + from os import cpu_count num_workers = min(cpu_count() // 2, 10) super().__init__(num_workers=num_workers) diff --git a/pasha/tests/test_context.py b/pasha/tests/test_context.py index 9ab60b9..ee1784f 100644 --- a/pasha/tests/test_context.py +++ b/pasha/tests/test_context.py @@ -10,7 +10,7 @@ import numpy as np import pasha as psh -from pasha.context import MapContext +from pasha.context import MapContext, ThreadContext from pasha.functor import Functor @@ -123,3 +123,12 @@ def test_set_default_context_string(ctx_str, expected_type): psh.set_default_context(ctx_str) assert isinstance(psh.get_default_context(), expected_type) + + +@pytest.mark.parametrize( + ['num_workers', 'expected'], [(None, 1), (0, 1), (4, 4)]) +def test_array(num_workers, expected): + """Test ThreadContext validation.""" + + ctx = ThreadContext(num_workers=num_workers) + assert ctx.num_workers >= expected