Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import List, Optional

import pytest
from sqlalchemy import inspect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import RelationshipProperty
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
from typing_extensions import Annotated

from .conftest import needs_pydanticv2


def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):
Expand Down Expand Up @@ -125,3 +130,73 @@ class Hero(SQLModel, table=True):
# The next statement should not raise an AttributeError
assert hero_rusty_man.team
assert hero_rusty_man.team.name == "Preventers"


def test_composite_primary_key(clear_sqlmodel):
class UserPermission(SQLModel, table=True):
user_id: int = Field(primary_key=True)
resource_id: int = Field(primary_key=True)
permission: str

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)

insp: Inspector = inspect(engine)
pk_constraint = insp.get_pk_constraint(str(UserPermission.__tablename__))

assert len(pk_constraint["constrained_columns"]) == 2
assert "user_id" in pk_constraint["constrained_columns"]
assert "resource_id" in pk_constraint["constrained_columns"]

with Session(engine) as session:
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
session.add(perm1)
session.add(perm2)
session.commit()

with pytest.raises(IntegrityError):
with Session(engine) as session:
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
session.add(perm3)
session.commit()


@needs_pydanticv2
def test_composite_primary_key_and_validator(clear_sqlmodel):
from pydantic import AfterValidator

def validate_resource_id(value: int) -> int:
if value < 1:
raise ValueError("Resource ID must be positive")
return value

class UserPermission(SQLModel, table=True):
user_id: int = Field(primary_key=True)
resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field(
primary_key=True
)
permission: str

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)

insp: Inspector = inspect(engine)
pk_constraint = insp.get_pk_constraint(str(UserPermission.__tablename__))

assert len(pk_constraint["constrained_columns"]) == 2
assert "user_id" in pk_constraint["constrained_columns"]
assert "resource_id" in pk_constraint["constrained_columns"]

with Session(engine) as session:
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
session.add(perm1)
session.add(perm2)
session.commit()

with pytest.raises(IntegrityError):
with Session(engine) as session:
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
session.add(perm3)
session.commit()
Loading