Skip to content

maraxen/PrxteinMPNN

Repository files navigation

PrxteinMPNN: A functional interface to ProteinMPNN in JAX

Test Coverage Run on Colab Documentation

PrxteinMPNN provides a functional interface for ProteinMPNN, leveraging the JAX ecosystem for accelerated computation and transparent protein design workflows.

🎯 Key Features

  • 🔍 Increased Transparency: Clear and functional interface for ProteinMPNN, enabling users to understand all the operations defining the models flow
  • ⚡ JAX Compatibility: Efficient computation with JAX's functional programming paradigm, including JIT compilation and vectorization
  • 🧩 Modular Design: Maintain a modular structure to facilitate easy updates and extensions to the model
  • 🚀 Performance Optimization: Utilize JAX's capabilities for large-scale protein design tasks
  • 🔄 JAX Transformations: Compatible with jit, vmap, and scan for batch processing and optimization

📚 Documentation

Complete Documentation →

✅ Validation

PrxteinMPNN has been rigorously validated against the original ColabDesign ProteinMPNN implementation:

Decoding Path Correlation Status
Unconditional 0.984 Validated
Conditional 0.958-0.984 Validated
Autoregressive 0.953-0.970 Validated

All three decoding paths achieve >0.95 Pearson correlation with ColabDesign outputs, ensuring faithful reproduction of the original model's behavior.

View Full Validation Report →

Running Equivalence Tests

# Install ColabDesign for validation tests
pip install git+https://github.com/sokrypton/ColabDesign.git@e31a56f

# Run equivalence tests
uv run pytest tests/model/test_colabdesign_equivalence.py -v

🚀 Quick Start

Installation

uv sync --extra cuda  # For GPU
uv sync --extra tpu   # For TPU
uv sync --extra cpu   # For CPU-only (default)

Basic Usage

import jax
import jax.numpy as jnp
from prxteinmpnn.mpnn import get_mpnn_model
from prxteinmpnn.io import from_structure_file, protein_structure_to_model_inputs
from prxteinmpnn.scoring.score import make_score_sequence
from prxteinmpnn.sampling import make_sample_sequences, SamplingConfig
from prxteinmpnn.utils.decoding_order import random_decoding_order

# Load a protein structure
protein_structure = from_structure_file(filename="path/to/structure.pdb")

# Get the MPNN model
model = get_mpnn_model(
    model_version="v_48_020",
    model_weights="original"
)

# Get model inputs for the structure
model_inputs = protein_structure_to_model_inputs(protein_structure)

# Score sequences
key = jax.random.PRNGKey(0)
score_sequence = make_score_sequence(
    model, 
    random_decoding_order, 
    model_inputs=model_inputs
)

# Score original sequence
original_score, original_logits, decoding_order = score_sequence(
    key, 
    model_inputs.sequence
)

# Sample new sequences
config = SamplingConfig(sampling_strategy="temperature", temperature=0.1)
sample_sequence = make_sample_sequences(
    model,
    random_decoding_order,
    config=config,
    model_inputs=model_inputs
)

sampled_sequence, logits, decoding_order = sample_sequence(key)

🛠️ Requirements

  • Python >= 3.11
  • JAX ecosystem: jax, jaxlib, flax
  • Core dependencies: NumPy, joblib, jaxtyping
  • Protein handling: foldcomp, biotite
  • Testing: chex, pytest, pytest-cov

🏗️ Development

Code Quality & Standards

This project follows strict coding standards:

  • JAX-idiomatic code: Functional programming paradigm with immutable data structures
  • Linting: Ruff with strict configuration (line length: 100, all rules enabled)
  • Type checking: Pyright in strict mode
  • Testing: Comprehensive unit and integration tests with pytest
  • Documentation: Google-style docstrings with examples

Running Tests

python -m pytest tests/

Linting

ruff check src/ --fix

📖 Core Concepts

  • Functional Design: All operations follow JAX's functional programming paradigm
  • Immutable Data: Protein structures and model states are immutable
  • JAX Transformations: Compatible with jit, vmap, and scan
  • Modular Architecture: Clean separation of concerns across sampling, scoring, and utilities

🎯 Project Goals

PrxteinMPNN aims to provide:

  1. Transparency: A clear, understandable interface to ProteinMPNN's capabilities
  2. Performance: Leverage JAX for high-performance protein design workflows
  3. Modularity: Easy-to-extend components for custom protein design tasks
  4. Compatibility: Seamless integration with the broader JAX ecosystem

📄 License

MIT License

🤝 Contributing

Contributions are welcome! Please see the contributing guidelines currently under development for details.

📞 Support


Built with ❤️ using JAX for the protein design community

About

functional implementation of ProteinMPNN

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Contributors 5