diff --git a/examples/usage.py b/examples/usage.py index 6644cd3..70ee849 100644 --- a/examples/usage.py +++ b/examples/usage.py @@ -1,18 +1,12 @@ import glassgen config = { - "schema": { - "name": "$name", - "email": "$email", - "country": "$country", - "id": "$uuid", - "address": "$address", - "phone": "$phone_number", - "job": "$job", - "company": "$company", - }, - "sink": {"type": "csv", "params": {"path": "output.csv"}}, - "generator": {"rps": 1500, "num_records": 5000}, + "schema": {"name": "$name", "user": {"email": "$email", "id": "$uuid"}}, + "generator": {"num_records": 10}, } -# Start the generator -print(glassgen.generate(config=config)) +sink_csv = {"type": "csv", "params": {"path": "output.csv"}} +config["sink"] = sink_csv + +gen = glassgen.generate(config=config) +for row in gen: + print(row) diff --git a/glassgen/schema/schema.py b/glassgen/schema/schema.py index 9913a4d..bceb97e 100644 --- a/glassgen/schema/schema.py +++ b/glassgen/schema/schema.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Union from pydantic import BaseModel, Field @@ -13,71 +13,108 @@ class SchemaField(BaseModel): params: List[Any] = Field(default_factory=list) +class NestedSchemaField(BaseModel): + """Represents a nested schema field that contains other fields""" + + name: str + fields: Dict[str, Union[SchemaField, "NestedSchemaField"]] + + class ConfigSchema(BaseSchema, BaseModel): """Schema implementation that can be created from a configuration""" - fields: Dict[str, SchemaField] + fields: Dict[str, Union[SchemaField, NestedSchemaField]] @classmethod - def from_dict(cls, schema_dict: Dict[str, str]) -> "ConfigSchema": + def from_dict(cls, schema_dict: Dict[str, Any]) -> "ConfigSchema": """Create a schema from a configuration dictionary""" fields = cls._schema_dict_to_fields(schema_dict) return cls(fields=fields) @staticmethod - def _schema_dict_to_fields(schema_dict: Dict[str, str]) -> Dict[str, SchemaField]: - """Convert a schema dictionary to a dictionary of SchemaField objects""" + def _schema_dict_to_fields( + schema_dict: Dict[str, Any], + ) -> Dict[str, Union[SchemaField, NestedSchemaField]]: + """Convert a schema dictionary to a dictionary of SchemaField or + NestedSchemaField objects""" fields = {} - for name, generator_str in schema_dict.items(): - match = re.match(r"\$(\w+)(?:\((.*)\))?", generator_str) - if not match: - raise ValueError(f"Invalid generator format: {generator_str}") - - generator_name = match.group(1) - params_str = match.group(2) - - params = [] - if params_str: - # Handle choice generator specially - if generator_name == GeneratorType.CHOICE: - # Split by comma but preserve quoted strings - params = [p.strip().strip("\"'") for p in params_str.split(",")] - - else: - # Simple parameter parsing for other generators - params = [p.strip() for p in params_str.split(",")] - # Convert numeric parameters - params = [int(p) if p.isdigit() else p for p in params] - - fields[name] = SchemaField( - name=name, generator=generator_name, params=params - ) + for name, value in schema_dict.items(): + if isinstance(value, dict): + # Handle nested structure + nested_fields = ConfigSchema._schema_dict_to_fields(value) + fields[name] = NestedSchemaField(name=name, fields=nested_fields) + elif isinstance(value, str): + # Handle flat generator string + match = re.match(r"\$(\w+)(?:\((.*)\))?", value) + if not match: + raise ValueError(f"Invalid generator format: {value}") + + generator_name = match.group(1) + params_str = match.group(2) + + params = [] + if params_str: + # Handle choice generator specially + if generator_name == GeneratorType.CHOICE: + # Split by comma but preserve quoted strings + params = [p.strip().strip("\"'") for p in params_str.split(",")] + else: + # Simple parameter parsing for other generators + params = [p.strip() for p in params_str.split(",")] + # Convert numeric parameters + params = [int(p) if p.isdigit() else p for p in params] + + fields[name] = SchemaField( + name=name, generator=generator_name, params=params + ) + else: + raise ValueError( + f"Invalid schema value type for field '{name}': {type(value)}" + ) return fields def validate(self) -> None: """Validate that all generators are supported""" supported_generators = set(registry.get_supported_generators().keys()) - for field in self.fields.values(): - if field.generator not in supported_generators: - raise ValueError( - f"Unsupported generator: {field.generator}. " - f"Supported generators are: {', '.join(supported_generators)}" - ) + def validate_fields( + fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]], + ): + for field in fields_dict.values(): + if isinstance(field, SchemaField): + if field.generator not in supported_generators: + raise ValueError( + f"Unsupported generator: {field.generator}. " + f"Supported generators are: " + f"{', '.join(supported_generators)}" + ) + elif isinstance(field, NestedSchemaField): + validate_fields(field.fields) + + validate_fields(self.fields) def _generate_record(self) -> Dict[str, Any]: """Generate a single record based on the schema""" - record = {} - for field_name, field in self.fields.items(): - generator = registry.get_generator(field.generator) - # Pass parameters to the generator if they exist - if field.params: - if field.generator == GeneratorType.CHOICE: - # For choice generator, pass the list directly - record[field_name] = generator(field.params) - else: - # For other generators, unpack the parameters - record[field_name] = generator(*field.params) - else: - record[field_name] = generator() - return record + + def generate_nested_record( + fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]], + ) -> Dict[str, Any]: + record = {} + for field_name, field in fields_dict.items(): + if isinstance(field, SchemaField): + generator = registry.get_generator(field.generator) + # Pass parameters to the generator if they exist + if field.params: + if field.generator == GeneratorType.CHOICE: + # For choice generator, pass the list directly + record[field_name] = generator(field.params) + else: + # For other generators, unpack the parameters + record[field_name] = generator(*field.params) + else: + record[field_name] = generator() + elif isinstance(field, NestedSchemaField): + record[field_name] = generate_nested_record(field.fields) + return record + + return generate_nested_record(self.fields) diff --git a/glassgen/sinks/csv_sink.py b/glassgen/sinks/csv_sink.py index 42a25e3..8583009 100644 --- a/glassgen/sinks/csv_sink.py +++ b/glassgen/sinks/csv_sink.py @@ -19,23 +19,69 @@ def __init__(self, sink_params: Dict[str, Any]): self.file = None self.fieldnames = None + def _flatten_dict( + self, data: Dict[str, Any], parent_key: str = "", sep: str = "_" + ) -> Dict[str, Any]: + """ + Flatten a nested dictionary by concatenating keys with separator. + + Args: + data: The dictionary to flatten + parent_key: The parent key prefix + sep: Separator to use between nested keys + + Returns: + Flattened dictionary with concatenated keys + """ + items = [] + for key, value in data.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + if isinstance(value, dict): + # Recursively flatten nested dictionaries + items.extend(self._flatten_dict(value, new_key, sep=sep).items()) + else: + items.append((new_key, value)) + return dict(items) + + def _get_flattened_fieldnames(self, data: List[Dict[str, Any]]) -> List[str]: + """ + Get fieldnames from flattened data to ensure consistent column order. + + Args: + data: List of dictionaries to process + + Returns: + List of fieldnames for CSV header + """ + all_keys = set() + for item in data: + flattened = self._flatten_dict(item) + all_keys.update(flattened.keys()) + return sorted(list(all_keys)) + def publish(self, data: Dict[str, Any]) -> None: + # Flatten the data before writing + flattened_data = self._flatten_dict(data) + if self.writer is None: self.file = open(self.filepath, "w", newline="") - self.fieldnames = list(data.keys()) + self.fieldnames = list(flattened_data.keys()) self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames) self.writer.writeheader() - self.writer.writerow(data) + self.writer.writerow(flattened_data) def publish_bulk(self, data: List[Dict[str, Any]]) -> None: if self.writer is None: self.file = open(self.filepath, "w", newline="") - self.fieldnames = list(data[0].keys()) + # Get fieldnames from all data to ensure consistent columns + self.fieldnames = self._get_flattened_fieldnames(data) self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames) self.writer.writeheader() - self.writer.writerows(data) + # Flatten each record before writing + flattened_data = [self._flatten_dict(item) for item in data] + self.writer.writerows(flattened_data) def close(self) -> None: if self.file: diff --git a/tests/test_schema.py b/tests/test_schema.py index 5ed6d0e..2ff8e19 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -95,3 +95,196 @@ def test_predefined_schema(): assert "@" in record["email"] assert isinstance(record["age"], int) assert isinstance(record["phone"], str) + + +def test_flat_schema(): + """Test that flat schemas still work as before""" + schema_dict = { + "name": "$name", + "email": "$email", + "country": "$country", + "id": "$uuid", + "phone": "$phone_number", + "job": "$job", + "company": "$company", + "intrange": "$intrange(10,100)", + "choice": "$choice(apple,banana,cherry)", + "weburl": "$url", + "ts": "$datetime", + "tsunix": "$timestamp", + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check that all fields are present + assert "name" in record + assert "email" in record + assert "country" in record + assert "id" in record + assert "phone" in record + assert "job" in record + assert "company" in record + assert "intrange" in record + assert "choice" in record + assert "weburl" in record + assert "ts" in record + assert "tsunix" in record + + # Check that intrange is within the specified range + assert 10 <= record["intrange"] <= 100 + + # Check that choice is one of the specified values + assert record["choice"] in ["apple", "banana", "cherry"] + + +def test_nested_schema(): + """Test that nested schemas work correctly""" + schema_dict = { + "name": "$name", + "email": "$email", + "country": "$country", + "id": "$uuid", + "location": {"address": "$address", "city": "$city", "postal_code": "$zipcode"}, + "phone": "$phone_number", + "job": "$job", + "company": "$company", + "intrange": "$intrange(10,100)", + "choice": "$choice(apple,banana,cherry)", + "weburl": "$url", + "ts": "$datetime", + "tsunix": "$timestamp", + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check that flat fields are present + assert "name" in record + assert "email" in record + assert "country" in record + assert "id" in record + assert "phone" in record + assert "job" in record + assert "company" in record + assert "intrange" in record + assert "choice" in record + assert "weburl" in record + assert "ts" in record + assert "tsunix" in record + + # Check that nested location field is present and has the expected structure + assert "location" in record + assert isinstance(record["location"], dict) + assert "address" in record["location"] + assert "city" in record["location"] + assert "postal_code" in record["location"] + + # Check that intrange is within the specified range + assert 10 <= record["intrange"] <= 100 + + # Check that choice is one of the specified values + assert record["choice"] in ["apple", "banana", "cherry"] + + +def test_deeply_nested_schema(): + """Test deeply nested schemas work correctly""" + schema_dict = { + "user": { + "personal": {"name": "$name", "email": "$email", "phone": "$phone_number"}, + "address": { + "street": "$address", + "city": "$city", + "country": "$country", + "postal_code": "$zipcode", + }, + }, + "metadata": { + "id": "$uuid", + "created_at": "$datetime", + "tags": "$choice(tag1,tag2,tag3)", + }, + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check nested structure + assert "user" in record + assert isinstance(record["user"], dict) + + assert "personal" in record["user"] + assert isinstance(record["user"]["personal"], dict) + assert "name" in record["user"]["personal"] + assert "email" in record["user"]["personal"] + assert "phone" in record["user"]["personal"] + + assert "address" in record["user"] + assert isinstance(record["user"]["address"], dict) + assert "street" in record["user"]["address"] + assert "city" in record["user"]["address"] + assert "country" in record["user"]["address"] + assert "postal_code" in record["user"]["address"] + + assert "metadata" in record + assert isinstance(record["metadata"], dict) + assert "id" in record["metadata"] + assert "created_at" in record["metadata"] + assert "tags" in record["metadata"] + assert record["metadata"]["tags"] in ["tag1", "tag2", "tag3"] + + +def test_mixed_flat_and_nested_schema(): + """Test schemas with both flat and nested fields""" + schema_dict = { + "name": "$name", + "email": "$email", + "location": {"address": "$address", "city": "$city"}, + "phone": "$phone_number", + "preferences": { + "theme": "$choice(dark,light)", + "notifications": "$choice(true,false)", + }, + } + + schema = ConfigSchema.from_dict(schema_dict) + schema.validate() + + record = schema._generate_record() + + # Check flat fields + assert "name" in record + assert "email" in record + assert "phone" in record + + # Check nested fields + assert "location" in record + assert isinstance(record["location"], dict) + assert "address" in record["location"] + assert "city" in record["location"] + + assert "preferences" in record + assert isinstance(record["preferences"], dict) + assert "theme" in record["preferences"] + assert "notifications" in record["preferences"] + assert record["preferences"]["theme"] in ["dark", "light"] + assert record["preferences"]["notifications"] in ["true", "false"] + + +def test_invalid_schema_value_type(): + """Test that invalid schema value types raise appropriate errors""" + schema_dict = { + "name": "$name", + "invalid_field": 123, # Invalid type - should be string or dict + } + + with pytest.raises( + ValueError, match="Invalid schema value type for field 'invalid_field'" + ): + ConfigSchema.from_dict(schema_dict) diff --git a/tests/test_sinks.py b/tests/test_sinks.py index f9649f5..058b5a3 100644 --- a/tests/test_sinks.py +++ b/tests/test_sinks.py @@ -64,7 +64,8 @@ def test_csv_sink_bulk_publish(temp_csv_file): with open(temp_csv_file, "r") as f: lines = f.readlines() assert len(lines) == 4 # Header + 3 data rows - assert "name,age" in lines[0] # Header + header = lines[0].strip().split(",") + assert set(header) == {"name", "age"} # Header fields, order-agnostic # Clean up os.unlink(temp_csv_file)