From b155522fcbed91d4e7e19424e50796fccea53524 Mon Sep 17 00:00:00 2001 From: Frank Dekervel Date: Fri, 20 Jun 2025 14:36:50 +0200 Subject: [PATCH] Add docstrings and slot mode helper --- linkml/generators/rustgen/rustgen.py | 223 ++++++++++++++++---------- linkml/generators/rustgen/template.py | 84 +++++----- tests/test_generators/test_rustgen.py | 34 ++++ 3 files changed, 224 insertions(+), 117 deletions(-) diff --git a/linkml/generators/rustgen/rustgen.py b/linkml/generators/rustgen/rustgen.py index fff886e93..0b4bdca57 100644 --- a/linkml/generators/rustgen/rustgen.py +++ b/linkml/generators/rustgen/rustgen.py @@ -11,12 +11,12 @@ SlotDefinition, TypeDefinition, ) -from linkml_runtime.utils.formatutils import camelcase, underscore, uncamelcase +from linkml_runtime.utils.formatutils import camelcase, uncamelcase, underscore from linkml_runtime.utils.schemaview import OrderedBy, SchemaView from linkml.generators.common.lifecycle import LifecycleMixin from linkml.generators.common.template import ObjectImport -from linkml.generators.common.type_designators import get_accepted_type_designator_values, get_type_designator_value +from linkml.generators.common.type_designators import get_accepted_type_designator_values from linkml.generators.rustgen.build import ( AttributeResult, ClassResult, @@ -26,34 +26,32 @@ SlotResult, TypeResult, ) - - - from linkml.generators.rustgen.template import ( + AsKeyValue, + ContainerType, Import, Imports, - RustCargo, - RustEnum, - RustFile, + PolyContainersFile, PolyFile, PolyTrait, PolyTraitImpl, - SerdeUtilsFile, + PolyTraitImplForSubtypeEnum, + PolyTraitProperty, + PolyTraitPropertyImpl, + PolyTraitPropertyMatch, + RustCargo, + RustClassModule, + RustEnum, + RustFile, RustProperty, - AsKeyValue, RustPyProject, + RustRange, RustStruct, + RustStructOrSubtypeEnum, RustTemplateModel, RustTypeAlias, - RustStructOrSubtypeEnum, + SerdeUtilsFile, SlotRangeAsUnion, - RustClassModule, - PolyTraitProperty, - PolyTraitPropertyImpl, - PolyTraitImplForSubtypeEnum, - PolyTraitPropertyMatch, - PolyContainersFile, - RustRange, ContainerType, ) from linkml.utils.generator import Generator @@ -91,8 +89,12 @@ RUST_IMPORTS = { "dec": Import(module="rust_decimal", version="1.36", objects=[ObjectImport(name="dec")]), - "NaiveDate": Import(module="chrono", features=["serde"], version="0.4.41", objects= [ObjectImport(name="NaiveDate")]), - "NaiveDateTime": Import(module="chrono", features=["serde"], version="0.4.41", objects= [ObjectImport(name="NaiveDateTime")]), + "NaiveDate": Import( + module="chrono", features=["serde"], version="0.4.41", objects=[ObjectImport(name="NaiveDate")] + ), + "NaiveDateTime": Import( + module="chrono", features=["serde"], version="0.4.41", objects=[ObjectImport(name="NaiveDateTime")] + ), } DEFAULT_IMPORTS = Imports( @@ -108,28 +110,39 @@ module="serde", version="1.0", features=["derive"], - objects=[ObjectImport(name="Serialize"), ObjectImport(name="Deserialize"), ObjectImport(name="de::IntoDeserializer")], + objects=[ + ObjectImport(name="Serialize"), + ObjectImport(name="Deserialize"), + ObjectImport(name="de::IntoDeserializer"), + ], feature_flag="serde", ), Import(module="serde-value", version="0.7.0", objects=[ObjectImport(name="Value")]), Import(module="serde_yml", version="0.0.12", feature_flag="serde", alias="_"), - Import(module="serde_path_to_error", version = "0.1.17", objects=[], feature_flag="serde"), - + Import(module="serde_path_to_error", version="0.1.17", objects=[], feature_flag="serde"), ] ) PYTHON_IMPORTS = Imports( - imports = [ - Import(module="pyo3", version="0.25.0", objects=[ObjectImport(name="prelude::*"), ObjectImport(name="FromPyObject")], feature_flag="pyo3", features=["chrono"]), + imports=[ + Import( + module="pyo3", + version="0.25.0", + objects=[ObjectImport(name="prelude::*"), ObjectImport(name="FromPyObject")], + feature_flag="pyo3", + features=["chrono"], + ), # Import(module="serde_pyobject", version="0.6.1", objects=[], feature_flag="pyo3", features=[]), ] ) + class SlotContainerMode(Enum): SINGLE_VALUE = "single_value" MAPPING = "mapping" LIST = "list" + class SlotInlineMode(Enum): INLINE = "inline" PRIMITIVE = "primitive" @@ -152,29 +165,39 @@ def get_identifier_slot(cls: ClassDefinition, sv: SchemaView) -> Optional[SlotDe return None -def get_inline_mode(s: SlotDefinition, sv: SchemaView) -> tuple[SlotContainerMode, SlotInlineMode]: +def determine_slot_mode(s: SlotDefinition, sv: SchemaView) -> tuple[SlotContainerMode, SlotInlineMode]: + """Return container and inline modes for a slot.""" + class_range = s.range in sv.all_classes() if not class_range: - return (SlotContainerMode.LIST if s.multivalued else SlotContainerMode.SINGLE_VALUE, SlotInlineMode.PRIMITIVE) + return ( + SlotContainerMode.LIST if s.multivalued else SlotContainerMode.SINGLE_VALUE, + SlotInlineMode.PRIMITIVE, + ) + if s.multivalued and s.inlined_as_list: return (SlotContainerMode.LIST, SlotInlineMode.INLINE) + key_slot = get_key_or_identifier_slot(sv.get_class(s.range), sv) identifier_slot = get_identifier_slot(sv.get_class(s.range), sv) inlined = s.inlined if identifier_slot is None: - inlined = True ## can only inline if identifier slot is none + # can only inline if identifier slot is none + inlined = True if not s.multivalued: - return (SlotContainerMode.SINGLE_VALUE, SlotInlineMode.INLINE if inlined else SlotInlineMode.REFERENCE) - + return ( + SlotContainerMode.SINGLE_VALUE, + SlotInlineMode.INLINE if inlined else SlotInlineMode.REFERENCE, + ) + if not inlined: return (SlotContainerMode.LIST, SlotInlineMode.REFERENCE) - + if key_slot is not None: return (SlotContainerMode.MAPPING, SlotInlineMode.INLINE) else: return (SlotContainerMode.LIST, SlotInlineMode.INLINE) - def can_contain_reference_to_class(s: SlotDefinition, cls: ClassDefinition, sv: SchemaView) -> bool: @@ -184,7 +207,7 @@ def can_contain_reference_to_class(s: SlotDefinition, cls: ClassDefinition, sv: while len(classes_to_check) > 0: a_class = classes_to_check.pop() seen_classes.add(a_class) - if not a_class in sv.all_classes(): + if a_class not in sv.all_classes(): continue if a_class == ref_name: return True @@ -195,7 +218,9 @@ def can_contain_reference_to_class(s: SlotDefinition, cls: ClassDefinition, sv: return False -def get_rust_type(t: Union[TypeDefinition, type, str], sv: SchemaView, pyo3: bool = False, crate_ref: Optional[str] = None) -> str: +def get_rust_type( + t: Union[TypeDefinition, type, str], sv: SchemaView, pyo3: bool = False, crate_ref: Optional[str] = None +) -> str: """ Get the rust type from a given linkml type """ @@ -220,9 +245,6 @@ def get_rust_type(t: Union[TypeDefinition, type, str], sv: SchemaView, pyo3: boo elif t in sv.all_classes(): c = sv.get_class(t) rsrange = get_name(c) - descendants = sv.class_descendants(t) - #if len(descendants) > 1: - # rsrange += "OrSubtype" # FIXME: Raise here once we have implemented all base types if rsrange is None: @@ -234,31 +256,45 @@ def get_rust_type(t: Union[TypeDefinition, type, str], sv: SchemaView, pyo3: boo return rsrange - -def get_rust_range_info(cls: ClassDefinition, s: SlotDefinition, sv: SchemaView, crate_ref: Optional[str] = None) -> RustRange: - (container_mode, inline_mode) = get_inline_mode(s, sv) +def get_rust_range_info( + cls: ClassDefinition, s: SlotDefinition, sv: SchemaView, crate_ref: Optional[str] = None +) -> RustRange: + (container_mode, inline_mode) = determine_slot_mode(s, sv) all_ranges = sv.slot_range_as_union(s) sub_ranges = [ - RustRange(type_="String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(r, sv, True, crate_ref), - is_class_range = r in sv.all_classes(), - has_class_subtypes= len(sv.class_descendants(r)) > 1 if r in sv.all_classes() else False, - ) + RustRange( + type_="String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(r, sv, True, crate_ref), + is_class_range=r in sv.all_classes(), + has_class_subtypes=len(sv.class_descendants(r)) > 1 if r in sv.all_classes() else False, + ) for r in all_ranges ] - + res = RustRange( optional=not s.required, - has_default = not (s.required or False) or (s.multivalued or False), - containerType= ContainerType.LIST if container_mode == SlotContainerMode.LIST else ContainerType.MAPPING if container_mode == SlotContainerMode.MAPPING else None, + has_default=not (s.required or False) or (s.multivalued or False), + containerType=( + ContainerType.LIST + if container_mode == SlotContainerMode.LIST + else ContainerType.MAPPING if container_mode == SlotContainerMode.MAPPING else None + ), child_ranges=sub_ranges if len(sub_ranges) > 1 else None, box_needed=inline_mode == SlotInlineMode.INLINE and can_contain_reference_to_class(s, cls, sv), is_class_range=all_ranges[0] in sv.all_classes() if len(all_ranges) == 1 else False, is_reference=inline_mode == SlotInlineMode.REFERENCE, - has_class_subtypes= len(sv.class_descendants(all_ranges[0])) > 1 if len(all_ranges) == 1 and all_ranges[0] in sv.all_classes() else False, - type_ = underscore(uncamelcase(cls.name)) + "_utl::" + get_name(s) + "_range" if len(sub_ranges) > 1 else ("String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(s.range, sv, True, crate_ref)) + has_class_subtypes=( + len(sv.class_descendants(all_ranges[0])) > 1 + if len(all_ranges) == 1 and all_ranges[0] in sv.all_classes() + else False + ), + type_=( + underscore(uncamelcase(cls.name)) + "_utl::" + get_name(s) + "_range" + if len(sub_ranges) > 1 + else ("String" if inline_mode == SlotInlineMode.REFERENCE else get_rust_type(s.range, sv, True, crate_ref)) + ), ) return res - + def protect_name(v: str) -> str: """ @@ -306,7 +342,7 @@ class RustGenerator(Generator, LifecycleMixin): """ * If ``mode == "crate"`` , a directory to contain the generated crate * If ``mode == "file"`` , a file with a ``.rs`` extension - + If output is not provided at object instantiation, it must be provided on a call to :meth:`.serialize` """ @@ -352,7 +388,7 @@ def generate_slot(self, slot: SlotDefinition) -> SlotResult: slot = self.before_generate_slot(slot, self.schemaview) class_range = slot.range in self.schemaview.all_classes() type_ = get_rust_type(slot.range, self.schemaview, self.pyo3) - + slot = SlotResult( source=slot, slot=RustTypeAlias( @@ -378,11 +414,15 @@ def generate_class(self, cls: ClassDefinition) -> ClassResult: for a in induced_attrs: ranges = self.schemaview.slot_range_as_union(a) if len(ranges) > 1: - slot_range_unions.append(SlotRangeAsUnion(slot_name=get_name(a), ranges=[get_rust_type(r, self.schemaview, True) for r in ranges])) - + slot_range_unions.append( + SlotRangeAsUnion( + slot_name=get_name(a), ranges=[get_rust_type(r, self.schemaview, True) for r in ranges] + ) + ) + cls_mod = RustClassModule( class_name=get_name(cls), - class_name_snakecase= underscore(uncamelcase(cls.name)), + class_name_snakecase=underscore(uncamelcase(cls.name)), slot_ranges=slot_range_unions, ) @@ -410,7 +450,7 @@ def generate_class(self, cls: ClassDefinition) -> ClassResult: res = self.after_generate_class(res, self.schemaview) return res - + def gen_struct_or_subtype_enum(self, cls: ClassDefinition) -> Optional[RustStructOrSubtypeEnum]: descendants = self.schemaview.class_descendants(cls.name) td = self.schemaview.get_type_designator_slot(cls.name) @@ -426,12 +466,10 @@ def gen_struct_or_subtype_enum(self, cls: ClassDefinition) -> Optional[RustStruc struct_names=[get_name(self.schemaview.get_class(d)) for d in descendants], type_designator_name=get_name(td) if td else None, as_key_value=get_key_or_identifier_slot(cls, self.schemaview) is not None, - type_designators = td_mapping, + type_designators=td_mapping, ) return None - - def generate_class_as_key_value(self, cls: ClassDefinition) -> Optional[AsKeyValue]: induced_attrs = [self.schemaview.induced_slot(sn, cls.name) for sn in self.schemaview.class_slots(cls.name)] key_attr = None @@ -455,18 +493,17 @@ def generate_class_as_key_value(self, cls: ClassDefinition) -> Optional[AsKeyVal value_args_no_default.append(attr) if key_attr is not None: return AsKeyValue( - name=get_name(cls), - key_property_name=get_name(key_attr), - key_property_type=get_rust_type(key_attr.range, self.schemaview, self.pyo3), - value_property_name=get_name(value_attrs[0]), - value_property_type=get_rust_type(value_attrs[0].range, self.schemaview, self.pyo3), - can_convert_from_primitive = len(value_args_no_default) <= 1, - can_convert_from_empty = len(value_args_no_default) == 0, - serde=self.serde, - pyo3=self.pyo3, + name=get_name(cls), + key_property_name=get_name(key_attr), + key_property_type=get_rust_type(key_attr.range, self.schemaview, self.pyo3), + value_property_name=get_name(value_attrs[0]), + value_property_type=get_rust_type(value_attrs[0].range, self.schemaview, self.pyo3), + can_convert_from_primitive=len(value_args_no_default) <= 1, + can_convert_from_empty=len(value_args_no_default) == 0, + serde=self.serde, + pyo3=self.pyo3, ) return None - def generate_attribute(self, attr: SlotDefinition, cls: ClassDefinition) -> AttributeResult: """ @@ -474,7 +511,7 @@ def generate_attribute(self, attr: SlotDefinition, cls: ClassDefinition) -> Attr """ attr = self.before_generate_slot(attr, self.schemaview) is_class_range = attr.range in self.schemaview.all_classes() - (container_mode, inline_mode) = get_inline_mode(attr, self.schemaview) + (container_mode, inline_mode) = determine_slot_mode(attr, self.schemaview) range = get_rust_range_info(cls, attr, self.schemaview) res = AttributeResult( source=attr, @@ -485,7 +522,8 @@ def generate_attribute(self, attr: SlotDefinition, cls: ClassDefinition) -> Attr type_=range, required=bool(attr.required), multivalued=True if attr.multivalued else False, - is_key_value=is_class_range and self.generate_class_as_key_value(self.schemaview.get_class(attr.range)) is not None, + is_key_value=is_class_range + and self.generate_class_as_key_value(self.schemaview.get_class(attr.range)) is not None, pyo3=self.pyo3, serde=self.serde, ), @@ -610,37 +648,60 @@ def gen_poly_trait(self, cls: ClassDefinition) -> PolyTrait: superclass_names.append(cls.is_a) for m in cls.mixins: superclass_names.append(m) - + superclasses = [self.schemaview.get_class(sn) for sn in superclass_names if sn is not None] for superclass in superclasses: attribs_sc = self.schemaview.class_induced_slots(superclass.name) attribs = [a for a in attribs if a.name not in [sc.name for sc in attribs_sc]] - + rust_attribs = [] for a in attribs: n = get_name(a) ri = get_rust_range_info(cls, a, self.schemaview) rust_attribs.append(PolyTraitProperty(name=n, range=ri)) - - + subtype_impls = [] for sc in self.schemaview.class_descendants(cls.name): sco = self.schemaview.get_class(sc) induced_slots = self.schemaview.class_induced_slots(cls.name) + def find_slot(n: str): for s in induced_slots: if s.name == n: return s return None - ptis = [PolyTraitPropertyImpl(name = get_name(a), range = get_rust_range_info(sco, find_slot(a.name), self.schemaview), struct_name=get_name(sco)) for a in attribs] + + ptis = [ + PolyTraitPropertyImpl( + name=get_name(a), + range=get_rust_range_info(sco, find_slot(a.name), self.schemaview), + struct_name=get_name(sco), + ) + for a in attribs + ] impls.append(PolyTraitImpl(name=class_name, struct_name=get_name(sco), attrs=ptis)) has_subtypes = len(self.schemaview.class_descendants(sc)) > 1 if has_subtypes: cases = [get_name(self.schemaview.get_class(x)) for x in self.schemaview.class_descendants(sc)] - matches = [PolyTraitPropertyMatch(name=get_name(a), range = get_rust_range_info(sco, find_slot(a.name), self.schemaview), cases=cases, struct_name=f"{get_name(sco)}OrSubtype") for a in attribs] - subtype_impls.append(PolyTraitImplForSubtypeEnum(name=class_name, enum_name=f"{get_name(sco)}OrSubtype",attrs=matches)) - return PolyTrait(name=class_name, impls=impls, attrs=rust_attribs, superclass_names = [get_name(scla) for scla in superclasses], subtypes=subtype_impls) - + matches = [ + PolyTraitPropertyMatch( + name=get_name(a), + range=get_rust_range_info(sco, find_slot(a.name), self.schemaview), + cases=cases, + struct_name=f"{get_name(sco)}OrSubtype", + ) + for a in attribs + ] + subtype_impls.append( + PolyTraitImplForSubtypeEnum(name=class_name, enum_name=f"{get_name(sco)}OrSubtype", attrs=matches) + ) + return PolyTrait( + name=class_name, + impls=impls, + attrs=rust_attribs, + superclass_names=[get_name(scla) for scla in superclasses], + subtypes=subtype_impls, + ) def serialize(self, output: Optional[Path] = None, mode: Optional[RUST_MODES] = None, force: bool = False) -> str: """ @@ -690,7 +751,7 @@ def write_crate( with open(lib_file, "w") as lfile: lfile.write(rust_file) - for (k,f) in rendered.extra_files.items(): + for k, f in rendered.extra_files.items(): extra_file = f.render(self.template_environment) extra_file_name = f"{k}.rs" extra_file_path = src_dir / extra_file_name diff --git a/linkml/generators/rustgen/template.py b/linkml/generators/rustgen/template.py index da0ea11d8..b1339d626 100644 --- a/linkml/generators/rustgen/template.py +++ b/linkml/generators/rustgen/template.py @@ -1,9 +1,9 @@ -from typing import ClassVar, Optional, List - from enum import Enum +from typing import ClassVar, List, Optional + from jinja2 import Environment, PackageLoader from linkml_runtime.utils.formatutils import underscore -from pydantic import Field, computed_field, field_validator, BaseModel +from pydantic import BaseModel, Field, computed_field, field_validator from linkml.generators.common.template import Import as Import_ from linkml.generators.common.template import Imports as Imports_ @@ -14,6 +14,7 @@ class ContainerType(Enum): LIST = "list" MAPPING = "mapping" + class RustRange(BaseModel): optional: bool = False containerType: Optional[ContainerType] = None @@ -24,7 +25,7 @@ class RustRange(BaseModel): has_class_subtypes: bool = False child_ranges: Optional[List["RustRange"]] = None type_: str - + def type_for_field(self): tp = self.type_ if self.has_class_subtypes: @@ -39,8 +40,8 @@ def type_for_field(self): if self.optional and self.containerType is None: tp = f"Option<{tp}>" return tp - - def type_for_trait(self, crateref: Optional[str], setter: bool=False): + + def type_for_trait(self, crateref: Optional[str], setter: bool = False): tp = self.type_ if self.is_class_range and not self.has_class_subtypes and not setter: if crateref and not self.is_reference: @@ -52,7 +53,7 @@ def type_for_trait(self, crateref: Optional[str], setter: bool=False): convert_ref = False if self.containerType == ContainerType.LIST: if not setter: - #tp = f"poly_containers::ListView<{tp}>" + # tp = f"poly_containers::ListView<{tp}>" tp = f"impl poly_containers::SeqRef<{tp}>" else: tp = f"&Vec<{tp}>" @@ -70,17 +71,15 @@ def type_for_trait(self, crateref: Optional[str], setter: bool=False): tp = f"Option<{tp}>" else: tp = f"&Option<{tp}>" - + return tp - - + def type_bound_for_setter(self, crateref: Optional[str]) -> Optional[str]: if self.is_class_range: tp = self.type_ return f"Into<{tp}>" return None - - + class RustTemplateModel(TemplateModel): """ @@ -103,6 +102,7 @@ class RustTemplateModel(TemplateModel): """ attributes: dict[str, str] = Field(default_factory=dict) + class PolyContainersFile(RustTemplateModel): template: ClassVar[str] = "poly_containers.rs.jinja" @@ -132,11 +132,11 @@ class RustProperty(RustTemplateModel): inline_mode: str container_mode: str name: str - type_: RustRange # might be a union type, so list length > 1 + type_: RustRange # might be a union type, so list length > 1 required: bool multivalued: bool = False is_key_value: bool = False - + @computed_field def type_for_field(self) -> str: """ @@ -147,12 +147,13 @@ def type_for_field(self) -> str: @computed_field def hasdefault(self) -> bool: return self.multivalued or not self.required - + class AsKeyValue(RustTemplateModel): """ A key-value representation for this struct """ + template: ClassVar[str] = "as_key_value.rs.jinja" name: str key_property_name: str @@ -162,34 +163,40 @@ class AsKeyValue(RustTemplateModel): can_convert_from_primitive: bool = False can_convert_from_empty: bool = False + class RustStructOrSubtypeEnum(RustTemplateModel): template: ClassVar[str] = "struct_or_subtype_enum.rs.jinja" enum_name: str - struct_names: list[str] + struct_names: list[str] as_key_value: bool = False type_designator_field: Optional[str] = None type_designators: dict[str, str] + class SlotRangeAsUnion(RustTemplateModel): """ A union of ranges! """ + template: ClassVar[str] = "slot_range_as_union.rs.jinja" slot_name: str ranges: list[str] + class RustClassModule(RustTemplateModel): class_name: str class_name_snakecase: str template: ClassVar[str] = "class_module.rs.jinja" slot_ranges: List[SlotRangeAsUnion] + + class RustStruct(RustTemplateModel): """ A struct! """ template: ClassVar[str] = "struct.rs.jinja" - class_module : Optional[RustClassModule] = None + class_module: Optional[RustClassModule] = None name: str bases: Optional[list[str]] = None @@ -221,7 +228,6 @@ class RustEnum(RustTemplateModel): items: list[str] - class RustTypeAlias(RustTemplateModel): """ A type alias used to represent slots @@ -250,27 +256,26 @@ class SerdeUtilsFile(RustTemplateModel): template: ClassVar[str] = "serde_utils.rs.jinja" - class PolyTraitProperty(RustTemplateModel): template: ClassVar[str] = "poly_trait_property.rs.jinja" name: str range: RustRange - + @computed_field def class_range(self) -> bool: """ Whether this range is a class range """ return self.range.is_class_range - + @computed_field def type_getter(self) -> str: return self.range.type_for_trait(setter=False, crateref="crate") - + @computed_field def type_setter(self) -> str: return self.range.type_for_trait(setter=True, crateref="crate") - + @computed_field def type_bound(self) -> Optional[str]: """ @@ -283,7 +288,8 @@ class PolyTraitPropertyImpl(RustTemplateModel): template: ClassVar[str] = "poly_trait_property_impl.rs.jinja" name: str range: RustRange - struct_name: str + struct_name: str + @computed_field def class_range(self) -> bool: """ @@ -297,15 +303,15 @@ def ct(self) -> str: The container type for this range, if any """ return self.range.containerType.value if self.range.containerType else "None" - + @computed_field def type_getter(self) -> str: return self.range.type_for_trait(setter=False, crateref="crate") - + @computed_field def type_setter(self) -> str: return self.range.type_for_trait(setter=True, crateref="crate") - + @computed_field def type_bound(self) -> Optional[str]: """ @@ -315,11 +321,12 @@ def type_bound(self) -> Optional[str]: class PolyTraitImpl(RustTemplateModel): - template : ClassVar[str] = "poly_trait_impl.rs.jinja" + """Implementation of a :class:`PolyTrait` for a particular struct.""" + + template: ClassVar[str] = "poly_trait_impl.rs.jinja" name: str struct_name: str attrs: List[PolyTraitPropertyImpl] - class PolyTraitPropertyMatch(RustTemplateModel): @@ -340,14 +347,20 @@ def is_container(self) -> bool: def type_getter(self) -> str: return self.range.type_for_trait(setter=False, crateref="crate") + class PolyTraitImplForSubtypeEnum(RustTemplateModel): + """Trait implementation that dispatches based on subtype enums.""" + template: ClassVar[str] = "poly_trait_impl_orsubtype.rs.jinja" enum_name: str name: str attrs: List[PolyTraitPropertyMatch] + class PolyTrait(RustTemplateModel): - template : ClassVar[str] = "poly_trait.rs.jinja" + """Definition of a polymorphic trait generated from a class hierarchy.""" + + template: ClassVar[str] = "poly_trait.rs.jinja" name: str attrs: List[PolyTraitProperty] superclass_names: List[str] @@ -355,14 +368,14 @@ class PolyTrait(RustTemplateModel): subtypes: List[PolyTraitImplForSubtypeEnum] - class PolyFile(RustTemplateModel): + """Rust file aggregating polymorphic traits.""" + template: ClassVar[str] = "poly.rs.jinja" imports: Imports = Imports() traits: List[PolyTrait] - class RustFile(RustTemplateModel): """ A whole rust file! @@ -386,11 +399,13 @@ def struct_names(self) -> list[str]: class RangeEnum(RustTemplateModel): """ A range enum! - """ + """ + template: ClassVar[str] = "range_enum.rs.jinja" name: str type_: List[str] + class RustCargo(RustTemplateModel): """ A Cargo.toml file @@ -416,14 +431,11 @@ def cratefeatures(self) -> dict[str, list[str]]: feature_flags[i.feature_flag].append(i.module) return feature_flags - - @field_validator("name", mode="after") @classmethod def snake_case_name(cls, value: str) -> str: return underscore(value) - def render(self, environment: Optional[Environment] = None, **kwargs) -> str: if environment is None: environment = RustTemplateModel.environment() diff --git a/tests/test_generators/test_rustgen.py b/tests/test_generators/test_rustgen.py index 802c947e0..b51778e08 100644 --- a/tests/test_generators/test_rustgen.py +++ b/tests/test_generators/test_rustgen.py @@ -1,9 +1,43 @@ import pytest +from linkml_runtime.utils.schemaview import SchemaView from linkml.generators.rustgen import RustGenerator +from linkml.generators.rustgen.rustgen import ( + SlotContainerMode, + SlotInlineMode, + determine_slot_mode, +) @pytest.mark.rustgen def test_generate_crate(kitchen_sink_path, temp_dir): gen = RustGenerator(kitchen_sink_path, mode="crate", output=temp_dir) _ = gen.serialize(force=True) + + +def test_determine_slot_mode(kitchen_sink_path): + sv = SchemaView(kitchen_sink_path) + + age_slot = sv.get_slot("age in years") + assert determine_slot_mode(age_slot, sv) == ( + SlotContainerMode.SINGLE_VALUE, + SlotInlineMode.PRIMITIVE, + ) + + aliases_slot = sv.get_slot("aliases") + assert determine_slot_mode(aliases_slot, sv) == ( + SlotContainerMode.LIST, + SlotInlineMode.PRIMITIVE, + ) + + hist_slot = sv.get_slot("has employment history") + assert determine_slot_mode(hist_slot, sv) == ( + SlotContainerMode.LIST, + SlotInlineMode.INLINE, + ) + + employed_slot = sv.get_slot("employed at") + assert determine_slot_mode(employed_slot, sv) == ( + SlotContainerMode.SINGLE_VALUE, + SlotInlineMode.REFERENCE, + )