From 40c9eb50e82e9d72b6f6d78fad06e767bb6554c9 Mon Sep 17 00:00:00 2001 From: Cyril Danilevski Date: Wed, 10 Nov 2021 21:41:00 +0100 Subject: [PATCH] Validate ThreadContext.num_workers --- pasha/context.py | 4 ++-- pasha/tests/test_context.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/pasha/context.py b/pasha/context.py index 814e450..c87515c 100644 --- a/pasha/context.py +++ b/pasha/context.py @@ -159,9 +159,9 @@ class PoolContext(MapContext): """ def __init__(self, num_workers=None): - from os import cpu_count - if num_workers is None: + if not num_workers: + from os import cpu_count num_workers = min(cpu_count() // 2, 10) super().__init__(num_workers=num_workers) diff --git a/pasha/tests/test_context.py b/pasha/tests/test_context.py index 9ab60b9..ee1784f 100644 --- a/pasha/tests/test_context.py +++ b/pasha/tests/test_context.py @@ -10,7 +10,7 @@ import numpy as np import pasha as psh -from pasha.context import MapContext +from pasha.context import MapContext, ThreadContext from pasha.functor import Functor @@ -123,3 +123,12 @@ def test_set_default_context_string(ctx_str, expected_type): psh.set_default_context(ctx_str) assert isinstance(psh.get_default_context(), expected_type) + + +@pytest.mark.parametrize( + ['num_workers', 'expected'], [(None, 1), (0, 1), (4, 4)]) +def test_array(num_workers, expected): + """Test ThreadContext validation.""" + + ctx = ThreadContext(num_workers=num_workers) + assert ctx.num_workers >= expected