From ba68a9f8cc1fca635fb59bed405aa8ca751a9810 Mon Sep 17 00:00:00 2001 From: Gururaj Kosuru Date: Mon, 13 Oct 2025 19:36:55 -0700 Subject: [PATCH 1/2] [Testing] Add unit tests for Gemma3 model This PR adds comprehensive unit tests for the Gemma3 model architecture, which was added in #3172, #3178, #3224 but lacked test coverage. Changes: - Add tests/python/model/test_gemma3.py with the following test cases: * test_gemma3_model_registered: Verify model is in registry * test_gemma3_creation: Test model creation and TVM export * test_gemma3_config_validation: Validate configuration parameters - Tests cover gemma3_2b and gemma3_9b variants - Follows pattern from test_llama.py Testing: All tests pass locally (instructions in PR description for reproduction) Part of effort to add test coverage for recently added models. --- tests/python/model/test_gemma3.py | 69 +++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/python/model/test_gemma3.py diff --git a/tests/python/model/test_gemma3.py b/tests/python/model/test_gemma3.py new file mode 100644 index 0000000000..43a00d4dbb --- /dev/null +++ b/tests/python/model/test_gemma3.py @@ -0,0 +1,69 @@ +# pylint: disable=invalid-name,missing-docstring +"""Unit tests for Gemma3 model architecture.""" + +import pytest + +from mlc_llm.model import MODEL_PRESETS, MODELS + + +def test_gemma3_model_registered(): + """Verify Gemma3 model is in the registry.""" + assert "gemma3" in MODELS, "gemma3 should be registered in MODELS" + + +@pytest.mark.parametrize( + "model_name", + [ + "gemma3_2b", + "gemma3_9b", + ], +) +def test_gemma3_creation(model_name: str): + """Test Gemma3 model creation and export to TVM IR. + + Verifies: + - Config can be loaded from preset + - Model instance can be created + - Model exports to TVM IR successfully + - Named parameters are extracted + """ + model_info = MODELS["gemma3"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + + # Verify export succeeded + assert mod is not None + assert len(named_params) > 0 + + # Optional: show module structure + mod.show(black_format=False) + + # Print parameters for debugging + for name, param in named_params: + print(name, param.shape, param.dtype) + + +def test_gemma3_config_validation(): + """Test Gemma3 configuration has required fields.""" + model_info = MODELS["gemma3"] + config = model_info.config.from_dict(MODEL_PRESETS["gemma3_2b"]) + + # Check required config parameters + assert hasattr(config, "hidden_size") and config.hidden_size > 0 + assert hasattr(config, "num_hidden_layers") and config.num_hidden_layers > 0 + assert hasattr(config, "num_attention_heads") and config.num_attention_heads > 0 + assert hasattr(config, "vocab_size") and config.vocab_size > 0 + + print(f"Gemma3 Config: hidden_size={config.hidden_size}, " + f"layers={config.num_hidden_layers}, " + f"heads={config.num_attention_heads}, " + f"vocab={config.vocab_size}") + + +if __name__ == "__main__": + # Allow running tests directly + test_gemma3_creation("gemma3_2b") + test_gemma3_creation("gemma3_9b") From ec32851cb63fbb2297b38d17bcd3d27be62a99c5 Mon Sep 17 00:00:00 2001 From: Gururaj Kosuru Date: Mon, 13 Oct 2025 19:42:32 -0700 Subject: [PATCH 2/2] Fix Black formatting for print statement --- tests/python/model/test_gemma3.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/model/test_gemma3.py b/tests/python/model/test_gemma3.py index 43a00d4dbb..37e7768092 100644 --- a/tests/python/model/test_gemma3.py +++ b/tests/python/model/test_gemma3.py @@ -57,10 +57,12 @@ def test_gemma3_config_validation(): assert hasattr(config, "num_attention_heads") and config.num_attention_heads > 0 assert hasattr(config, "vocab_size") and config.vocab_size > 0 - print(f"Gemma3 Config: hidden_size={config.hidden_size}, " - f"layers={config.num_hidden_layers}, " - f"heads={config.num_attention_heads}, " - f"vocab={config.vocab_size}") + print( + f"Gemma3 Config: hidden_size={config.hidden_size}, " + f"layers={config.num_hidden_layers}, " + f"heads={config.num_attention_heads}, " + f"vocab={config.vocab_size}" + ) if __name__ == "__main__":