Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/DM-53999.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved the handling and validation of ``Column`` overrides
68 changes: 44 additions & 24 deletions python/felis/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,25 @@ def check_votable_xtype(self) -> Column:
self.votable_xtype = "timestamp"
return self

def _update_from_overrides(self, overrides: ColumnOverrides) -> None:
"""Update the column attributes from the given overrides.

Parameters
----------
overrides
The column overrides to apply or `None` to skip applying overrides.

Notes
-----
Using ``model_fields_set`` allows updating only the fields that are
explicitly set in the `overrides` object. This prevents overwriting
existing column attributes which were not explicitly provided.
"""
if overrides.model_fields_set:
logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set)
for field in overrides.model_fields_set:
setattr(self, field, getattr(overrides, field))


class Constraint(BaseObject):
"""Table constraint model."""
Expand Down Expand Up @@ -819,10 +838,15 @@ def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]:
class ColumnOverrides(BaseModel):
"""Allowed overrides for a referenced column.

All fields are optional; missing values mean "inherit from the base
column".
Notes
-----
All of these fields are optional. Values of None may be explicitly set to
override the corresponding attribute in the referenced column but only
for certain fields (see validation in `_check_non_nullable_overrides`).
"""

model_config = CONFIG.copy()

datatype: DataType | None = None
"""New datatype for the column."""

Expand All @@ -832,7 +856,7 @@ class ColumnOverrides(BaseModel):
description: str | None = None
"""New description for the column."""

nullable: bool = True
nullable: bool | None = None
"""New nullable flag for the column."""

tap_principal: int | None = Field(default=None, alias="tap:principal")
Expand All @@ -841,6 +865,18 @@ class ColumnOverrides(BaseModel):
tap_column_index: int | None = Field(default=None, alias="tap:column_index")
"""Override for the TAP_SCHEMA column index."""

@model_validator(mode="before")
@classmethod
def _check_non_nullable_overrides(cls, data: Any) -> Any:
"""Check that certain fields are not overridden to null."""
if not isinstance(data, dict):
return data
non_nullable_fields = ("datatype", "length", "nullable", "tap_principal")
for name in non_nullable_fields:
if name in data and data[name] is None:
raise ValueError(f"The '{name}' field cannot be overridden to null")
return data

@field_serializer("datatype")
def serialize_datatype(self, value: DataType | None) -> str | None:
"""Convert `DataType` to string when serializing to JSON/YAML.
Expand Down Expand Up @@ -1331,21 +1367,6 @@ def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema:
table.column_refs = {}
return self

@classmethod
def _copy_overrides_to_column(
cls, column_ref: ColumnResourceRef, column_copy: Column
) -> ColumnOverrides | None:
"""Copy overrides from a column ref to a column."""
if column_ref.overrides is not None:
overrides = column_ref.overrides
override_fields = overrides.model_fields_set
for field_name in override_fields:
if hasattr(column_copy, field_name):
# Use attribute assignment to avoid type conversion issues
# which can occur with using model_dump and model_copy
setattr(column_copy, field_name, getattr(overrides, field_name))
return column_ref.overrides

@classmethod
def _process_column_refs(
cls,
Expand Down Expand Up @@ -1394,8 +1415,7 @@ def _process_column_refs(
f"from resource '{resource_schema.name}' and no ref_name provided"
)

# Create a copy of the base column and apply
# overrides
# Create a copy of the base column
column_copy = base_column.model_copy()

# Set the local name (key from the mapping)
Expand All @@ -1406,10 +1426,10 @@ def _process_column_refs(
# written out during serialization
column_copy._is_resource_ref = True

# Apply overrides to the original column definition
overrides: ColumnOverrides | None = None
if column_ref is not None:
overrides = cls._copy_overrides_to_column(column_ref, column_copy)
# Apply overrides to the referenced column definition
overrides = column_ref.overrides if column_ref is not None else None
if overrides is not None:
column_copy._update_from_overrides(overrides)

# Manually set the ID of the copied column as ID generation has
# already occurred by now
Expand Down
202 changes: 202 additions & 0 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import difflib
import os
import pathlib
import re
import shutil
import tempfile
import unittest
Expand All @@ -35,6 +36,7 @@
CheckConstraint,
Column,
ColumnGroup,
ColumnOverrides,
Constraint,
DataType,
ForeignKeyConstraint,
Expand Down Expand Up @@ -1551,5 +1553,205 @@ def test_tap_column_index_with_overrides(self) -> None:
self.fail(f"Unexpected column name: {column.name}")


class ColumnOverridesTestCase(unittest.TestCase):
"""Test application of overrides to a column, setting all allowed
fields.
"""

def test_all_override_fields_exist_on_column(self) -> None:
"""Ensure every ColumnOverrides field corresponds to an attribute on
Column.
"""
override_fields = set(ColumnOverrides.model_fields)
column_fields = set(Column.model_fields)

missing = override_fields - column_fields

self.assertFalse(
missing,
f"Column is missing attributes for override fields: {sorted(missing)}",
)

def test_overrides_all(self) -> None:
"""Test updating all allowed column fields from overrides."""
# Create a base column
base_column = Column(
name="base_column",
id="#base_column",
description="Base column",
datatype="char",
length=64,
nullable=False,
tap_principal=1,
tap_column_index=10,
)

# Override all allowed fields with different values
overrides = ColumnOverrides(
description="Ref column",
datatype="string",
length=256,
nullable=True,
tap_principal=0,
tap_column_index=100,
)

# Apply overrides
base_column._update_from_overrides(overrides)

# Check that the attributes were updated correctly
self.assertEqual(base_column.description, "Ref column")
self.assertEqual(base_column.datatype, "string")
self.assertEqual(base_column.length, 256)
self.assertEqual(base_column.nullable, True)
self.assertEqual(base_column.tap_principal, 0)
self.assertEqual(base_column.tap_column_index, 100)

def test_overrides_subset(self) -> None:
"""Test updating a subset of allowed column fields from overrides."""
# Create a base column
base_column = Column(
name="base_column",
id="#base_column",
description="Base column",
datatype="char",
length=64,
nullable=False,
tap_principal=1,
tap_column_index=10,
)

# Override all allowed fields with different values
overrides = ColumnOverrides(
description="Ref column",
tap_column_index=100,
)

# Apply overrides
base_column._update_from_overrides(overrides)

# Check that the attributes were updated correctly
self.assertEqual(base_column.description, "Ref column")
self.assertEqual(base_column.datatype, "char")
self.assertEqual(base_column.length, 64)
self.assertEqual(base_column.nullable, False)
self.assertEqual(base_column.tap_principal, 1)
self.assertEqual(base_column.tap_column_index, 100)

def test_overrides_default(self) -> None:
"""Test that applying the default overrides is a no-op."""
# Create a base column
base_column = Column(
name="base_column",
id="#base_column",
description="Base column",
datatype="char",
length=64,
nullable=False,
tap_principal=1,
tap_column_index=10,
)

# Apply overrides
base_column._update_from_overrides(ColumnOverrides())

# Check that the attributes remain unchanged
self.assertEqual(base_column.description, "Base column")
self.assertEqual(base_column.datatype, "char")
self.assertEqual(base_column.length, 64)
self.assertEqual(base_column.nullable, False)
self.assertEqual(base_column.tap_principal, 1)
self.assertEqual(base_column.tap_column_index, 10)

def test_overrides_with_explicit_none_values(self) -> None:
"""Test that passing explicit None values in overrides does update
the column attributes where allowed and raises errors if it is not.
"""
# Create a base column
base_column = Column(
name="base_column",
id="#base_column",
description="Base column",
datatype="int",
length=64,
nullable=False,
tap_principal=1,
tap_column_index=10,
)

# Create overrides with explicit None values for nullable fields
overrides = ColumnOverrides(
description=None,
tap_column_index=None,
)

# Apply overrides
base_column._update_from_overrides(overrides)

# Check that the attributes were updated to None where allowed
self.assertIsNone(base_column.description)
self.assertIsNone(base_column.tap_column_index)

# Check that setting non-nullable fields to None raise a specific
# ValueError on ColumnOverrides creation
for non_nullable_field in ("datatype", "length", "nullable", "tap_principal"):
with self.assertRaisesRegex(
ValueError,
re.escape(f"The '{non_nullable_field}' field cannot be overridden to null"),
):
ColumnOverrides(**{non_nullable_field: None})

def test_extra_fields_in_overrides(self) -> None:
"""Test that extra fields in ColumnOverrides raise a
ValidationError.
"""
with self.assertRaises(ValidationError) as cm:
ColumnOverrides(
description="Test column",
extra_field="This should not be allowed",
)

self.assertIn("Extra inputs are not permitted", str(cm.exception))

def test_overrides_accept_alias_keys(self) -> None:
"""Test that alias keys for TAP fields are accepted and populate the
corresponding model fields.
"""
overrides = ColumnOverrides(**{"tap:principal": 1, "tap:column_index": 42})

self.assertEqual(overrides.tap_principal, 1)
self.assertEqual(overrides.tap_column_index, 42)

# Ensure these count as explicitly provided (for model_fields_set
# logic).
self.assertIn("tap_principal", overrides.model_fields_set)
self.assertIn("tap_column_index", overrides.model_fields_set)

def test_datatype_deserialize_and_serialize(self) -> None:
"""Test that datatype is deserialized from a string to DataType and
serialized back to a string.
"""
overrides = ColumnOverrides(datatype="char")

# Deserialization should yield a DataType instance (not a raw str).
self.assertIsInstance(overrides.datatype, DataType)
self.assertEqual(str(overrides.datatype), "char")

# Serialization should produce a JSON-friendly string value.
dumped = overrides.model_dump(mode="json")
self.assertEqual(dumped["datatype"], "char")

# None should remain None on serialization.
overrides_none = ColumnOverrides()
dumped_none = overrides_none.model_dump(mode="json")
self.assertIsNone(dumped_none["datatype"])

def test_non_nullable_overrides_data_is_none(self) -> None:
"""Test that passing None to ``_check_non_nullable_overrides`` does not
raise an error.
"""
ColumnOverrides()._check_non_nullable_overrides(None)


if __name__ == "__main__":
unittest.main()
Loading