⚠️ Early Development - API UnstableDatarax is currently in early development and undergoing rapid iteration. Please be aware of the following implications:
Area Status Impact API 🔄 Unstable Breaking changes are expected. Public interfaces may change without deprecation warnings. Pin to specific commits if stability is required. Tests 🔄 In Flux Test suite is being expanded. Some tests may fail or be skipped. Coverage metrics are improving but not yet comprehensive. Documentation 🔄 Evolving Docs may not reflect current implementation. Code examples might be outdated. Refer to source code and tests for accurate usage. We recommend waiting for a stable release (v1.0) before using Datarax in production. For research and experimentation, proceed with the understanding that APIs will evolve.
Datarax (Data + Array/JAX) is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It simplifies and accelerates the development of efficient and scalable data loading, preprocessing, and augmentation pipelines for JAX, leveraging the full potential of JAX's Just-In-Time (JIT) compilation, automatic differentiation, and hardware acceleration capabilities.
- High Performance: Leverages JAX's JIT compilation and XLA backend to achieve near-optimal data processing speeds on CPUs, GPUs, and TPUs.
- JAX-Native Design: All core components and operations are designed with JAX's functional programming paradigm and immutable data structures (PyTrees) in mind.
- Scalability: Supports efficient data loading and processing for large datasets and distributed training scenarios, including multi-host and multi-device setups.
- Extensibility: Easily define and integrate custom data sources, transformations, and augmentation operations.
- Usability: Provides a clear, intuitive Python API and a flexible configuration system (TOML-based) for defining and managing pipelines.
- Determinism: Pipeline runs are deterministic by default, crucial for reproducibility in research and production.
- Caching Optimization: Multiple caching strategies for performance improvement, including function caching, transformer caching, and checkpointing.
- Complete Feature Set: Supports common data pipeline operations including diverse data source handling, advanced transformations, data augmentation, batching, sharding, checkpointing, and caching.
- Ecosystem Integration: Facilitates smooth integration with other JAX libraries like Flax, Optax, and Orbax.
Install Datarax using pip:
# Basic installation
pip install datarax
# With data loading support (HuggingFace, TFDS, Audio/Image libs)
pip install datarax[data]
# With GPU support (CUDA 12)
pip install datarax[gpu]
# Full installation with all optional dependencies
pip install datarax[all]Datarax supports macOS on both Intel and Apple Silicon (M1/M2/M3) processors.
# Install for macOS (CPU mode - recommended)
pip install datarax[all-cpu]
# Run with explicit CPU backend
JAX_PLATFORMS=cpu python your_script.pyMetal GPU Acceleration (Experimental): JAX supports Apple's Metal backend for GPU acceleration on M1+ chips:
pip install jax-metal
JAX_PLATFORMS=metal python your_script.pyNote: Metal GPU acceleration is community-tested. CI runs on macOS with CPU only. If you can help validate Metal support, please open an issue with your results.
Known Limitations:
- TensorFlow Datasets: May hang on import on macOS ARM64. Datarax uses lazy imports to handle this.
- JAX Profiler: Tracing is automatically disabled on macOS due to TensorBoard compatibility.
For contributors and developers, Datarax uses uv as its package manager:
# Install uv
pip install uv
# Clone and setup
git clone https://github.com/avitai/datarax.git
cd datarax
# Run automatic setup (creates venv & installs dependencies)
./setup.sh
# Activate the environment
source activate.sh
# Or install manually with uv
uv pip install -e ".[dev]"See the Development Environment Guide for detailed instructions on:
- Creating a virtual environment with
uv venv - Installing dependencies through
pyproject.toml - Running tests through pytest
- Building and packaging Datarax
Datarax is currently in active development with a Flax NNX-based architecture that provides robust state management and checkpointing:
- NNX Module System: All components built on
flax.nnx.Modulefor robust state management - Integrated Checkpointing: Seamless Orbax integration for stateful pipeline persistence
- Type Safety: Complete type annotations and runtime validation
- Composability: Modular design enabling flexible pipeline construction
- Core NNX Modules: DataraxModule, OperatorModule, StructuralModule
- Data Sources: MemorySource, TFDSEagerSource/TFDSStreamingSource, HFEagerSource/HFStreamingSource (inheriting from StructuralModule/DataSourceModule)
- Operators: Element-wise operators, MapOperator (inheriting from OperatorModule)
- DAG Execution Engine:
DAGExecutorandpipelineAPI for constructing flexible, graph-based data processing flows. - Node System: A rich set of nodes for building pipelines:
DataSourceNode: entry point for dataOperatorNode: for transformationsBatchNode&RebatchNode: for batching controlShuffleNode,CacheNode: for data management- Control flow nodes:
Sequential,Parallel,Branch,Merge
- Stateful Components:
- MemorySource/ArrayRecordSourceModule/HFEagerSource/HFStreamingSource for diverse data ingestion
- Range/ShuffleSamplerModule with reproducible random state
- DefaultBatcherModule with buffer state management
- ArraySharderModule for device-aware data distribution
- External Integrations:
- HuggingFace Datasets with stateful iteration
- TensorFlow Datasets with checkpoint support
- Advanced Features:
- Complete caching strategies with state preservation
- Differentiable rebatching (
DifferentiableRebatchImpl) - NNXCheckpointHandler for production-grade checkpointing
Upcoming features include:
- Image transformation library with JAX-native operations
- Advanced sharding strategies for multi-device and multi-host scenarios
- Performance optimization suite with benchmarking tools
- Extended monitoring and metrics capabilities
- Additional external data source integrations
Here's a simple example of using Datarax's DAG-based architecture to create a data pipeline for image classification:
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element
def normalize(element: Element, key: jax.Array | None = None) -> Element:
"""Normalize the image in the element."""
# element.data is a dict-like PyTree containing the actual data arrays
# Return a new Element with updated data (immutable update)
# update_data provides a clean API for partial updates
return element.update_data({"image": element.data["image"] / 255.0})
def apply_augmentation(element: Element, key: jax.Array) -> Element:
"""Apply a simple augmentation (flipping) to the image."""
key1, _ = jax.random.split(key)
flip_horizontal = jax.random.bernoulli(key1, 0.5)
# Use jax.lax.cond for JAX-compatible conditional execution
def flip_image(img):
return jnp.flip(img, axis=1)
def no_flip(img):
return img
new_image = jax.lax.cond(
flip_horizontal,
flip_image,
no_flip,
element.data["image"]
)
return element.update_data({"image": new_image})
# Create some dummy data (28x28 images)
num_samples = 1000
image_shape = (28, 28, 1)
data = {
"image": np.random.randint(0, 255, (num_samples, *image_shape)).astype(np.float32),
"label": np.random.randint(0, 10, (num_samples,)).astype(np.int32),
}
# Create the data source using config-based API
# MemorySource is a standard DataSource implementation for in-memory data
source_config = MemorySourceConfig()
source = MemorySource(source_config, data=data, rngs=nnx.Rngs(0))
# Create operators using the unified ElementOperator API
# Normalizer is deterministic (no random key needed)
normalizer_config = ElementOperatorConfig(stochastic=False)
normalizer = ElementOperator(normalizer_config, fn=normalize, rngs=nnx.Rngs(0))
# Augmenter is stochastic: requires stream_name for proper RNG management
augmenter_config = ElementOperatorConfig(stochastic=True, stream_name="augmentations")
augmenter = ElementOperator(augmenter_config, fn=apply_augmentation, rngs=nnx.Rngs(42))
# Create the data pipeline using the DAG-based API
# from_source() initializes the pipeline with:
# 1. The DataSourceNode (data loading)
# 2. A BatchNode (automatic batching)
# The >> operator chains transformation operators
pipeline = (
from_source(source, batch_size=32)
>> OperatorNode(normalizer)
>> OperatorNode(augmenter)
)
# Alternative: Method Chaining
# You can also build the pipeline using the fluent .add() method:
# pipeline = (
# from_source(source, batch_size=32)
# .add(OperatorNode(normalizer))
# .add(OperatorNode(augmenter))
# )
# Create an iterator and process batches
# The pipeline handles data streaming, batching, state management, and execution
for i, batch in enumerate(pipeline):
if i >= 3:
break
# Get the shape and stats for each component in the batch
# batch['key'] provides direct access to data arrays
image_batch = batch["image"]
label_batch = batch["label"]
print(f"Batch {i}:")
print(f" Image shape: {image_batch.shape}")
print(f" Label batch size: {label_batch.shape[0]}")
print(f" Image min/max: {image_batch.min():.3f}/{image_batch.max():.3f}")
print("Pipeline processing completed!")For more complex workflows, Datarax supports branching and parallel execution:
from datarax.dag.nodes import Parallel, Merge, Branch
# Define additional operators
def invert(element, key=None):
return element.update_data({"image": 1.0 - element.data["image"]})
def is_high_contrast(element):
# Condition: check if image variance is high
return jnp.var(element.data["image"]) > 0.1
# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: Normal version AND Inverted version
# 3. Merge: Average them (Simple Ensemble)
# 4. Branch: Apply extra noise ONLY if high contrast, otherwise normalize again
complex_pipeline = (
from_source(source, batch_size=32)
>> (OperatorNode(normalizer) | OperatorNode(invert))
>> Merge("mean")
>> Branch(
condition=is_high_contrast,
true_path=OperatorNode(augmenter),
false_path=OperatorNode(normalizer)
)
)Datarax uses a complete test suite with support for both CPU and GPU testing:
- Tests are organized to mirror the
src/dataraxpackage structure for easier navigation - All GitHub CI workflows run tests exclusively on CPU
- Local test runs automatically use GPU when available
- The testing infrastructure handles environment configuration to ensure consistency
To run tests:
# Run all tests using CPU only (most stable)
JAX_PLATFORMS=cpu python -m pytest
# Run specific test module
JAX_PLATFORMS=cpu python -m pytest tests/sources/test_memory_source.pyFor more information on the test organization and how to run tests, see the Testing Guide.
Datarax is licensed under the MIT License.