From ee81615a356eec81e2eade15752760e9b8f96b40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EC=9E=AC=EA=B7=A0?= Date: Thu, 19 Feb 2026 15:58:07 +0900 Subject: [PATCH] [Test] Add DeepSeek v3 base test file and etc. (WIP) --- Dockerfile.base | 3 + .../torch_openreg/openreg/__init__.py | 17 +- .../mlir/mlir_codegen_backend.py | 32 +-- PyTorchSimFrontend/mlir/mlir_common.py | 11 +- PyTorchSimFrontend/mlir/mlir_template.py | 4 + tests/DeepSeek/test_deepseek_v3_base.py | 220 ++++++++++++++++++ 6 files changed, 271 insertions(+), 16 deletions(-) create mode 100644 tests/DeepSeek/test_deepseek_v3_base.py diff --git a/Dockerfile.base b/Dockerfile.base index 0fd950d2..e8504bcf 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -45,6 +45,9 @@ RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2 # Install torchsim dependency RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0 && pip install "transformers<4.44" && pip install diffusers==0.34.0 +# FlashAttention +RUN python -m pip install --no-build-isolation flash-attn + # Extra Python deps for YOLO/vision tests RUN python -m pip install -U pip setuptools wheel && \ python -m pip install --no-cache-dir --no-deps ultralytics && \ diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index 8d62cee3..f5aabc18 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -80,8 +80,21 @@ def __init__(self, flags=0): self._stream = torch_openreg._C._stream_create() def __del__(self): - if hasattr(self, '_stream'): - torch_openreg._C._stream_destroy(self._stream) + # Interpreter shutdown can clear module globals before __del__ runs. + # Only destroy when both runtime handle and stream are still valid. + stream = getattr(self, "_stream", None) + backend = globals().get("torch_openreg", None) + c_api = getattr(backend, "_C", None) if backend is not None else None + if stream is None or c_api is None: + return + destroy = getattr(c_api, "_stream_destroy", None) + if destroy is None: + return + try: + destroy(stream) + except (AttributeError, TypeError): + # Ignore cleanup-time teardown ordering issues. + pass def launch_kernel(self, task): """Add a Python callable kernel to this stream. diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index a60c706e..e6c355f6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -110,6 +110,7 @@ def write_header(self): aten = torch.ops.aten inductor_ops = torch.ops.inductor assert_size_stride = torch._C._dynamo.guards.assert_size_stride + assert_alignment = torch._C._dynamo.guards.assert_alignment alloc_from_pool = torch.ops.inductor._alloc_from_pool reinterpret_tensor = torch.ops.inductor._reinterpret_tensor custom_async_compile = CustomAsyncCompile() @@ -375,6 +376,10 @@ def _convert_sympy_to_mlir_expr(self, expr, sorted_args): indices.append(str(new_arg)) expr_str = str(expr) + if "ModularIndexing" in expr_str: + def _replace_mod(m): + return f"({m.group(1)} floordiv {m.group(2)}) mod {m.group(3)}" + expr_str = re.sub(r"ModularIndexing\(([^,]+), ([^,]+), ([^)]+)\)", _replace_mod, expr_str) if "//" in expr_str: expr_str = expr_str.replace("//", " floordiv ") return expr_str, indices @@ -1159,30 +1164,28 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe for constraint in sorted_constraints[1:]: index = index.replace(constraint.original_expr, 0) - # Calculate dram stride + # Calculate dram stride in local tile-dim order. + # This keeps dram/sram stride rank aligned with tile rank. + local_dim_to_axis = {dim: axis for axis, dim in enumerate(local_dims)} dram_stride = [0] * local_tile_desc.get_nr_dim() if index.is_Symbol: dim_idx = int(str(index)[5:]) - dram_stride[dim_idx] = 1 + if dim_idx in local_dim_to_axis: + dram_stride[local_dim_to_axis[dim_idx]] = 1 elif index.is_Number: pass else: - dram_dict = defaultdict(list) + dram_dict = defaultdict(lambda: 0) # Assume that div will have high priority than mod for arg in index.as_ordered_terms(): coeff, dim = arg.as_coeff_mul() if len(dim) == 0: continue real_dim = list(dim[0].free_symbols)[0] - dram_dict[str(real_dim)].append(coeff) - # Add missing dims if not added - max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 - for i in range(max_dim): - target_dim = f"index{i}" - if sympy.Symbol(target_dim) not in index.free_symbols: - dram_dict[target_dim] = [0] - sorted_keys = sorted(dram_dict.keys()) - dram_stride = sum((dram_dict[key] for key in sorted_keys), []) + real_dim_name = str(real_dim) + if real_dim_name.startswith("index"): + dram_dict[int(real_dim_name[5:])] += int(coeff) + dram_stride = [dram_dict[dim] for dim in local_dims] # Support floordiv pattern # FIXME. How to integrate implicit dims and floordiv? @@ -1194,6 +1197,9 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe if not str(sub.args[0]).startswith("index"): continue dim_idx = int((str(sub.args[0])[5:])) + if dim_idx not in local_dim_to_axis: + continue + local_dim_idx = local_dim_to_axis[dim_idx] if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: # In this case, need to recompile original_tile = self.kernel_group.tile_desc.get_tile_size() @@ -1212,7 +1218,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Send recompile signal self.reset("recompile") raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") - dim_divisor[dim_idx] = sub.args[1] + dim_divisor[local_dim_idx] = sub.args[1] # Update dram_stride, just insert 0 next to target dim offset = 0 diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index f101b7cb..7eb8f7f1 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -504,7 +504,7 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N vlane_stride=vlane_stride ) - self.implicit_dim_size = None + self.implicit_dim_size = {} self.nr_rdim = 0 self.offset = sympy.Integer(0) # Dram offset @@ -654,6 +654,11 @@ def reduction(self, dtype, src_dtype, reduction_type, value): def indirect_indexing(self, index_var, size, check, wrap_neg): raise NotImplementedError() + def check_bounds(self, expr, size, lower, upper): + # MLIR backend currently relies on masked paths for out-of-bounds handling. + # Keep this hook as a no-op to satisfy Inductor's check_bounds callback. + return + def codegen_global_init(self): raise NotImplementedError() @@ -964,6 +969,10 @@ def store_reduction(name, index, value): def reduction(dtype, src_dtype, reduction_type, value): return self.reduction(dtype, src_dtype, reduction_type, value) + @staticmethod + def check_bounds(index, size, lower, upper): + return self.check_bounds(index, size, lower, upper) + @staticmethod def _index_expr(tile_size, buffer, renamed_expression, index): return self._index_expr(tile_size, buffer, renamed_expression, index) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b864e5f2..1159cf3b 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -861,6 +861,8 @@ def load_epilogue(self, name: str, index: sympy.Expr): vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) tile_stride = self.kernel_group.tile_desc.get_tile_stride() + tile_rank = self.kernel_group.tile_desc.get_nr_dim() + dram_stride = dram_stride[:tile_rank] + [0] * max(tile_rank - len(dram_stride), 0) # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) @@ -913,6 +915,8 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) tile_stride = self.kernel_group.tile_desc.get_tile_stride() + tile_rank = self.kernel_group.tile_desc.get_nr_dim() + dram_stride = dram_stride[:tile_rank] + [0] * max(tile_rank - len(dram_stride), 0) if name not in self.buffer_names: sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) diff --git a/tests/DeepSeek/test_deepseek_v3_base.py b/tests/DeepSeek/test_deepseek_v3_base.py new file mode 100644 index 00000000..b8402c8b --- /dev/null +++ b/tests/DeepSeek/test_deepseek_v3_base.py @@ -0,0 +1,220 @@ +import os +import sys +import argparse +import torch + + +def _dtype_from_str(name: str) -> torch.dtype: + return { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }.get(name, torch.float32) + + +def _build_random_inputs(batch, seq_len, vocab_size, device): + g = torch.Generator().manual_seed(0) + input_ids = torch.randint(0, vocab_size, (batch, seq_len), generator=g, dtype=torch.int64) + return input_ids.to(device) + + +def _safe_scaled_int(value, scale, min_value=1): + return max(min_value, int(round(float(value) * float(scale)))) + + +def _round_to_multiple(value, multiple, min_value=1): + if multiple is None or multiple <= 0: + return max(min_value, int(value)) + v = max(min_value, int(value)) + return max(min_value, ((v + multiple - 1) // multiple) * multiple) + + +def _maybe_scale_config(config, scale=1.0, max_layers=None): + if scale == 1.0 and max_layers is None: + return config + + if hasattr(config, "hidden_size"): + config.hidden_size = _safe_scaled_int(config.hidden_size, scale) + if hasattr(config, "intermediate_size"): + config.intermediate_size = _safe_scaled_int(config.intermediate_size, scale) + if hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = _safe_scaled_int(config.num_hidden_layers, scale) + if hasattr(config, "num_attention_heads"): + config.num_attention_heads = _safe_scaled_int(config.num_attention_heads, scale) + if hasattr(config, "num_key_value_heads"): + config.num_key_value_heads = min( + _safe_scaled_int(config.num_key_value_heads, scale), + config.num_attention_heads, + ) + + for name in [ + "n_routed_experts", + "n_shared_experts", + "num_local_experts", + "num_experts", + "num_experts_per_tok", + "moe_intermediate_size", + "shared_expert_intermediate_size", + ]: + if hasattr(config, name): + setattr(config, name, _safe_scaled_int(getattr(config, name), scale)) + + # DeepSeek MoE gate expects n_routed_experts to be divisible by n_group. + if hasattr(config, "n_routed_experts") and hasattr(config, "n_group"): + config.n_routed_experts = _round_to_multiple( + config.n_routed_experts, + config.n_group, + min_value=max(1, int(config.n_group)), + ) + + if max_layers is not None and hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = max(1, min(int(max_layers), int(config.num_hidden_layers))) + + if hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + config.hidden_size = max( + config.num_attention_heads, + (config.hidden_size // config.num_attention_heads) * config.num_attention_heads, + ) + + return config + + +def _apply_preset(scale, max_layers, batch, seq_len, preset): + if preset == "tiny": + return 0.03, 4, 1, min(seq_len, 16) + if preset == "small": + return 0.07, 8, 1, min(seq_len, 32) + if preset == "medium": + return 0.10, 12, 1, min(seq_len, 48) + return scale, max_layers, batch, seq_len + + +@torch.no_grad() +def run_deep_seek_v3_base_test( + model_id, + device, + init_mode="config-random", + scale=1.0, + max_layers=None, + dtype="float16", + batch=1, + seq_len=32, + use_tokenizer=False, + prompt="Hello, DeepSeek V3", + trust_remote_code=False, + revision=None, + compile_model=False, +): + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + torch_dtype = _dtype_from_str(dtype) + + # Load model config + config = AutoConfig.from_pretrained( + model_id, + trust_remote_code=trust_remote_code, + revision=revision, + ) + + # Some remote model codes expect quantization_config to stay object-like + # (call .to_dict()), so only disable it for pretrained loading path. + if init_mode == "pretrained" and getattr(config, "quantization_config", None) is not None: + config.quantization_config = None + + config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) + + if init_mode == "config-random": + model = AutoModelForCausalLM.from_config( + config=config, + trust_remote_code=trust_remote_code, + ).eval() + model = model.to(dtype=torch_dtype) + elif init_mode == "pretrained": + # Load model(weights) + model = AutoModelForCausalLM.from_pretrained( + model_id, + config=config, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + revision=revision, + ).eval() + else: + raise ValueError(f"Unsupported init mode: {init_mode}") + + model = model.to(device) + model_params = sum(p.numel() for p in model.parameters()) + print("init mode:", init_mode) + print("scaled hidden_size:", getattr(config, "hidden_size", "n/a")) + print("scaled num_hidden_layers:", getattr(config, "num_hidden_layers", "n/a")) + print("scaled num_attention_heads:", getattr(config, "num_attention_heads", "n/a")) + print("model params:", model_params) + + # Load tokenizer + if use_tokenizer: + tokenizer = AutoTokenizer.from_pretrained( + model_id, + trust_remote_code=trust_remote_code, + revision=revision, + ) + encoded = tokenizer(prompt, return_tensors="pt") + input_ids = encoded["input_ids"].to(device) + else: + vocab_size = getattr(config, "vocab_size", None) + if vocab_size is None: + raise ValueError("Config has no vocab_size; use --use-tokenizer or pass a model with vocab_size.") + input_ids = _build_random_inputs(batch, seq_len, vocab_size, device) + + if compile_model: + model = torch.compile(model, dynamic=False) + + out = model(input_ids) + logits = out.logits + + print("logits shape:", tuple(logits.shape)) + print("logits dtype:", logits.dtype) + print("logits max:", logits.max().item()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="DeepSeek V3 download-based test") + parser.add_argument("--model-id", type=str, default=os.environ.get("DEEPSEEK_V3_MODEL_ID", "deepseek-ai/DeepSeek-V3-Base")) + parser.add_argument("--revision", type=str, default=None) + parser.add_argument("--trust-remote-code", action="store_true", default=True) + parser.add_argument("--init-mode", type=str, default="config-random", choices=["config-random", "pretrained"]) + parser.add_argument("--preset", type=str, default="tiny", choices=["none", "tiny", "small", "medium"]) + parser.add_argument("--scale", type=float, default=1.0) + parser.add_argument("--max-layers", type=int, default=None) + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=32) + parser.add_argument("--use-tokenizer", action="store_true") + parser.add_argument("--prompt", type=str, default="Hello, DeepSeek V3") + parser.add_argument("--compile", action="store_true", default=True) + + args = parser.parse_args() + + if not args.model_id: + print("Error: --model-id is required (or set DEEPSEEK_V3_MODEL_ID).", file=sys.stderr) + sys.exit(2) + + args.scale, args.max_layers, args.batch, args.seq_len = _apply_preset( + args.scale, args.max_layers, args.batch, args.seq_len, args.preset + ) + + device = torch.device("npu:0") + + run_deep_seek_v3_base_test( + model_id=args.model_id, + device=device, + init_mode=args.init_mode, + scale=args.scale, + max_layers=args.max_layers, + dtype=args.dtype, + batch=args.batch, + seq_len=args.seq_len, + use_tokenizer=args.use_tokenizer, + prompt=args.prompt, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + compile_model=args.compile, + )