From 25d3e9c95fd8f64e230f9362752b327cbce3d12e Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Thu, 9 Oct 2025 23:26:22 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20composite=20primary=20with?= =?UTF-8?q?=20AfterValidator/Annotated?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_main.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..c0e936ee72 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,9 +1,14 @@ from typing import List, Optional import pytest +from sqlalchemy import inspect +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import RelationshipProperty from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from typing_extensions import Annotated + +from .conftest import needs_pydanticv2 def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): @@ -125,3 +130,73 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_composite_primary_key(clear_sqlmodel): + class UserPermission(SQLModel, table=True): + user_id: int = Field(primary_key=True) + resource_id: int = Field(primary_key=True) + permission: str + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + insp: Inspector = inspect(engine) + pk_constraint = insp.get_pk_constraint(str(UserPermission.__tablename__)) + + assert len(pk_constraint["constrained_columns"]) == 2 + assert "user_id" in pk_constraint["constrained_columns"] + assert "resource_id" in pk_constraint["constrained_columns"] + + with Session(engine) as session: + perm1 = UserPermission(user_id=1, resource_id=1, permission="read") + perm2 = UserPermission(user_id=1, resource_id=2, permission="write") + session.add(perm1) + session.add(perm2) + session.commit() + + with pytest.raises(IntegrityError): + with Session(engine) as session: + perm3 = UserPermission(user_id=1, resource_id=1, permission="admin") + session.add(perm3) + session.commit() + + +@needs_pydanticv2 +def test_composite_primary_key_and_validator(clear_sqlmodel): + from pydantic import AfterValidator + + def validate_resource_id(value: int) -> int: + if value < 1: + raise ValueError("Resource ID must be positive") + return value + + class UserPermission(SQLModel, table=True): + user_id: int = Field(primary_key=True) + resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field( + primary_key=True + ) + permission: str + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + insp: Inspector = inspect(engine) + pk_constraint = insp.get_pk_constraint(str(UserPermission.__tablename__)) + + assert len(pk_constraint["constrained_columns"]) == 2 + assert "user_id" in pk_constraint["constrained_columns"] + assert "resource_id" in pk_constraint["constrained_columns"] + + with Session(engine) as session: + perm1 = UserPermission(user_id=1, resource_id=1, permission="read") + perm2 = UserPermission(user_id=1, resource_id=2, permission="write") + session.add(perm1) + session.add(perm2) + session.commit() + + with pytest.raises(IntegrityError): + with Session(engine) as session: + perm3 = UserPermission(user_id=1, resource_id=1, permission="admin") + session.add(perm3) + session.commit()