diff --git a/pgvector/bit.py b/pgvector/bit.py index 26a9d8d..cecd180 100644 --- a/pgvector/bit.py +++ b/pgvector/bit.py @@ -62,9 +62,12 @@ def from_binary(cls, value): @classmethod def _to_db(cls, value): + if value is None: + return value + if not isinstance(value, cls): - raise ValueError('expected bit') - + value = cls(value) + return value.to_text() @classmethod @@ -73,3 +76,9 @@ def _to_db_binary(cls, value): raise ValueError('expected bit') return value.to_binary() + + @classmethod + def _from_db(cls, value): + if value is None or isinstance(value, cls): + return value + return cls.from_text(value) \ No newline at end of file diff --git a/pgvector/sqlalchemy/bit.py b/pgvector/sqlalchemy/bit.py index 1ea85c3..b5f64f2 100644 --- a/pgvector/sqlalchemy/bit.py +++ b/pgvector/sqlalchemy/bit.py @@ -1,6 +1,8 @@ +import asyncpg +from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg from sqlalchemy.dialects.postgresql.base import ischema_names from sqlalchemy.types import UserDefinedType, Float - +from .. import Bit class BIT(UserDefinedType): cache_ok = True diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index c59c12e..6900e3a 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -605,7 +605,7 @@ async def test_bit(self, engine): async with async_session() as session: async with session.begin(): - embedding = asyncpg.BitString('101') if engine == asyncpg_engine else '101' + embedding = '101' session.add(Item(id=1, binary_embedding=embedding)) item = await session.get(Item, 1) assert item.binary_embedding == embedding