diff --git a/sqeleton/abcs/database_types.py b/sqeleton/abcs/database_types.py index 59843c5..43e1a56 100644 --- a/sqeleton/abcs/database_types.py +++ b/sqeleton/abcs/database_types.py @@ -24,10 +24,6 @@ class PrecisionType(ColType): rounds: Union[bool, Unknown] = Unknown -class Boolean(ColType): - precision = 0 - - class TemporalType(PrecisionType): pass @@ -82,6 +78,10 @@ def python_type(self) -> type: return decimal.Decimal +class Boolean(ColType, IKey): + precision = 0 + python_type = bool + @dataclass class StringType(ColType): python_type = str diff --git a/tests/test_database.py b/tests/test_database.py index 461e9b5..8096a1c 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -9,7 +9,7 @@ from sqeleton.queries import table, current_timestamp, NormalizeAsString, ForeignKey, Compiler from .common import TEST_MYSQL_CONN_STRING from .common import str_to_checksum, make_test_each_database_in_list, get_conn, random_table_suffix -from sqeleton.abcs.database_types import TimestampTZ +from sqeleton.abcs.database_types import TimestampTZ, Boolean, IKey TEST_DATABASES = { dbs.MySQL, @@ -95,7 +95,31 @@ def test_type_mapping(self): db.query(tbl.drop()) assert not db.query(q) - +@test_each_database +class TestColTypes(unittest.TestCase): + """Test column type implementations, especially IKey interface""" + + def test_boolean_as_ikey(self): + """Test that Boolean type implements IKey interface correctly""" + boolean_type = Boolean() + + # Test that Boolean is an instance of IKey + self.assertIsInstance(boolean_type, IKey) + + # Test that python_type property returns bool + self.assertEqual(boolean_type.python_type, bool) + + # Test precision is 0 + self.assertEqual(boolean_type.precision, 0) + + # Test make_value method converts values to bool + self.assertEqual(boolean_type.make_value(True), True) + self.assertEqual(boolean_type.make_value(False), False) + self.assertEqual(boolean_type.make_value(1), True) + self.assertEqual(boolean_type.make_value(0), False) + self.assertEqual(boolean_type.make_value("True"), True) + self.assertEqual(boolean_type.make_value(""), False) + @test_each_database class TestQueries(unittest.TestCase): def test_current_timestamp(self):