PrxteinMPNN provides a functional interface for ProteinMPNN, leveraging the JAX ecosystem for accelerated computation and transparent protein design workflows.
- 🔍 Increased Transparency: Clear and functional interface for ProteinMPNN, enabling users to understand all the operations defining the models flow
- ⚡ JAX Compatibility: Efficient computation with JAX's functional programming paradigm, including JIT compilation and vectorization
- 🧩 Modular Design: Maintain a modular structure to facilitate easy updates and extensions to the model
- 🚀 Performance Optimization: Utilize JAX's capabilities for large-scale protein design tasks
- 🔄 JAX Transformations: Compatible with
jit,vmap, andscanfor batch processing and optimization
PrxteinMPNN has been rigorously validated against the original ColabDesign ProteinMPNN implementation:
| Decoding Path | Correlation | Status |
|---|---|---|
| Unconditional | 0.984 | ✅ Validated |
| Conditional | 0.958-0.984 | ✅ Validated |
| Autoregressive | 0.953-0.970 | ✅ Validated |
All three decoding paths achieve >0.95 Pearson correlation with ColabDesign outputs, ensuring faithful reproduction of the original model's behavior.
# Install ColabDesign for validation tests
pip install git+https://github.com/sokrypton/ColabDesign.git@e31a56f
# Run equivalence tests
uv run pytest tests/model/test_colabdesign_equivalence.py -vuv sync --extra cuda # For GPU
uv sync --extra tpu # For TPU
uv sync --extra cpu # For CPU-only (default)import jax
import jax.numpy as jnp
from prxteinmpnn.mpnn import get_mpnn_model
from prxteinmpnn.io import from_structure_file, protein_structure_to_model_inputs
from prxteinmpnn.scoring.score import make_score_sequence
from prxteinmpnn.sampling import make_sample_sequences, SamplingConfig
from prxteinmpnn.utils.decoding_order import random_decoding_order
# Load a protein structure
protein_structure = from_structure_file(filename="path/to/structure.pdb")
# Get the MPNN model
model = get_mpnn_model(
model_version="v_48_020",
model_weights="original"
)
# Get model inputs for the structure
model_inputs = protein_structure_to_model_inputs(protein_structure)
# Score sequences
key = jax.random.PRNGKey(0)
score_sequence = make_score_sequence(
model,
random_decoding_order,
model_inputs=model_inputs
)
# Score original sequence
original_score, original_logits, decoding_order = score_sequence(
key,
model_inputs.sequence
)
# Sample new sequences
config = SamplingConfig(sampling_strategy="temperature", temperature=0.1)
sample_sequence = make_sample_sequences(
model,
random_decoding_order,
config=config,
model_inputs=model_inputs
)
sampled_sequence, logits, decoding_order = sample_sequence(key)- Python >= 3.11
- JAX ecosystem: jax, jaxlib, flax
- Core dependencies: NumPy, joblib, jaxtyping
- Protein handling: foldcomp, biotite
- Testing: chex, pytest, pytest-cov
This project follows strict coding standards:
- JAX-idiomatic code: Functional programming paradigm with immutable data structures
- Linting: Ruff with strict configuration (line length: 100, all rules enabled)
- Type checking: Pyright in strict mode
- Testing: Comprehensive unit and integration tests with pytest
- Documentation: Google-style docstrings with examples
python -m pytest tests/ruff check src/ --fix- Functional Design: All operations follow JAX's functional programming paradigm
- Immutable Data: Protein structures and model states are immutable
- JAX Transformations: Compatible with
jit,vmap, andscan - Modular Architecture: Clean separation of concerns across sampling, scoring, and utilities
PrxteinMPNN aims to provide:
- Transparency: A clear, understandable interface to ProteinMPNN's capabilities
- Performance: Leverage JAX for high-performance protein design workflows
- Modularity: Easy-to-extend components for custom protein design tasks
- Compatibility: Seamless integration with the broader JAX ecosystem
MIT License
Contributions are welcome! Please see the contributing guidelines currently under development for details.
- Documentation: http://maraxen.github.io/PrxteinMPNN
- Issues: GitHub Issues
- Discussions: GitHub Discussions