From ad75ce353de765010f3df397b21afb41303d56b9 Mon Sep 17 00:00:00 2001 From: Ashish Bagri Date: Wed, 25 Jun 2025 14:31:40 +0200 Subject: [PATCH 1/3] Add support for nested data generation --- glassgen/schema/schema.py | 123 +++++++++++++--------- glassgen/sinks/csv_sink.py | 52 +++++++++- tests/test_schema.py | 202 +++++++++++++++++++++++++++++++++++++ tests/test_sinks.py | 3 +- 4 files changed, 326 insertions(+), 54 deletions(-) diff --git a/glassgen/schema/schema.py b/glassgen/schema/schema.py index 9913a4d..a6211f6 100644 --- a/glassgen/schema/schema.py +++ b/glassgen/schema/schema.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from pydantic import BaseModel, Field @@ -13,71 +13,96 @@ class SchemaField(BaseModel): params: List[Any] = Field(default_factory=list) +class NestedSchemaField(BaseModel): + """Represents a nested schema field that contains other fields""" + name: str + fields: Dict[str, Union[SchemaField, 'NestedSchemaField']] + + class ConfigSchema(BaseSchema, BaseModel): """Schema implementation that can be created from a configuration""" - fields: Dict[str, SchemaField] + fields: Dict[str, Union[SchemaField, NestedSchemaField]] @classmethod - def from_dict(cls, schema_dict: Dict[str, str]) -> "ConfigSchema": + def from_dict(cls, schema_dict: Dict[str, Any]) -> "ConfigSchema": """Create a schema from a configuration dictionary""" fields = cls._schema_dict_to_fields(schema_dict) return cls(fields=fields) @staticmethod - def _schema_dict_to_fields(schema_dict: Dict[str, str]) -> Dict[str, SchemaField]: - """Convert a schema dictionary to a dictionary of SchemaField objects""" + def _schema_dict_to_fields(schema_dict: Dict[str, Any]) -> Dict[str, Union[SchemaField, NestedSchemaField]]: + """Convert a schema dictionary to a dictionary of SchemaField or NestedSchemaField objects""" fields = {} - for name, generator_str in schema_dict.items(): - match = re.match(r"\$(\w+)(?:\((.*)\))?", generator_str) - if not match: - raise ValueError(f"Invalid generator format: {generator_str}") - - generator_name = match.group(1) - params_str = match.group(2) - - params = [] - if params_str: - # Handle choice generator specially - if generator_name == GeneratorType.CHOICE: - # Split by comma but preserve quoted strings - params = [p.strip().strip("\"'") for p in params_str.split(",")] - - else: - # Simple parameter parsing for other generators - params = [p.strip() for p in params_str.split(",")] - # Convert numeric parameters - params = [int(p) if p.isdigit() else p for p in params] - - fields[name] = SchemaField( - name=name, generator=generator_name, params=params - ) + for name, value in schema_dict.items(): + if isinstance(value, dict): + # Handle nested structure + nested_fields = ConfigSchema._schema_dict_to_fields(value) + fields[name] = NestedSchemaField(name=name, fields=nested_fields) + elif isinstance(value, str): + # Handle flat generator string + match = re.match(r"\$(\w+)(?:\((.*)\))?", value) + if not match: + raise ValueError(f"Invalid generator format: {value}") + + generator_name = match.group(1) + params_str = match.group(2) + + params = [] + if params_str: + # Handle choice generator specially + if generator_name == GeneratorType.CHOICE: + # Split by comma but preserve quoted strings + params = [p.strip().strip("\"'") for p in params_str.split(",")] + else: + # Simple parameter parsing for other generators + params = [p.strip() for p in params_str.split(",")] + # Convert numeric parameters + params = [int(p) if p.isdigit() else p for p in params] + + fields[name] = SchemaField( + name=name, generator=generator_name, params=params + ) + else: + raise ValueError(f"Invalid schema value type for field '{name}': {type(value)}") return fields def validate(self) -> None: """Validate that all generators are supported""" supported_generators = set(registry.get_supported_generators().keys()) - for field in self.fields.values(): - if field.generator not in supported_generators: - raise ValueError( - f"Unsupported generator: {field.generator}. " - f"Supported generators are: {', '.join(supported_generators)}" - ) + def validate_fields(fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]]): + for field in fields_dict.values(): + if isinstance(field, SchemaField): + if field.generator not in supported_generators: + raise ValueError( + f"Unsupported generator: {field.generator}. " + f"Supported generators are: {', '.join(supported_generators)}" + ) + elif isinstance(field, NestedSchemaField): + validate_fields(field.fields) + + validate_fields(self.fields) def _generate_record(self) -> Dict[str, Any]: """Generate a single record based on the schema""" - record = {} - for field_name, field in self.fields.items(): - generator = registry.get_generator(field.generator) - # Pass parameters to the generator if they exist - if field.params: - if field.generator == GeneratorType.CHOICE: - # For choice generator, pass the list directly - record[field_name] = generator(field.params) - else: - # For other generators, unpack the parameters - record[field_name] = generator(*field.params) - else: - record[field_name] = generator() - return record + def generate_nested_record(fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]]) -> Dict[str, Any]: + record = {} + for field_name, field in fields_dict.items(): + if isinstance(field, SchemaField): + generator = registry.get_generator(field.generator) + # Pass parameters to the generator if they exist + if field.params: + if field.generator == GeneratorType.CHOICE: + # For choice generator, pass the list directly + record[field_name] = generator(field.params) + else: + # For other generators, unpack the parameters + record[field_name] = generator(*field.params) + else: + record[field_name] = generator() + elif isinstance(field, NestedSchemaField): + record[field_name] = generate_nested_record(field.fields) + return record + + return generate_nested_record(self.fields) diff --git a/glassgen/sinks/csv_sink.py b/glassgen/sinks/csv_sink.py index 42a25e3..55cbbca 100644 --- a/glassgen/sinks/csv_sink.py +++ b/glassgen/sinks/csv_sink.py @@ -19,23 +19,67 @@ def __init__(self, sink_params: Dict[str, Any]): self.file = None self.fieldnames = None + def _flatten_dict(self, data: Dict[str, Any], parent_key: str = '', sep: str = '_') -> Dict[str, Any]: + """ + Flatten a nested dictionary by concatenating keys with separator. + + Args: + data: The dictionary to flatten + parent_key: The parent key prefix + sep: Separator to use between nested keys + + Returns: + Flattened dictionary with concatenated keys + """ + items = [] + for key, value in data.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + if isinstance(value, dict): + # Recursively flatten nested dictionaries + items.extend(self._flatten_dict(value, new_key, sep=sep).items()) + else: + items.append((new_key, value)) + return dict(items) + + def _get_flattened_fieldnames(self, data: List[Dict[str, Any]]) -> List[str]: + """ + Get fieldnames from flattened data to ensure consistent column order. + + Args: + data: List of dictionaries to process + + Returns: + List of fieldnames for CSV header + """ + all_keys = set() + for item in data: + flattened = self._flatten_dict(item) + all_keys.update(flattened.keys()) + return sorted(list(all_keys)) + def publish(self, data: Dict[str, Any]) -> None: + # Flatten the data before writing + flattened_data = self._flatten_dict(data) + if self.writer is None: self.file = open(self.filepath, "w", newline="") - self.fieldnames = list(data.keys()) + self.fieldnames = list(flattened_data.keys()) self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames) self.writer.writeheader() - self.writer.writerow(data) + self.writer.writerow(flattened_data) def publish_bulk(self, data: List[Dict[str, Any]]) -> None: if self.writer is None: self.file = open(self.filepath, "w", newline="") - self.fieldnames = list(data[0].keys()) + # Get fieldnames from all data to ensure consistent columns + self.fieldnames = self._get_flattened_fieldnames(data) self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames) self.writer.writeheader() - self.writer.writerows(data) + # Flatten each record before writing + flattened_data = [self._flatten_dict(item) for item in data] + self.writer.writerows(flattened_data) def close(self) -> None: if self.file: diff --git a/tests/test_schema.py b/tests/test_schema.py index 5ed6d0e..903e052 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -95,3 +95,205 @@ def test_predefined_schema(): assert "@" in record["email"] assert isinstance(record["age"], int) assert isinstance(record["phone"], str) + + +def test_flat_schema(): + """Test that flat schemas still work as before""" + schema_dict = { + "name": "$name", + "email": "$email", + "country": "$country", + "id": "$uuid", + "phone": "$phone_number", + "job": "$job", + "company": "$company", + "intrange": "$intrange(10,100)", + "choice": "$choice(apple,banana,cherry)", + "weburl": "$url", + "ts": "$datetime", + "tsunix": "$timestamp" + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check that all fields are present + assert "name" in record + assert "email" in record + assert "country" in record + assert "id" in record + assert "phone" in record + assert "job" in record + assert "company" in record + assert "intrange" in record + assert "choice" in record + assert "weburl" in record + assert "ts" in record + assert "tsunix" in record + + # Check that intrange is within the specified range + assert 10 <= record["intrange"] <= 100 + + # Check that choice is one of the specified values + assert record["choice"] in ["apple", "banana", "cherry"] + + +def test_nested_schema(): + """Test that nested schemas work correctly""" + schema_dict = { + "name": "$name", + "email": "$email", + "country": "$country", + "id": "$uuid", + "location": { + "address": "$address", + "city": "$city", + "postal_code": "$zipcode" + }, + "phone": "$phone_number", + "job": "$job", + "company": "$company", + "intrange": "$intrange(10,100)", + "choice": "$choice(apple,banana,cherry)", + "weburl": "$url", + "ts": "$datetime", + "tsunix": "$timestamp" + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check that flat fields are present + assert "name" in record + assert "email" in record + assert "country" in record + assert "id" in record + assert "phone" in record + assert "job" in record + assert "company" in record + assert "intrange" in record + assert "choice" in record + assert "weburl" in record + assert "ts" in record + assert "tsunix" in record + + # Check that nested location field is present and has the expected structure + assert "location" in record + assert isinstance(record["location"], dict) + assert "address" in record["location"] + assert "city" in record["location"] + assert "postal_code" in record["location"] + + # Check that intrange is within the specified range + assert 10 <= record["intrange"] <= 100 + + # Check that choice is one of the specified values + assert record["choice"] in ["apple", "banana", "cherry"] + + +def test_deeply_nested_schema(): + """Test deeply nested schemas work correctly""" + schema_dict = { + "user": { + "personal": { + "name": "$name", + "email": "$email", + "phone": "$phone_number" + }, + "address": { + "street": "$address", + "city": "$city", + "country": "$country", + "postal_code": "$zipcode" + } + }, + "metadata": { + "id": "$uuid", + "created_at": "$datetime", + "tags": "$choice(tag1,tag2,tag3)" + } + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check nested structure + assert "user" in record + assert isinstance(record["user"], dict) + + assert "personal" in record["user"] + assert isinstance(record["user"]["personal"], dict) + assert "name" in record["user"]["personal"] + assert "email" in record["user"]["personal"] + assert "phone" in record["user"]["personal"] + + assert "address" in record["user"] + assert isinstance(record["user"]["address"], dict) + assert "street" in record["user"]["address"] + assert "city" in record["user"]["address"] + assert "country" in record["user"]["address"] + assert "postal_code" in record["user"]["address"] + + assert "metadata" in record + assert isinstance(record["metadata"], dict) + assert "id" in record["metadata"] + assert "created_at" in record["metadata"] + assert "tags" in record["metadata"] + assert record["metadata"]["tags"] in ["tag1", "tag2", "tag3"] + + +def test_mixed_flat_and_nested_schema(): + """Test schemas with both flat and nested fields""" + schema_dict = { + "name": "$name", + "email": "$email", + "location": { + "address": "$address", + "city": "$city" + }, + "phone": "$phone_number", + "preferences": { + "theme": "$choice(dark,light)", + "notifications": "$choice(true,false)" + } + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check flat fields + assert "name" in record + assert "email" in record + assert "phone" in record + + # Check nested fields + assert "location" in record + assert isinstance(record["location"], dict) + assert "address" in record["location"] + assert "city" in record["location"] + + assert "preferences" in record + assert isinstance(record["preferences"], dict) + assert "theme" in record["preferences"] + assert "notifications" in record["preferences"] + assert record["preferences"]["theme"] in ["dark", "light"] + assert record["preferences"]["notifications"] in ["true", "false"] + + +def test_invalid_schema_value_type(): + """Test that invalid schema value types raise appropriate errors""" + schema_dict = { + "name": "$name", + "invalid_field": 123 # Invalid type - should be string or dict + } + + with pytest.raises(ValueError, match="Invalid schema value type for field 'invalid_field'"): + ConfigSchema.from_dict(schema_dict) diff --git a/tests/test_sinks.py b/tests/test_sinks.py index f9649f5..6f313ca 100644 --- a/tests/test_sinks.py +++ b/tests/test_sinks.py @@ -64,7 +64,8 @@ def test_csv_sink_bulk_publish(temp_csv_file): with open(temp_csv_file, "r") as f: lines = f.readlines() assert len(lines) == 4 # Header + 3 data rows - assert "name,age" in lines[0] # Header + header = lines[0].strip().split(',') + assert set(header) == {"name", "age"} # Header fields, order-agnostic # Clean up os.unlink(temp_csv_file) From 0724844722b8a56a182d23b8890953ca797ad55a Mon Sep 17 00:00:00 2001 From: Ashish Bagri Date: Wed, 25 Jun 2025 14:49:51 +0200 Subject: [PATCH 2/3] Fix ruff checks --- examples/usage.py | 27 +++++++++++--------- glassgen/schema/schema.py | 22 ++++++++++++----- glassgen/sinks/csv_sink.py | 22 +++++++++++------ tests/test_schema.py | 50 ++++++++++++++++++++------------------ 4 files changed, 72 insertions(+), 49 deletions(-) diff --git a/examples/usage.py b/examples/usage.py index 6644cd3..f399499 100644 --- a/examples/usage.py +++ b/examples/usage.py @@ -3,16 +3,21 @@ config = { "schema": { "name": "$name", - "email": "$email", - "country": "$country", - "id": "$uuid", - "address": "$address", - "phone": "$phone_number", - "job": "$job", - "company": "$company", + "user": { + "email": "$email", + "id": "$uuid" + } }, - "sink": {"type": "csv", "params": {"path": "output.csv"}}, - "generator": {"rps": 1500, "num_records": 5000}, + "generator": { + "num_records": 10 + } } -# Start the generator -print(glassgen.generate(config=config)) +sink_csv = { + "type": "csv", + "params": {"path": "output.csv"} +} +config["sink"] = sink_csv + +gen = glassgen.generate(config=config) +for row in gen: + print(row) diff --git a/glassgen/schema/schema.py b/glassgen/schema/schema.py index a6211f6..75efe83 100644 --- a/glassgen/schema/schema.py +++ b/glassgen/schema/schema.py @@ -31,8 +31,11 @@ def from_dict(cls, schema_dict: Dict[str, Any]) -> "ConfigSchema": return cls(fields=fields) @staticmethod - def _schema_dict_to_fields(schema_dict: Dict[str, Any]) -> Dict[str, Union[SchemaField, NestedSchemaField]]: - """Convert a schema dictionary to a dictionary of SchemaField or NestedSchemaField objects""" + def _schema_dict_to_fields( + schema_dict: Dict[str, Any], + ) -> Dict[str, Union[SchemaField, NestedSchemaField]]: + """Convert a schema dictionary to a dictionary of SchemaField or + NestedSchemaField objects""" fields = {} for name, value in schema_dict.items(): if isinstance(value, dict): @@ -64,20 +67,25 @@ def _schema_dict_to_fields(schema_dict: Dict[str, Any]) -> Dict[str, Union[Schem name=name, generator=generator_name, params=params ) else: - raise ValueError(f"Invalid schema value type for field '{name}': {type(value)}") + raise ValueError( + f"Invalid schema value type for field '{name}': {type(value)}" + ) return fields def validate(self) -> None: """Validate that all generators are supported""" supported_generators = set(registry.get_supported_generators().keys()) - def validate_fields(fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]]): + def validate_fields( + fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]] + ): for field in fields_dict.values(): if isinstance(field, SchemaField): if field.generator not in supported_generators: raise ValueError( f"Unsupported generator: {field.generator}. " - f"Supported generators are: {', '.join(supported_generators)}" + f"Supported generators are: " + f"{', '.join(supported_generators)}" ) elif isinstance(field, NestedSchemaField): validate_fields(field.fields) @@ -86,7 +94,9 @@ def validate_fields(fields_dict: Dict[str, Union[SchemaField, NestedSchemaField] def _generate_record(self) -> Dict[str, Any]: """Generate a single record based on the schema""" - def generate_nested_record(fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]]) -> Dict[str, Any]: + def generate_nested_record( + fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]] + ) -> Dict[str, Any]: record = {} for field_name, field in fields_dict.items(): if isinstance(field, SchemaField): diff --git a/glassgen/sinks/csv_sink.py b/glassgen/sinks/csv_sink.py index 55cbbca..83dcc7f 100644 --- a/glassgen/sinks/csv_sink.py +++ b/glassgen/sinks/csv_sink.py @@ -19,15 +19,17 @@ def __init__(self, sink_params: Dict[str, Any]): self.file = None self.fieldnames = None - def _flatten_dict(self, data: Dict[str, Any], parent_key: str = '', sep: str = '_') -> Dict[str, Any]: + def _flatten_dict( + self, data: Dict[str, Any], parent_key: str = '', sep: str = '_' + ) -> Dict[str, Any]: """ Flatten a nested dictionary by concatenating keys with separator. - + Args: data: The dictionary to flatten parent_key: The parent key prefix sep: Separator to use between nested keys - + Returns: Flattened dictionary with concatenated keys """ @@ -36,18 +38,22 @@ def _flatten_dict(self, data: Dict[str, Any], parent_key: str = '', sep: str = ' new_key = f"{parent_key}{sep}{key}" if parent_key else key if isinstance(value, dict): # Recursively flatten nested dictionaries - items.extend(self._flatten_dict(value, new_key, sep=sep).items()) + items.extend( + self._flatten_dict(value, new_key, sep=sep).items() + ) else: items.append((new_key, value)) return dict(items) - def _get_flattened_fieldnames(self, data: List[Dict[str, Any]]) -> List[str]: + def _get_flattened_fieldnames( + self, data: List[Dict[str, Any]] + ) -> List[str]: """ Get fieldnames from flattened data to ensure consistent column order. - + Args: data: List of dictionaries to process - + Returns: List of fieldnames for CSV header """ @@ -60,7 +66,7 @@ def _get_flattened_fieldnames(self, data: List[Dict[str, Any]]) -> List[str]: def publish(self, data: Dict[str, Any]) -> None: # Flatten the data before writing flattened_data = self._flatten_dict(data) - + if self.writer is None: self.file = open(self.filepath, "w", newline="") self.fieldnames = list(flattened_data.keys()) diff --git a/tests/test_schema.py b/tests/test_schema.py index 903e052..da29846 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -113,12 +113,12 @@ def test_flat_schema(): "ts": "$datetime", "tsunix": "$timestamp" } - + schema = ConfigSchema.from_dict(schema_dict) schema.validate() - + record = schema._generate_record() - + # Check that all fields are present assert "name" in record assert "email" in record @@ -132,10 +132,10 @@ def test_flat_schema(): assert "weburl" in record assert "ts" in record assert "tsunix" in record - + # Check that intrange is within the specified range assert 10 <= record["intrange"] <= 100 - + # Check that choice is one of the specified values assert record["choice"] in ["apple", "banana", "cherry"] @@ -161,12 +161,12 @@ def test_nested_schema(): "ts": "$datetime", "tsunix": "$timestamp" } - + schema = ConfigSchema.from_dict(schema_dict) schema.validate() - + record = schema._generate_record() - + # Check that flat fields are present assert "name" in record assert "email" in record @@ -180,17 +180,17 @@ def test_nested_schema(): assert "weburl" in record assert "ts" in record assert "tsunix" in record - + # Check that nested location field is present and has the expected structure assert "location" in record assert isinstance(record["location"], dict) assert "address" in record["location"] assert "city" in record["location"] assert "postal_code" in record["location"] - + # Check that intrange is within the specified range assert 10 <= record["intrange"] <= 100 - + # Check that choice is one of the specified values assert record["choice"] in ["apple", "banana", "cherry"] @@ -217,29 +217,29 @@ def test_deeply_nested_schema(): "tags": "$choice(tag1,tag2,tag3)" } } - + schema = ConfigSchema.from_dict(schema_dict) schema.validate() - + record = schema._generate_record() - + # Check nested structure assert "user" in record assert isinstance(record["user"], dict) - + assert "personal" in record["user"] assert isinstance(record["user"]["personal"], dict) assert "name" in record["user"]["personal"] assert "email" in record["user"]["personal"] assert "phone" in record["user"]["personal"] - + assert "address" in record["user"] assert isinstance(record["user"]["address"], dict) assert "street" in record["user"]["address"] assert "city" in record["user"]["address"] assert "country" in record["user"]["address"] assert "postal_code" in record["user"]["address"] - + assert "metadata" in record assert isinstance(record["metadata"], dict) assert "id" in record["metadata"] @@ -263,23 +263,23 @@ def test_mixed_flat_and_nested_schema(): "notifications": "$choice(true,false)" } } - + schema = ConfigSchema.from_dict(schema_dict) schema.validate() - + record = schema._generate_record() - + # Check flat fields assert "name" in record assert "email" in record assert "phone" in record - + # Check nested fields assert "location" in record assert isinstance(record["location"], dict) assert "address" in record["location"] assert "city" in record["location"] - + assert "preferences" in record assert isinstance(record["preferences"], dict) assert "theme" in record["preferences"] @@ -294,6 +294,8 @@ def test_invalid_schema_value_type(): "name": "$name", "invalid_field": 123 # Invalid type - should be string or dict } - - with pytest.raises(ValueError, match="Invalid schema value type for field 'invalid_field'"): + + with pytest.raises( + ValueError, match="Invalid schema value type for field 'invalid_field'" + ): ConfigSchema.from_dict(schema_dict) From 0ffd7f62adfa991eae9502846f70d6b653e260b6 Mon Sep 17 00:00:00 2001 From: Ashish Bagri Date: Wed, 25 Jun 2025 14:55:26 +0200 Subject: [PATCH 3/3] Update ruff --- examples/usage.py | 17 +++-------------- glassgen/schema/schema.py | 10 ++++++---- glassgen/sinks/csv_sink.py | 10 +++------- tests/test_schema.py | 35 ++++++++++++----------------------- tests/test_sinks.py | 2 +- 5 files changed, 25 insertions(+), 49 deletions(-) diff --git a/examples/usage.py b/examples/usage.py index f399499..70ee849 100644 --- a/examples/usage.py +++ b/examples/usage.py @@ -1,21 +1,10 @@ import glassgen config = { - "schema": { - "name": "$name", - "user": { - "email": "$email", - "id": "$uuid" - } - }, - "generator": { - "num_records": 10 - } -} -sink_csv = { - "type": "csv", - "params": {"path": "output.csv"} + "schema": {"name": "$name", "user": {"email": "$email", "id": "$uuid"}}, + "generator": {"num_records": 10}, } +sink_csv = {"type": "csv", "params": {"path": "output.csv"}} config["sink"] = sink_csv gen = glassgen.generate(config=config) diff --git a/glassgen/schema/schema.py b/glassgen/schema/schema.py index 75efe83..bceb97e 100644 --- a/glassgen/schema/schema.py +++ b/glassgen/schema/schema.py @@ -15,8 +15,9 @@ class SchemaField(BaseModel): class NestedSchemaField(BaseModel): """Represents a nested schema field that contains other fields""" + name: str - fields: Dict[str, Union[SchemaField, 'NestedSchemaField']] + fields: Dict[str, Union[SchemaField, "NestedSchemaField"]] class ConfigSchema(BaseSchema, BaseModel): @@ -35,7 +36,7 @@ def _schema_dict_to_fields( schema_dict: Dict[str, Any], ) -> Dict[str, Union[SchemaField, NestedSchemaField]]: """Convert a schema dictionary to a dictionary of SchemaField or - NestedSchemaField objects""" + NestedSchemaField objects""" fields = {} for name, value in schema_dict.items(): if isinstance(value, dict): @@ -77,7 +78,7 @@ def validate(self) -> None: supported_generators = set(registry.get_supported_generators().keys()) def validate_fields( - fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]] + fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]], ): for field in fields_dict.values(): if isinstance(field, SchemaField): @@ -94,8 +95,9 @@ def validate_fields( def _generate_record(self) -> Dict[str, Any]: """Generate a single record based on the schema""" + def generate_nested_record( - fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]] + fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]], ) -> Dict[str, Any]: record = {} for field_name, field in fields_dict.items(): diff --git a/glassgen/sinks/csv_sink.py b/glassgen/sinks/csv_sink.py index 83dcc7f..8583009 100644 --- a/glassgen/sinks/csv_sink.py +++ b/glassgen/sinks/csv_sink.py @@ -20,7 +20,7 @@ def __init__(self, sink_params: Dict[str, Any]): self.fieldnames = None def _flatten_dict( - self, data: Dict[str, Any], parent_key: str = '', sep: str = '_' + self, data: Dict[str, Any], parent_key: str = "", sep: str = "_" ) -> Dict[str, Any]: """ Flatten a nested dictionary by concatenating keys with separator. @@ -38,16 +38,12 @@ def _flatten_dict( new_key = f"{parent_key}{sep}{key}" if parent_key else key if isinstance(value, dict): # Recursively flatten nested dictionaries - items.extend( - self._flatten_dict(value, new_key, sep=sep).items() - ) + items.extend(self._flatten_dict(value, new_key, sep=sep).items()) else: items.append((new_key, value)) return dict(items) - def _get_flattened_fieldnames( - self, data: List[Dict[str, Any]] - ) -> List[str]: + def _get_flattened_fieldnames(self, data: List[Dict[str, Any]]) -> List[str]: """ Get fieldnames from flattened data to ensure consistent column order. diff --git a/tests/test_schema.py b/tests/test_schema.py index da29846..2ff8e19 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -111,7 +111,7 @@ def test_flat_schema(): "choice": "$choice(apple,banana,cherry)", "weburl": "$url", "ts": "$datetime", - "tsunix": "$timestamp" + "tsunix": "$timestamp", } schema = ConfigSchema.from_dict(schema_dict) @@ -147,11 +147,7 @@ def test_nested_schema(): "email": "$email", "country": "$country", "id": "$uuid", - "location": { - "address": "$address", - "city": "$city", - "postal_code": "$zipcode" - }, + "location": {"address": "$address", "city": "$city", "postal_code": "$zipcode"}, "phone": "$phone_number", "job": "$job", "company": "$company", @@ -159,7 +155,7 @@ def test_nested_schema(): "choice": "$choice(apple,banana,cherry)", "weburl": "$url", "ts": "$datetime", - "tsunix": "$timestamp" + "tsunix": "$timestamp", } schema = ConfigSchema.from_dict(schema_dict) @@ -199,23 +195,19 @@ def test_deeply_nested_schema(): """Test deeply nested schemas work correctly""" schema_dict = { "user": { - "personal": { - "name": "$name", - "email": "$email", - "phone": "$phone_number" - }, + "personal": {"name": "$name", "email": "$email", "phone": "$phone_number"}, "address": { "street": "$address", "city": "$city", "country": "$country", - "postal_code": "$zipcode" - } + "postal_code": "$zipcode", + }, }, "metadata": { "id": "$uuid", "created_at": "$datetime", - "tags": "$choice(tag1,tag2,tag3)" - } + "tags": "$choice(tag1,tag2,tag3)", + }, } schema = ConfigSchema.from_dict(schema_dict) @@ -253,15 +245,12 @@ def test_mixed_flat_and_nested_schema(): schema_dict = { "name": "$name", "email": "$email", - "location": { - "address": "$address", - "city": "$city" - }, + "location": {"address": "$address", "city": "$city"}, "phone": "$phone_number", "preferences": { "theme": "$choice(dark,light)", - "notifications": "$choice(true,false)" - } + "notifications": "$choice(true,false)", + }, } schema = ConfigSchema.from_dict(schema_dict) @@ -292,7 +281,7 @@ def test_invalid_schema_value_type(): """Test that invalid schema value types raise appropriate errors""" schema_dict = { "name": "$name", - "invalid_field": 123 # Invalid type - should be string or dict + "invalid_field": 123, # Invalid type - should be string or dict } with pytest.raises( diff --git a/tests/test_sinks.py b/tests/test_sinks.py index 6f313ca..058b5a3 100644 --- a/tests/test_sinks.py +++ b/tests/test_sinks.py @@ -64,7 +64,7 @@ def test_csv_sink_bulk_publish(temp_csv_file): with open(temp_csv_file, "r") as f: lines = f.readlines() assert len(lines) == 4 # Header + 3 data rows - header = lines[0].strip().split(',') + header = lines[0].strip().split(",") assert set(header) == {"name", "age"} # Header fields, order-agnostic # Clean up