|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# | 
|  | 3 | +# This source code is licensed under the MIT license found in the | 
|  | 4 | +# LICENSE file in the root directory of this source tree. | 
|  | 5 | +"""Shared test fixtures and mock infrastructure for LLM tests.""" | 
|  | 6 | +from __future__ import annotations | 
|  | 7 | + | 
|  | 8 | +import pytest | 
|  | 9 | +import torch | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +class MockTransformerConfig: | 
|  | 13 | +    """Mock config to mimic transformers model config.""" | 
|  | 14 | + | 
|  | 15 | +    def __init__(self, vocab_size: int, max_position_embeddings: int = 2048): | 
|  | 16 | +        self.vocab_size = vocab_size | 
|  | 17 | +        self.max_position_embeddings = max_position_embeddings | 
|  | 18 | +        self.hidden_size = vocab_size  # For simplicity | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +class MockTransformerOutput: | 
|  | 22 | +    """Mock output that mimics transformers model output with dict-like access.""" | 
|  | 23 | + | 
|  | 24 | +    def __init__(self, logits): | 
|  | 25 | +        self.logits = logits | 
|  | 26 | + | 
|  | 27 | +    def __getitem__(self, key): | 
|  | 28 | +        """Allow dict-like access for compatibility.""" | 
|  | 29 | +        if key == "logits": | 
|  | 30 | +            return self.logits | 
|  | 31 | +        raise KeyError(f"Key {key} not found in model output") | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +class MockTransformerModel(torch.nn.Module): | 
|  | 35 | +    """Mock transformer model that mimics the structure of HuggingFace models.""" | 
|  | 36 | + | 
|  | 37 | +    def __init__(self, vocab_size: int, device: torch.device | str | int = "cpu"): | 
|  | 38 | +        super().__init__() | 
|  | 39 | +        device = torch.device(device) | 
|  | 40 | +        self.config = MockTransformerConfig(vocab_size) | 
|  | 41 | +        # Simple embedding layer that maps tokens to logits | 
|  | 42 | +        self.embedding = torch.nn.Embedding(vocab_size, vocab_size, device=device) | 
|  | 43 | +        self.device = device | 
|  | 44 | + | 
|  | 45 | +    def forward(self, input_ids, attention_mask=None, **kwargs): | 
|  | 46 | +        """Forward pass that returns logits in the expected format.""" | 
|  | 47 | +        # Get embeddings (which we'll use as logits for simplicity) | 
|  | 48 | +        logits = self.embedding(input_ids % self.config.vocab_size) | 
|  | 49 | +        # Return output object similar to transformers models | 
|  | 50 | +        return MockTransformerOutput(logits) | 
|  | 51 | + | 
|  | 52 | +    def get_tokenizer(self): | 
|  | 53 | +        from transformers import AutoTokenizer | 
|  | 54 | + | 
|  | 55 | +        return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | 
|  | 56 | + | 
|  | 57 | + | 
|  | 58 | +@pytest.fixture | 
|  | 59 | +def mock_transformer_model(): | 
|  | 60 | +    """Fixture that provides a mock transformer model factory.""" | 
|  | 61 | + | 
|  | 62 | +    def _make_model( | 
|  | 63 | +        vocab_size: int = 1024, device: torch.device | str | int = "cpu" | 
|  | 64 | +    ) -> MockTransformerModel: | 
|  | 65 | +        """Make a mock transformer model.""" | 
|  | 66 | +        device = torch.device(device) | 
|  | 67 | +        return MockTransformerModel(vocab_size, device) | 
|  | 68 | + | 
|  | 69 | +    return _make_model | 
|  | 70 | + | 
|  | 71 | + | 
|  | 72 | +@pytest.fixture | 
|  | 73 | +def mock_tokenizer(): | 
|  | 74 | +    """Fixture that provides a mock tokenizer.""" | 
|  | 75 | +    from transformers import AutoTokenizer | 
|  | 76 | + | 
|  | 77 | +    return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | 
0 commit comments