diff --git a/cpp/csp/adapters/kafka/KafkaInputAdapter.cpp b/cpp/csp/adapters/kafka/KafkaInputAdapter.cpp index 296df8daa..bbef00584 100644 --- a/cpp/csp/adapters/kafka/KafkaInputAdapter.cpp +++ b/cpp/csp/adapters/kafka/KafkaInputAdapter.cpp @@ -113,6 +113,10 @@ void KafkaInputAdapter::processMessage( RdKafka::Message* message, bool live, cs if( m_tickTimestampField ) msgTime = m_tickTimestampField->value(tick.get()); + if (!tick.get() -> validate()) + CSP_THROW( ValueError, "Struct validation failed for Kafka message, fields missing" ); + + bool pushLive = shouldPushLive(live, msgTime); if( shouldProcessMessage( pushLive, msgTime ) ) pushTick(pushLive, msgTime, std::move(tick), batch); diff --git a/cpp/csp/adapters/parquet/ParquetReaderColumnAdapter.cpp b/cpp/csp/adapters/parquet/ParquetReaderColumnAdapter.cpp index b8380ce53..f95ff6704 100644 --- a/cpp/csp/adapters/parquet/ParquetReaderColumnAdapter.cpp +++ b/cpp/csp/adapters/parquet/ParquetReaderColumnAdapter.cpp @@ -520,6 +520,9 @@ void ParquetStructAdapter::dispatchValue( const utils::Symbol *symbol, bool isNu { fieldSetter( s ); } + + CSP_TRUE_OR_THROW_RUNTIME( s -> validate(), "Struct validation failed for Parquet message, some fields are missing" ); + dispatchedValue = &s; } diff --git a/cpp/csp/adapters/utils/JSONMessageStructConverter.cpp b/cpp/csp/adapters/utils/JSONMessageStructConverter.cpp index 574f21084..585b3ee6d 100644 --- a/cpp/csp/adapters/utils/JSONMessageStructConverter.cpp +++ b/cpp/csp/adapters/utils/JSONMessageStructConverter.cpp @@ -145,6 +145,9 @@ StructPtr JSONMessageStructConverter::convertJSON( const char * fieldname, const } ); } + if( !struct_ -> validate() ) + CSP_THROW( ValueError, "JSON conversion of struct " << sType.meta() -> name() << " failed; some required fields were not set" ); + return struct_; } @@ -251,6 +254,7 @@ csp::StructPtr JSONMessageStructConverter::asStruct( void * bytes, size_t size ) } ); } + // root struct validation (validate()) deferred to adapter level return data; } diff --git a/cpp/csp/adapters/websocket/ClientAdapterManager.cpp b/cpp/csp/adapters/websocket/ClientAdapterManager.cpp index 423f2a234..e472357d8 100644 --- a/cpp/csp/adapters/websocket/ClientAdapterManager.cpp +++ b/cpp/csp/adapters/websocket/ClientAdapterManager.cpp @@ -52,8 +52,15 @@ void ClientAdapterManager::start( DateTime starttime, DateTime endtime ) if( m_inputAdapter ) { m_endpoint -> setOnMessage( [ this ]( void* c, size_t t ) { - PushBatch batch( m_engine -> rootEngine() ); - m_inputAdapter -> processMessage( c, t, &batch ); + try + { + PushBatch batch( m_engine -> rootEngine() ); + m_inputAdapter -> processMessage( c, t, &batch ); + } + catch( csp::Exception & err ) + { + pushStatus( StatusLevel::ERROR, ClientStatusType::GENERIC_ERROR, err.what() ); + } } ); } else { diff --git a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp index e4b0b7ff7..b727341ed 100644 --- a/cpp/csp/adapters/websocket/ClientInputAdapter.cpp +++ b/cpp/csp/adapters/websocket/ClientInputAdapter.cpp @@ -31,6 +31,8 @@ void ClientInputAdapter::processMessage( void* c, size_t t, PushBatch* batch ) if( dataType() -> type() == CspType::Type::STRUCT ) { auto tick = m_converter -> asStruct( c, t ); + if (!tick.get() -> validate()) + CSP_THROW( ValueError, "Struct validation failed for WebSocket message, fields missing" ); pushTick( std::move(tick), batch ); } else if ( dataType() -> type() == CspType::Type::STRING ) { diff --git a/cpp/csp/cppnodes/baselibimpl.cpp b/cpp/csp/cppnodes/baselibimpl.cpp index 52a5537d9..5530bb446 100644 --- a/cpp/csp/cppnodes/baselibimpl.cpp +++ b/cpp/csp/cppnodes/baselibimpl.cpp @@ -705,6 +705,9 @@ DECLARE_CPPNODE( struct_fromts ) ); } + if( !out.get() -> validate( ) ) + CSP_THROW( ValueError, "Struct " << cls.value() -> name() << " is not valid; some required fields did not tick" ); + CSP_OUTPUT( std::move( out ) ); } @@ -758,6 +761,9 @@ DECLARE_CPPNODE( struct_collectts ) } ); } + + if( !out.get() -> validate( ) ) + CSP_THROW( ValueError, "Struct " << cls.value() -> name() << " is not valid; some required fields did not tick" ); CSP_OUTPUT( std::move( out ) ); } diff --git a/cpp/csp/engine/BasketInfo.cpp b/cpp/csp/engine/BasketInfo.cpp index f43e425e9..327ecec66 100644 --- a/cpp/csp/engine/BasketInfo.cpp +++ b/cpp/csp/engine/BasketInfo.cpp @@ -161,6 +161,8 @@ void DynamicOutputBasketInfo::addShapeChange( const DialectGenericType & key, bo { auto events = autogen::DynamicBasketEvents::create(); events -> set_events( {} ); + if( !events -> validate() ) + CSP_THROW( ValueError, "DynamicBasketEvents struct is not valid; some required fields were not set" ); m_shapeTs.outputTickTyped( m_parentNode -> rootEngine() -> cycleCount(), m_parentNode -> rootEngine() -> now(), events, false ); @@ -171,6 +173,8 @@ void DynamicOutputBasketInfo::addShapeChange( const DialectGenericType & key, bo auto event = autogen::DynamicBasketEvent::create(); event -> set_key( key ); event -> set_added( added ); + if( !event -> validate() ) + CSP_THROW( ValueError, "DynamicBasketEvent struct is not valid; some required fields were not set" ); const_cast &>( events ).emplace_back( event ); } diff --git a/cpp/csp/engine/Struct.cpp b/cpp/csp/engine/Struct.cpp index 42830357e..3b501036b 100644 --- a/cpp/csp/engine/Struct.cpp +++ b/cpp/csp/engine/Struct.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include namespace csp { @@ -33,8 +35,8 @@ and adjustments required for the hidden fields */ -StructMeta::StructMeta( const std::string & name, const Fields & fields, - std::shared_ptr base ) : m_name( name ), m_base( base ), m_fields( fields ), +StructMeta::StructMeta( const std::string & name, const Fields & fields, bool isStrict, + std::shared_ptr base ) : m_name( name ), m_base( base ), m_isStrict( isStrict ), m_fields( fields ), m_size( 0 ), m_partialSize( 0 ), m_partialStart( 0 ), m_nativeStart( 0 ), m_basePadding( 0 ), m_maskLoc( 0 ), m_maskSize( 0 ), m_firstPartialField( 0 ), m_firstNativePartialField( 0 ), m_isPartialNative( true ), m_isFullyNative( true ) @@ -128,6 +130,18 @@ StructMeta::StructMeta( const std::string & name, const Fields & fields, if( !rv.second ) CSP_THROW( ValueError, "csp Struct " << name << " attempted to add existing field " << m_fields[ idx ] -> fieldname() ); } + + // A non-strict struct may not inherit (directly or indirectly) from a strict base + bool encountered_non_strict = false; + for ( const StructMeta * cur = this; cur; cur = cur -> m_base.get() ) + { + encountered_non_strict |= !cur -> isStrict(); + if ( encountered_non_strict && cur -> isStrict() ) + CSP_THROW( ValueError, + "Strict '" << m_name + << "' has non-strict inheritance of strict base '" + << cur -> name() << "'" ); + } } StructMeta::~StructMeta() @@ -494,6 +508,24 @@ void StructMeta::destroy( Struct * s ) const m_base -> destroy( s ); } +[[nodiscard]] bool StructMeta::validate( const Struct * s ) const +{ + for ( const StructMeta * cur = this; cur; cur = cur -> m_base.get() ) + { + if ( !cur -> isStrict() ) + continue; + + // Note that we do not recursively validate nested struct. + // We assume after any creation on the C++ side, these structs + // are validated properly prior to being set as field values + if ( !cur -> allFieldsSet( s ) ) + return false; + } + return true; +} + + + Struct::Struct( const std::shared_ptr & meta ) { //Initialize meta shared_ptr diff --git a/cpp/csp/engine/Struct.h b/cpp/csp/engine/Struct.h index 64b51ecae..f3b0b63bf 100644 --- a/cpp/csp/engine/Struct.h +++ b/cpp/csp/engine/Struct.h @@ -587,7 +587,7 @@ class StructMeta : public std::enable_shared_from_this using FieldNames = std::vector; //Fields will be re-arranged and assigned their offsets in StructMeta for optimal performance - StructMeta( const std::string & name, const Fields & fields, std::shared_ptr base = nullptr ); + StructMeta( const std::string & name, const Fields & fields, bool isStrict, std::shared_ptr base = nullptr ); virtual ~StructMeta(); const std::string & name() const { return m_name; } @@ -595,6 +595,7 @@ class StructMeta : public std::enable_shared_from_this size_t partialSize() const { return m_partialSize; } bool isNative() const { return m_isFullyNative; } + bool isStrict() const { return m_isStrict; } const Fields & fields() const { return m_fields; } const FieldNames & fieldNames() const { return m_fieldnames; } @@ -602,6 +603,8 @@ class StructMeta : public std::enable_shared_from_this size_t maskLoc() const { return m_maskLoc; } size_t maskSize() const { return m_maskSize; } + [[nodiscard]] bool validate( const Struct * s ) const; + const StructFieldPtr & field( const char * name ) const { static StructFieldPtr s_empty; @@ -652,7 +655,8 @@ class StructMeta : public std::enable_shared_from_this std::shared_ptr m_base; StructPtr m_default; FieldMap m_fieldMap; - + bool m_isStrict; + //fields in order, memory owners of field objects which in turn own the key memory //m_fields includes all base fields as well. m_fieldnames maintains the proper iteration order of fields Fields m_fields; @@ -738,6 +742,11 @@ class Struct return meta() -> allFieldsSet( this ); } + [[nodiscard]] bool validate() const + { + return meta() -> validate( this ); + } + //used to cache dialect representations of this struct, if needed void * dialectPtr() const { return hidden() -> dialectPtr; } diff --git a/cpp/csp/python/PyStruct.cpp b/cpp/csp/python/PyStruct.cpp index 037e3f63a..9fedd486d 100644 --- a/cpp/csp/python/PyStruct.cpp +++ b/cpp/csp/python/PyStruct.cpp @@ -20,8 +20,8 @@ class PyObjectStructField final : public DialectGenericStructField public: using BASE = DialectGenericStructField; PyObjectStructField( const std::string & name, - PyTypeObjectPtr pytype ) : DialectGenericStructField( name, sizeof( PyObjectPtr ), alignof( PyObjectPtr ) ), - m_pytype( pytype ) + PyTypeObjectPtr pytype ) : BASE( name, sizeof( PyObjectPtr ), alignof( PyObjectPtr ) ), + m_pytype( pytype ) {} @@ -42,8 +42,8 @@ class PyObjectStructField final : public DialectGenericStructField }; DialectStructMeta::DialectStructMeta( PyTypeObject * pyType, const std::string & name, - const Fields & flds, std::shared_ptr base ) : - StructMeta( name, flds, base ), + const Fields & flds, bool isStrict, std::shared_ptr base ) : + StructMeta( name, flds, isStrict, base ), m_pyType( pyType ) { } @@ -110,12 +110,19 @@ static PyObject * PyStructMeta_new( PyTypeObject *subtype, PyObject *args, PyObj { PyObject *key, *type; Py_ssize_t pos = 0; + PyObject *optional_fields = PyDict_GetItemString( dict, "__optional_fields__" ); + + while( PyDict_Next( metadata, &pos, &key, &type ) ) { const char * keystr = PyUnicode_AsUTF8( key ); if( !keystr ) CSP_THROW( PythonPassthrough, "" ); + if (!PySet_Check(optional_fields)) + CSP_THROW( TypeError, "Struct metadata for key " << keystr << " expected a set, got " << PyObjectPtr::incref( optional_fields ) ); + + if( !PyType_Check( type ) && !PyList_Check( type ) ) CSP_THROW( TypeError, "Struct metadata for key " << keystr << " expected a type, got " << PyObjectPtr::incref( type ) ); @@ -151,7 +158,7 @@ static PyObject * PyStructMeta_new( PyTypeObject *subtype, PyObject *args, PyObj default: CSP_THROW( ValueError, "Unexpected csp type " << csptype -> type() << " on struct " << name ); } - + fields.emplace_back( field ); } } @@ -188,7 +195,12 @@ static PyObject * PyStructMeta_new( PyTypeObject *subtype, PyObject *args, PyObj | | PyStruct -------------------------- */ - auto structMeta = std::make_shared( ( PyTypeObject * ) pymeta, name, fields, metabase ); + + PyObject * strict_enabled = PyDict_GetItemString( dict, "__strict_enabled__" ); + if( !strict_enabled ) + CSP_THROW( KeyError, "StructMeta missing __strict_enabled__" ); + bool isStrict = strict_enabled == Py_True; + auto structMeta = std::make_shared( ( PyTypeObject * ) pymeta, name, fields, isStrict, metabase ); //Setup fast attr dict lookup pymeta -> attrDict = PyObjectPtr::own( PyDict_New() ); @@ -347,6 +359,7 @@ static PyObject * PyStructMeta_metadata_info( PyStructMeta * m ) return out.release(); } + static PyMethodDef PyStructMeta_methods[] = { {"_layout", (PyCFunction) PyStructMeta_layout, METH_NOARGS, "debug view of structs internal mem layout"}, {"_metadata_info", (PyCFunction) PyStructMeta_metadata_info, METH_NOARGS, "provide detailed information about struct layout"}, @@ -456,6 +469,9 @@ void PyStruct::setattr( Struct * s, PyObject * attr, PyObject * value ) if( !field ) CSP_THROW( AttributeError, "'" << s -> meta() -> name() << "' object has no attribute '" << PyUnicode_AsUTF8( attr ) << "'" ); + if ( s -> meta() -> isStrict() && value == nullptr ) + CSP_THROW( AttributeError, "Strict struct " << s -> meta() -> name() << " does not allow the deletion of field " << PyUnicode_AsUTF8( attr ) ); + try { switchCspType( field -> type(), [field,&struct_=s,value]( auto tag ) @@ -795,6 +811,8 @@ int PyStruct_init( PyStruct * self, PyObject * args, PyObject * kwargs ) CSP_BEGIN_METHOD; PyStruct_setattrs( self, args, kwargs, "__init__" ); + if( !self -> struct_ -> validate() ) + CSP_THROW( ValueError, "Struct " << self -> struct_ -> meta() -> name() << " is not valid; some required fields were not set on init" ); CSP_RETURN_INT; } diff --git a/cpp/csp/python/PyStruct.h b/cpp/csp/python/PyStruct.h index 2034268b8..2c794cdd9 100644 --- a/cpp/csp/python/PyStruct.h +++ b/cpp/csp/python/PyStruct.h @@ -25,7 +25,7 @@ class CSPTYPESIMPL_EXPORT DialectStructMeta : public StructMeta { public: DialectStructMeta( PyTypeObject * pyType, const std::string & name, - const Fields & fields, std::shared_ptr base = nullptr ); + const Fields & fields, bool isStrict, std::shared_ptr base = nullptr ); ~DialectStructMeta() {} PyTypeObject * pyType() const { return m_pyType; } diff --git a/csp/impl/struct.py b/csp/impl/struct.py index 30b201285..f1580434a 100644 --- a/csp/impl/struct.py +++ b/csp/impl/struct.py @@ -15,7 +15,7 @@ class StructMeta(_csptypesimpl.PyStructMeta): - def __new__(cls, name, bases, dct): + def __new__(cls, name, bases, dct, allow_unset=True): full_metadata = {} full_metadata_typed = {} metadata = {} @@ -29,12 +29,17 @@ def __new__(cls, name, bases, dct): defaults.update(base.__defaults__) annotations = dct.get("__annotations__", None) + optional_fields = set() if annotations: for k, v in annotations.items(): actual_type = v # Lists need to be normalized too as potentially we need to add a boolean flag to use FastList if v == FastList: raise TypeError(f"{v} annotation is not supported without args") + if CspTypingUtils.is_optional_type(v): + if (not allow_unset) and (k not in dct): + raise TypeError(f"Optional field {k} must have a default value") + optional_fields.add(k) if ( CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v) @@ -72,6 +77,8 @@ def __new__(cls, name, bases, dct): dct["__full_metadata_typed__"] = full_metadata_typed dct["__metadata__"] = metadata dct["__defaults__"] = defaults + dct["__optional_fields__"] = optional_fields + dct["__strict_enabled__"] = not allow_unset res = super().__new__(cls, name, bases, dct) # This is how we make sure we construct the pydantic schema from the new class @@ -174,6 +181,14 @@ def metadata(cls, typed=False): else: return cls.__full_metadata__ + @classmethod + def optional_fields(cls): + return cls.__optional_fields__ + + @classmethod + def is_strict(cls): + return cls.__strict_enabled__ + @classmethod def fromts(cls, trigger=None, /, **kwargs): """convert valid inputs into ts[ struct ] @@ -237,12 +252,13 @@ def _obj_from_python(cls, json, obj_type): elif issubclass(obj_type, Struct): if not isinstance(json, dict): raise TypeError("Representation of struct as json is expected to be of dict type") - res = obj_type() + obj_args = {} for k, v in json.items(): expected_type = obj_type.__full_metadata_typed__.get(k, None) if expected_type is None: raise KeyError(f"Unexpected key {k} for type {obj_type}") - setattr(res, k, cls._obj_from_python(v, expected_type)) + obj_args[k] = cls._obj_from_python(v, expected_type) + res = obj_type(**obj_args) return res else: if isinstance(json, obj_type): diff --git a/csp/impl/types/typing_utils.py b/csp/impl/types/typing_utils.py index f852168d2..3511888b2 100644 --- a/csp/impl/types/typing_utils.py +++ b/csp/impl/types/typing_utils.py @@ -83,6 +83,13 @@ def is_numpy_nd_array_type(cls, typ): def is_union_type(cls, typ): return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union + @classmethod + def is_optional_type(cls, typ): + if cls.is_union_type(typ): + args = typing.get_args(typ) + return type(None) in args + return False + @classmethod def is_literal_type(cls, typ): return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal diff --git a/csp/impl/wiring/edge.py b/csp/impl/wiring/edge.py index 072b0a424..8ebf460e7 100644 --- a/csp/impl/wiring/edge.py +++ b/csp/impl/wiring/edge.py @@ -202,6 +202,11 @@ def __getattr__(self, key): elemtype = typ.metadata(typed=True).get(key) if elemtype is None: raise AttributeError("'%s' object has no attribute '%s'" % (self.tstype.typ.__name__, key)) + if (key in typ.optional_fields()) and (typ.is_strict()): + raise AttributeError( + "Cannot access optional field '%s' on strict struct object '%s' at graph time" + % (key, self.tstype.typ.__name__) + ) return csp.struct_field(self, key, elemtype) raise AttributeError("'Edge' object has no attribute '%s'" % (key)) diff --git a/csp/tests/test_strict_structs.py b/csp/tests/test_strict_structs.py new file mode 100644 index 000000000..817ec7d2b --- /dev/null +++ b/csp/tests/test_strict_structs.py @@ -0,0 +1,359 @@ +import unittest +from datetime import datetime, timedelta +from typing import Optional + +import csp +from csp import ts +from csp.impl.wiring.base_parser import CspParseError + + +class TestStrictStructs(unittest.TestCase): + def test_backwards_compatibility(self): + """quick test that existing struct behavior is unchanged""" + + class OldStruct(csp.Struct, allow_unset=True): + a: int + b: str + + s = OldStruct(a=5) + self.assertFalse(hasattr(s, "b")) + with self.assertRaisesRegex(AttributeError, "b"): + _ = s.b + del s.a + self.assertFalse(hasattr(s, "a")) + + def test_strict_struct_initialization(self): + """test initialization rules for strict structs. + + notably, + * Setting fields works as expected + * Initialize all non-default fields (including Optional) + * Missing required fields fail + """ + + class MyStrictStruct(csp.Struct, allow_unset=False): + req_int: int + opt_str: Optional[str] = None + def_int: int = 123 + opt_str_2: Optional[str] = None + + # Valid initialization + s1 = MyStrictStruct(req_int=10, opt_str="hello") + self.assertEqual(s1.req_int, 10) + self.assertEqual(s1.opt_str, "hello") + self.assertEqual(s1.def_int, 123) + self.assertIsNone(s1.opt_str_2) + + with self.assertRaisesRegex( + ValueError, "Struct MyStrictStruct is not valid; some required fields were not set on init" + ): + MyStrictStruct() + + def test_strict_struct_hasattr_delattr(self): + """test hasattr and delattr behavior for strict structs""" + + class MyStrictStruct1(csp.Struct, allow_unset=False): + req_int: int + opt_str: Optional[str] = None + + class MyStrictStruct2(csp.Struct, allow_unset=False): + req_int: int + opt_str: Optional[str] = None + + s = MyStrictStruct1(req_int=10, opt_str="hello") + r = MyStrictStruct2(req_int=5) + + # hasattr will always be True for all defined fields + self.assertTrue(hasattr(s, "req_int")) + self.assertTrue(hasattr(s, "opt_str")) + self.assertTrue(hasattr(r, "req_int")) + self.assertTrue(hasattr(r, "opt_str")) + + # delattr is forbidden + with self.assertRaisesRegex( + AttributeError, "Strict struct MyStrictStruct1 does not allow the deletion of field req_int" + ): + del s.req_int + + with self.assertRaisesRegex( + AttributeError, "Strict struct MyStrictStruct1 does not allow the deletion of field opt_str" + ): + del s.opt_str + + def test_strict_struct_serialization(self): + """test to_dict and from_dict behavior""" + + class MyStrictStruct(csp.Struct, allow_unset=False): + req_int: int + opt_str: Optional[str] = None + def_int: int = 100 + req_opt_str: Optional[str] = None + + s = MyStrictStruct(req_int=50, req_opt_str="NoneStr") + expected_dict = {"req_int": 50, "opt_str": None, "def_int": 100, "req_opt_str": "NoneStr"} + self.assertEqual(s.to_dict(), expected_dict) + + with self.assertRaisesRegex( + ValueError, "Struct MyStrictStruct is not valid; some required fields were not set on init" + ): + MyStrictStruct.from_dict({"opt_str": "hello", "def_int": 13}) + + MyStrictStruct.from_dict({"req_int": 60, "opt_str": None, "req_opt_str": None}) + s2 = MyStrictStruct.from_dict({"req_int": 60, "req_opt_str": None}) + self.assertEqual(s2.req_int, 60) + self.assertIsNone(s2.opt_str) + self.assertEqual(s2.def_int, 100) + + def test_strict_struct_wiring_access_1(self): + """test accessing fields on a time series at graph wiring time""" + + class MyStrictStruct(csp.Struct, allow_unset=False): + req_int: int + opt_str: Optional[str] = None + + # check that at graph and wire time we are able to access required fields just fine: + + @csp.node + def ok_node(x: csp.ts[MyStrictStruct]): + int_val = x.req_int + + @csp.graph + def g(): + s_ts = csp.const(MyStrictStruct(req_int=1)) + req_ts = s_ts.req_int + csp.add_graph_output("req_ts", req_ts) + + res = csp.run(g, starttime=datetime(2023, 1, 1)) + self.assertEqual(res["req_ts"][0][1], 1) + + # check that at graph time we cannot access optional fields: + + @csp.graph + def g_fail(): + s_ts = csp.const(MyStrictStruct(req_int=1)) + opt_ts = s_ts.opt_str + + with self.assertRaisesRegex( + AttributeError, + "Cannot access optional field 'opt_str' on strict struct object 'MyStrictStruct' at graph time", + ): + csp.run(g_fail, starttime=datetime(2023, 1, 1)) + + def test_strict_struct_fromts(self): + """fromts requires all non-defaulted fields to tick together""" + + class MyStrictStruct(csp.Struct, allow_unset=False): + req_int1: int + req_int2: int + opt_str: Optional[str] = None + req_default_str: str = "default" + + @csp.node + def make_ts(x: csp.ts[int]) -> csp.ts[int]: + if x % 2 == 0: + return x + + @csp.graph + def g(): + ts1 = make_ts(csp.const(2)) + ts2 = make_ts(csp.const(1)) + + # ts1 and ts2 don't tick together + s_ts = MyStrictStruct.fromts(req_int1=ts1, req_int2=ts2) + csp.add_graph_output("output", s_ts) + + with self.assertRaisesRegex( + ValueError, "Struct MyStrictStruct is not valid; some required fields did not tick" + ): + csp.run(g, starttime=datetime(2023, 1, 1)) + + @csp.graph + def g_ok(): + ts1 = csp.const(2) + ts2 = csp.const(4) + + # ts1 and ts2 tick together + s_ts = MyStrictStruct.fromts(req_int1=ts1, req_int2=ts2) + csp.add_graph_output("output", s_ts) + + csp.run(g_ok, starttime=datetime(2023, 1, 1)) + + @csp.graph + def g_ok_with_optional(): + beat = csp.timer(timedelta(days=1)) + even = csp.eq(csp.mod(csp.count(beat), csp.const(2)), csp.const(0)) + + int_ts1 = csp.sample(even, csp.const(1)) + int_ts2 = csp.sample(even, csp.const(2)) + str_ts = csp.sample(even, csp.const("Hello")) + + s_ts = MyStrictStruct.fromts(req_int1=int_ts1, req_int2=int_ts2, req_default_str=str_ts) + csp.add_graph_output("output", s_ts) + + csp.run(g_ok_with_optional, starttime=datetime(2025, 1, 1), endtime=datetime(2025, 1, 5)) + + def test_strict_struct_inheritance_and_nested(self): + class BaseStrict(csp.Struct, allow_unset=False): + base_req: int + + class DerivedStrict(BaseStrict, allow_unset=False): + derived_req: int + + d_ok = DerivedStrict(base_req=1, derived_req=2) + self.assertEqual(d_ok.base_req, 1) + self.assertEqual(d_ok.derived_req, 2) + + with self.assertRaisesRegex( + ValueError, "Struct DerivedStrict is not valid; some required fields were not set on init" + ): + DerivedStrict(base_req=10) + with self.assertRaisesRegex( + ValueError, "Struct DerivedStrict is not valid; some required fields were not set on init" + ): + DerivedStrict(derived_req=20) + + # loose base & strict child: + class LooseBase(csp.Struct, allow_unset=True): + loose_req: int + + class StrictChild(LooseBase, allow_unset=False): + child_req: int + + sc_ok = StrictChild(child_req=5, loose_req=10) + self.assertEqual(sc_ok.child_req, 5) + + with self.assertRaisesRegex( + ValueError, "Struct StrictChild is not valid; some required fields were not set on init" + ): + StrictChild() + with self.assertRaisesRegex( + ValueError, "Struct StrictChild is not valid; some required fields were not set on init" + ): + StrictChild(loose_req=10) + with self.assertRaisesRegex( + ValueError, "Struct StrictChild is not valid; some required fields were not set on init" + ): + StrictChild(child_req=5) + + # nested struct fields: + class InnerStrict(csp.Struct, allow_unset=False): + val: int + + class OuterStrict(csp.Struct, allow_unset=False): + inner: InnerStrict + + os_ok = OuterStrict(inner=InnerStrict(val=42)) + self.assertEqual(os_ok.inner.val, 42) + + with self.assertRaisesRegex( + ValueError, "Struct InnerStrict is not valid; some required fields were not set on init" + ): + OuterStrict(inner=InnerStrict()) + + # nested loose struct inside strict: + class InnerLoose(csp.Struct, allow_unset=True): + val: int + + class OuterStrict2(csp.Struct, allow_unset=False): + inner: InnerLoose + + ol_ok = OuterStrict2(inner=InnerLoose()) + self.assertIsInstance(ol_ok.inner, InnerLoose) + + with self.assertRaisesRegex( + ValueError, "Struct OuterStrict2 is not valid; some required fields were not set on init" + ): + OuterStrict2() + + def test_nonstrict_cannot_inherit_strict(self): + """non-strict structs inheriting from strict bases should raise""" + + class StrictBase(csp.Struct, allow_unset=False): + base_val: int + + with self.assertRaisesRegex(ValueError, "non-strict inheritance of strict base"): + + class NonStrictChild1(StrictBase, allow_unset=True): + child_val1: Optional[int] = None + + def test_nonstrict_strict_nonstrict_inheritance_order(self): + """inheritance order NonStrict -> Strict -> NonStrict raises an error""" + + class NonStrictBase(csp.Struct, allow_unset=True): + base_val: int + + class StrictMiddle(NonStrictBase, allow_unset=False): + middle_val: int + + with self.assertRaisesRegex(ValueError, "non-strict inheritance of strict base"): + + class NonStrictChild(StrictMiddle, allow_unset=True): + child_val: Optional[int] = None + + def test_nested_struct_serialization(self): + """to_dict / from_dict work with nested strict & non-strict structs""" + + class InnerStrict(csp.Struct, allow_unset=False): + x: int + + class InnerLoose(csp.Struct, allow_unset=True): + y: Optional[int] = None + + class OuterStruct(csp.Struct, allow_unset=False): + strict_inner: InnerStrict + loose_inner: InnerLoose + + o = OuterStruct(strict_inner=InnerStrict(x=5), loose_inner=InnerLoose()) + expected_dict = {"strict_inner": {"x": 5}, "loose_inner": {"y": None}} + self.assertEqual(o.to_dict(), expected_dict) + + o2 = OuterStruct.from_dict({"strict_inner": {"x": 10}, "loose_inner": {"y": 20}}) + self.assertEqual(o2.strict_inner.x, 10) + self.assertIsNotNone(o2.loose_inner) + self.assertEqual(o2.loose_inner.y, 20) + + with self.assertRaisesRegex( + ValueError, "Struct OuterStruct is not valid; some required fields were not set on init" + ): + OuterStruct.from_dict({"loose_inner": {"y": 1}}) + + with self.assertRaisesRegex( + ValueError, "Struct InnerStrict is not valid; some required fields were not set on init" + ): + OuterStruct.from_dict({"strict_inner": {}, "loose_inner": {"y": None}}) + + def test_strict_struct_wiring_access_2(self): + class Test(csp.Struct, allow_unset=False): + name: str + age: int + is_active: Optional[bool] = None + + def greet(self): + return f"Hello, my name is {self.name} and I am {self.age} years old." + + @csp.node + def test() -> csp.ts[Test]: + return Test(name="John", age=30, is_active=True) + + @csp.graph + def main_graph(): + res = test().is_active + csp.print("", res) + + with self.assertRaisesRegex( + AttributeError, "Cannot access optional field 'is_active' on strict struct object 'Test' at graph time" + ): + csp.build_graph(main_graph) + + def test_strict_struct_optional_field_validation_no_default(self): + """test that strict structs cannot have Optional fields without defaults""" + + with self.assertRaisesRegex(TypeError, "Optional field bad_field must have a default value"): + + class InvalidStrictStruct(csp.Struct, allow_unset=False): + req_field: int + bad_field: Optional[str] + + +if __name__ == "__main__": + unittest.main()