diff --git a/annotation/typed.py b/annotation/typed.py index 1c3bb29..577fd11 100644 --- a/annotation/typed.py +++ b/annotation/typed.py @@ -455,6 +455,10 @@ def _check_type_constraint(value, constraint): else: return False +class IncorrectArgument(TypeError): + def __str__(self): + return '{} expected {!r} actual {!r}'.format(*self.args) + def _check_argument_types(signature, *args, **kwargs): """Check that the arguments of a function match the given signature.""" bound_arguments = signature.bind(*args, **kwargs) @@ -464,8 +468,11 @@ def _check_argument_types(signature, *args, **kwargs): if annotation is EMPTY_ANNOTATION: annotation = AnyType if not _check_type_constraint(value, annotation): - raise TypeError('Incorrect type for "{0}"'.format(name)) + raise IncorrectArgument(name, annotation, value) +class IncorrectReturnType(TypeError): + def __str__(self): + return 'expected: {!r} actual: {!r}'.format(*self.args) def _check_return_type(signature, return_value): """Check that the return value of a function matches the signature.""" @@ -473,7 +480,7 @@ def _check_return_type(signature, return_value): if annotation is EMPTY_ANNOTATION: annotation = AnyType if not _check_type_constraint(return_value, annotation): - raise TypeError('Incorrect return type') + raise IncorrectReturnType(annotation, return_value) return return_value diff --git a/tests/test_annontations.py b/tests/test_annontations.py index d1aa269..f9e38a1 100644 --- a/tests/test_annontations.py +++ b/tests/test_annontations.py @@ -2,8 +2,7 @@ from collections import namedtuple from annotation.typed import (typechecked, Interface, union, AnyType, predicate, - optional, typedef, options, only) - + optional, typedef, options, only, IncorrectReturnType, IncorrectArgument) class TypecheckedTest(unittest.TestCase): @@ -15,8 +14,8 @@ def test(a: int): return a self.assertEqual(1, test(1)) - self.assertRaises(TypeError, test, 'string') - self.assertRaises(TypeError, test, 1.2) + self.assertRaises(IncorrectArgument, test, 'string') + self.assertRaises(IncorrectArgument, test, 1.2) def test_single_argument_with_class(self): @@ -29,7 +28,7 @@ def test(a: MyClass): value = MyClass() self.assertEqual(value, test(value)) - self.assertRaises(TypeError, test, 'string') + self.assertRaises(IncorrectArgument, test, 'string') def test_single_argument_with_subclass(self): @@ -42,7 +41,7 @@ def test(a: MyClass): value = MySubClass() self.assertEqual(value, test(value)) - self.assertRaises(TypeError, test, 'string') + self.assertRaises(IncorrectArgument, test, 'string') def test_single_argument_with_union_annotation(self): from decimal import Decimal @@ -54,7 +53,7 @@ def test(a: union(int, float, Decimal)): self.assertEqual(1, test(1)) self.assertEqual(1.5, test(1.5)) self.assertEqual(Decimal('2.5'), test(Decimal('2.5'))) - self.assertRaises(TypeError, test, 'string') + self.assertRaises(IncorrectArgument, test, 'string') def test_single_argument_with_predicate_annotation(self): @@ -63,7 +62,7 @@ def test(a: predicate(lambda x: x > 0)): return a self.assertEqual(1, test(1)) - self.assertRaises(TypeError, test, 0) + self.assertRaises(IncorrectArgument, test, 0) def test_single_argument_with_optional_annotation(self): @@ -91,7 +90,7 @@ def f2(a: str, b: str): pass self.assertEqual({}, test(f1)) - self.assertRaises(TypeError, test, f2) + self.assertRaises(IncorrectArgument, test, f2) def test_single_argument_with_options_annotation(self): @@ -101,7 +100,7 @@ def test(a: options('open', 'write')): self.assertEqual('open', test('open')) self.assertEqual('write', test('write')) - self.assertRaises(TypeError, test, 'other') + self.assertRaises(IncorrectArgument, test, 'other') def test_single_argument_with_only_annotation(self): @@ -110,7 +109,7 @@ def test(a: only(int)): return a self.assertEqual(1, test(1)) - self.assertRaises(TypeError, test, True) + self.assertRaises(IncorrectArgument, test, True) def test_single_argument_with_interface(self): @@ -129,7 +128,7 @@ def test(a: Test): return 1 self.assertEqual(1, test(TestImplementation())) - self.assertRaises(TypeError, test, Other()) + self.assertRaises(IncorrectArgument, test, Other()) def test_single_argument_with_no_annotation(self): @@ -148,9 +147,9 @@ def test(a: int, b: str): return a, b self.assertEqual((1, 'string'), test(1, 'string')) - self.assertRaises(TypeError, test, 1, 1) - self.assertRaises(TypeError, test, 'string', 'string') - self.assertRaises(TypeError, test, 'string', 1) + self.assertRaises(IncorrectArgument, test, 1, 1) + self.assertRaises(IncorrectArgument, test, 'string', 'string') + self.assertRaises(IncorrectArgument, test, 'string', 1) def test_single_argument_with_none_value(self): @@ -158,7 +157,7 @@ def test_single_argument_with_none_value(self): def test(a: int): return a - self.assertRaises(TypeError, test, None) + self.assertRaises(IncorrectArgument, test, None) def test_multiple_arguments_some_with_annotations(self): @@ -168,8 +167,8 @@ def test(a, b: str): self.assertEqual((1, 'string'), test(1, 'string')) self.assertEqual(('string', 'string'), test('string', 'string')) - self.assertRaises(TypeError, test, 1, 1) - self.assertRaises(TypeError, test, 'string', 1) + self.assertRaises(IncorrectArgument, test, 1, 1) + self.assertRaises(IncorrectArgument, test, 'string', 1) def test_return_with_builtin_type(self): @@ -178,8 +177,8 @@ def test(a) -> int: return a self.assertEqual(1, test(1)) - self.assertRaises(TypeError, test, 'string') - self.assertRaises(TypeError, test, 1.2) + self.assertRaises(IncorrectReturnType, test, 'string') + self.assertRaises(IncorrectReturnType, test, 1.2) def test_return_with_class(self): @@ -195,7 +194,7 @@ def test2() -> MyClass: return 1 self.assertIsInstance(test1(), MyClass) - self.assertRaises(TypeError, test2) + self.assertRaises(IncorrectReturnType, test2) def test_return_with_sublass(self): @@ -211,7 +210,7 @@ def test2() -> MyClass: return 1 self.assertIsInstance(test1(), MyClass) - self.assertRaises(TypeError, test2) + self.assertRaises(IncorrectReturnType, test2) def test_return_with_union(self): @@ -221,7 +220,7 @@ def test(a) -> union(int, float): self.assertEqual(1, test(1)) self.assertEqual(1.1, test(1.1)) - self.assertRaises(TypeError, test, 'string') + self.assertRaises(IncorrectReturnType, test, 'string') def test_return_with_interface(self): @@ -244,7 +243,7 @@ def test2() -> Test: return 1 self.assertIsInstance(test1(), TestImplementation) - self.assertRaises(TypeError, test2) + self.assertRaises(IncorrectReturnType, test2) def test_return_with_none_value(self): @@ -252,7 +251,7 @@ def test_return_with_none_value(self): def test(a) -> int: return a - self.assertRaises(TypeError, test, None) + self.assertRaises(IncorrectReturnType, test, None) def test_complex_types(self): simple_types = [ 'a', 1, None, 1.1, False ] @@ -328,12 +327,12 @@ def test(a: check.test) -> check.test: @typechecked def test(a: check.test): pass - self.assertRaises(TypeError, test, value) + self.assertRaises(IncorrectArgument, test, value) @typechecked def test(a) -> check.test: return value - self.assertRaises(TypeError, test, value) + self.assertRaises(IncorrectReturnType, test, value) class UnionTest(unittest.TestCase):