diff --git a/intervals/__init__.py b/intervals/__init__.py index 5241270..4c949ad 100644 --- a/intervals/__init__.py +++ b/intervals/__init__.py @@ -50,6 +50,9 @@ def canonicalize(interval, lower_inc=True, upper_inc=False): if not interval.discrete: raise TypeError('Only discrete ranges can be canonicalized') + if interval.empty: + return interval + lower, lower_inc = canonicalize_lower(interval, lower_inc) upper, upper_inc = canonicalize_upper(interval, upper_inc) @@ -69,6 +72,21 @@ def wrapper(self, arg): return wrapper +class ClosedInterval(type): + """ + Supports initialization of intervals using square brackets and makes them + closed intervals. + + eg. + + IntInterval[1, 4] == IntInterval([1, 4]) + """ + def __getitem__(self, bounds): + lower_inc = upper_inc = True + return self(bounds, lower_inc, upper_inc) + + +@six.add_metaclass(ClosedInterval) @total_ordering class AbstractInterval(object): step = None @@ -148,6 +166,20 @@ def __init__( 30 """ + + # This if-block adds support for parentheses as open intervals. + # Note: If the interval is initialized with the parentheses with two + # objects of same type, eg. + # IntInterval(1, 4) + # the bounds and lower_inc are received of that type and + # upper_inc is None. + # + # eg. + # IntInterval(1, 4) == IntInterval((1, 4)) + if type(bounds) == type(lower_inc) and not upper_inc: + bounds = (bounds, lower_inc) + lower_inc = upper_inc = None + self.lower, self.upper, self.lower_inc, self.upper_inc = ( self.parser(bounds, lower_inc, upper_inc) ) @@ -307,6 +339,18 @@ def radius(self): def degenerate(self): return self.upper == self.lower + @property + def empty(self): + if self.discrete and not self.degenerate: + return ( + self.upper - self.lower == self.step + and not (self.upper_inc or self.lower_inc) + ) + return ( + self.upper == self.lower + and not (self.lower_inc and self.upper_inc) + ) + @property def centre(self): return float((self.lower + self.upper)) / 2 @@ -365,16 +409,27 @@ def __and__(self, other): """ Defines the intersection operator """ + if self.upper < other.lower or other.upper < self.lower: + return self.__class__((0, 0)) if self.lower <= other.lower <= self.upper: - return self.__class__([ + intersection = self.__class__([ other.lower, other.upper if other.upper < self.upper else self.upper ]) + intersection.lower_inc = other.lower_inc + intersection.upper_inc = ( + other.upper_inc if other.upper < self.upper else self.upper_inc + ) elif self.lower <= other.upper <= self.upper: - return self.__class__([ + intersection = self.__class__([ other.lower if other.lower > self.lower else self.lower, other.upper ]) + intersection.lower_inc = ( + other.lower_inc if other.lower > self.lower else self.lower_inc + ) + intersection.upper_inc = other.upper_inc + return intersection class IntInterval(AbstractInterval): diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 6840bb8..eaac817 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -88,6 +88,12 @@ def test_supports_integers(self): assert interval.lower_inc assert interval.upper_inc + def test_uses_two_numbers_with_parentheses_as_open_interval(self): + assert IntInterval(1, 2) == IntInterval((1, 2)) + + def test_uses_two_numbers_with_square_brackets_as_closed_interval(self): + assert IntInterval[1, 2] == IntInterval([1, 2]) + @mark.parametrize('number_range', ( (3, 2), diff --git a/tests/test_operators.py b/tests/test_operators.py index 731dee7..161c9ee 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -63,3 +63,28 @@ class TestDiscreteRangeComparison(object): )) def test_eq_operator(self, interval, interval2): assert IntInterval(interval) == IntInterval(interval2) + + +class TestBinaryOperators(object): + @mark.parametrize(('interval1', 'interval2', 'result'), ( + ((2, 3), (3, 4), (3, 3)), + ((2, 3), [3, 4], '[3, 3)'), + ((2, 5), (3, 10), (3, 5)), + ('(2, 3]', '[3, 4)', [3, 3]), + ('(2, 10]', '[3, 40]', [3, 10]), + ((2, 10), (3, 8), (3, 8)), + )) + def test_and_operator(self, interval1, interval2, result): + assert ( + IntInterval(interval1) & IntInterval(interval2) == + IntInterval(result) + ) + + @mark.parametrize(('interval1', 'interval2', 'empty'), ( + ((2, 3), (3, 4), True), + ((2, 3), [3, 4], True), + ([2, 3], (3, 4), True), + ('(2, 3]', '[3, 4)', False), + )) + def test_and_operator_for_empty_results(self, interval1, interval2, empty): + assert (IntInterval(interval1) & IntInterval(interval2)).empty == empty diff --git a/tests/test_properties.py b/tests/test_properties.py index cf8c4a5..a818586 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -70,6 +70,21 @@ def test_open(self, number_range, is_open): def test_closed(self, number_range, is_closed): assert IntInterval(number_range).closed == is_closed + @mark.parametrize(('number_range', 'empty'), + ( + ((2, 3), True), + ([2, 3], False), + ([2, 2], False), + ((2, 2), True), + ('[2, 2)', True), + ('(2, 2]', True), + ('[2, 3)', False), + ((2, 10), False), + ) + ) + def test_empty(self, number_range, empty): + assert IntInterval(number_range).empty == empty + @mark.parametrize(('number_range', 'degenerate'), ( ((2, 4), False),