Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions examples/usage.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import glassgen

config = {
"schema": {
"name": "$name",
"email": "$email",
"country": "$country",
"id": "$uuid",
"address": "$address",
"phone": "$phone_number",
"job": "$job",
"company": "$company",
},
"sink": {"type": "csv", "params": {"path": "output.csv"}},
"generator": {"rps": 1500, "num_records": 5000},
"schema": {"name": "$name", "user": {"email": "$email", "id": "$uuid"}},
"generator": {"num_records": 10},
}
# Start the generator
print(glassgen.generate(config=config))
sink_csv = {"type": "csv", "params": {"path": "output.csv"}}
config["sink"] = sink_csv

gen = glassgen.generate(config=config)
for row in gen:
print(row)
135 changes: 86 additions & 49 deletions glassgen/schema/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

from pydantic import BaseModel, Field

Expand All @@ -13,71 +13,108 @@ class SchemaField(BaseModel):
params: List[Any] = Field(default_factory=list)


class NestedSchemaField(BaseModel):
"""Represents a nested schema field that contains other fields"""

name: str
fields: Dict[str, Union[SchemaField, "NestedSchemaField"]]


class ConfigSchema(BaseSchema, BaseModel):
"""Schema implementation that can be created from a configuration"""

fields: Dict[str, SchemaField]
fields: Dict[str, Union[SchemaField, NestedSchemaField]]

@classmethod
def from_dict(cls, schema_dict: Dict[str, str]) -> "ConfigSchema":
def from_dict(cls, schema_dict: Dict[str, Any]) -> "ConfigSchema":
"""Create a schema from a configuration dictionary"""
fields = cls._schema_dict_to_fields(schema_dict)
return cls(fields=fields)

@staticmethod
def _schema_dict_to_fields(schema_dict: Dict[str, str]) -> Dict[str, SchemaField]:
"""Convert a schema dictionary to a dictionary of SchemaField objects"""
def _schema_dict_to_fields(
schema_dict: Dict[str, Any],
) -> Dict[str, Union[SchemaField, NestedSchemaField]]:
"""Convert a schema dictionary to a dictionary of SchemaField or
NestedSchemaField objects"""
fields = {}
for name, generator_str in schema_dict.items():
match = re.match(r"\$(\w+)(?:\((.*)\))?", generator_str)
if not match:
raise ValueError(f"Invalid generator format: {generator_str}")

generator_name = match.group(1)
params_str = match.group(2)

params = []
if params_str:
# Handle choice generator specially
if generator_name == GeneratorType.CHOICE:
# Split by comma but preserve quoted strings
params = [p.strip().strip("\"'") for p in params_str.split(",")]

else:
# Simple parameter parsing for other generators
params = [p.strip() for p in params_str.split(",")]
# Convert numeric parameters
params = [int(p) if p.isdigit() else p for p in params]

fields[name] = SchemaField(
name=name, generator=generator_name, params=params
)
for name, value in schema_dict.items():
if isinstance(value, dict):
# Handle nested structure
nested_fields = ConfigSchema._schema_dict_to_fields(value)
fields[name] = NestedSchemaField(name=name, fields=nested_fields)
elif isinstance(value, str):
# Handle flat generator string
match = re.match(r"\$(\w+)(?:\((.*)\))?", value)
if not match:
raise ValueError(f"Invalid generator format: {value}")

generator_name = match.group(1)
params_str = match.group(2)

params = []
if params_str:
# Handle choice generator specially
if generator_name == GeneratorType.CHOICE:
# Split by comma but preserve quoted strings
params = [p.strip().strip("\"'") for p in params_str.split(",")]
else:
# Simple parameter parsing for other generators
params = [p.strip() for p in params_str.split(",")]
# Convert numeric parameters
params = [int(p) if p.isdigit() else p for p in params]

fields[name] = SchemaField(
name=name, generator=generator_name, params=params
)
else:
raise ValueError(
f"Invalid schema value type for field '{name}': {type(value)}"
)
return fields

def validate(self) -> None:
"""Validate that all generators are supported"""
supported_generators = set(registry.get_supported_generators().keys())

for field in self.fields.values():
if field.generator not in supported_generators:
raise ValueError(
f"Unsupported generator: {field.generator}. "
f"Supported generators are: {', '.join(supported_generators)}"
)
def validate_fields(
fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]],
):
for field in fields_dict.values():
if isinstance(field, SchemaField):
if field.generator not in supported_generators:
raise ValueError(
f"Unsupported generator: {field.generator}. "
f"Supported generators are: "
f"{', '.join(supported_generators)}"
)
elif isinstance(field, NestedSchemaField):
validate_fields(field.fields)

validate_fields(self.fields)

def _generate_record(self) -> Dict[str, Any]:
"""Generate a single record based on the schema"""
record = {}
for field_name, field in self.fields.items():
generator = registry.get_generator(field.generator)
# Pass parameters to the generator if they exist
if field.params:
if field.generator == GeneratorType.CHOICE:
# For choice generator, pass the list directly
record[field_name] = generator(field.params)
else:
# For other generators, unpack the parameters
record[field_name] = generator(*field.params)
else:
record[field_name] = generator()
return record

def generate_nested_record(
fields_dict: Dict[str, Union[SchemaField, NestedSchemaField]],
) -> Dict[str, Any]:
record = {}
for field_name, field in fields_dict.items():
if isinstance(field, SchemaField):
generator = registry.get_generator(field.generator)
# Pass parameters to the generator if they exist
if field.params:
if field.generator == GeneratorType.CHOICE:
# For choice generator, pass the list directly
record[field_name] = generator(field.params)
else:
# For other generators, unpack the parameters
record[field_name] = generator(*field.params)
else:
record[field_name] = generator()
elif isinstance(field, NestedSchemaField):
record[field_name] = generate_nested_record(field.fields)
return record

return generate_nested_record(self.fields)
54 changes: 50 additions & 4 deletions glassgen/sinks/csv_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,69 @@ def __init__(self, sink_params: Dict[str, Any]):
self.file = None
self.fieldnames = None

def _flatten_dict(
self, data: Dict[str, Any], parent_key: str = "", sep: str = "_"
) -> Dict[str, Any]:
"""
Flatten a nested dictionary by concatenating keys with separator.

Args:
data: The dictionary to flatten
parent_key: The parent key prefix
sep: Separator to use between nested keys

Returns:
Flattened dictionary with concatenated keys
"""
items = []
for key, value in data.items():
new_key = f"{parent_key}{sep}{key}" if parent_key else key
if isinstance(value, dict):
# Recursively flatten nested dictionaries
items.extend(self._flatten_dict(value, new_key, sep=sep).items())
else:
items.append((new_key, value))
return dict(items)

def _get_flattened_fieldnames(self, data: List[Dict[str, Any]]) -> List[str]:
"""
Get fieldnames from flattened data to ensure consistent column order.

Args:
data: List of dictionaries to process

Returns:
List of fieldnames for CSV header
"""
all_keys = set()
for item in data:
flattened = self._flatten_dict(item)
all_keys.update(flattened.keys())
return sorted(list(all_keys))

def publish(self, data: Dict[str, Any]) -> None:
# Flatten the data before writing
flattened_data = self._flatten_dict(data)

if self.writer is None:
self.file = open(self.filepath, "w", newline="")
self.fieldnames = list(data.keys())
self.fieldnames = list(flattened_data.keys())
self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames)
self.writer.writeheader()

self.writer.writerow(data)
self.writer.writerow(flattened_data)

def publish_bulk(self, data: List[Dict[str, Any]]) -> None:
if self.writer is None:
self.file = open(self.filepath, "w", newline="")
self.fieldnames = list(data[0].keys())
# Get fieldnames from all data to ensure consistent columns
self.fieldnames = self._get_flattened_fieldnames(data)
self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames)
self.writer.writeheader()

self.writer.writerows(data)
# Flatten each record before writing
flattened_data = [self._flatten_dict(item) for item in data]
self.writer.writerows(flattened_data)

def close(self) -> None:
if self.file:
Expand Down
Loading