Skip to content

Commit 6dccaed

Browse files
wooway777ma-hang
authored andcommitted
issue/573 - support kwargs to be tensors
1 parent 36081e5 commit 6dccaed

File tree

5 files changed

+161
-19
lines changed

5 files changed

+161
-19
lines changed

python/infinicore/ops/mul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ def mul(input, other, *, out=None):
77
return Tensor(_infinicore.mul(input._underlying, other._underlying))
88

99
_infinicore.mul_(out._underlying, input._underlying, other._underlying)
10+
11+
return out

test/infinicore/framework/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _create_tensor_from_spec(self, spec, device):
389389

390390
def prepare_inputs_and_kwargs(self, test_case, device):
391391
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
392-
Supports tuple inputs for operators like torch.cat
392+
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
393393
"""
394394
inputs = []
395395
kwargs = test_case.kwargs.copy()
@@ -443,6 +443,11 @@ def prepare_inputs_and_kwargs(self, test_case, device):
443443
f"Invalid input index for in-place operation: {input_idx}"
444444
)
445445

446+
for key, value in list(kwargs.items()):
447+
if isinstance(value, TensorSpec):
448+
# Replace TensorSpec with actual tensor
449+
kwargs[key] = self._create_tensor_from_spec(value, device)
450+
446451
return inputs, kwargs
447452

448453
def run_test(self, device, test_case, config):
@@ -488,6 +493,17 @@ def run_test(self, device, test_case, config):
488493
else:
489494
infini_inputs.append(inp)
490495

496+
infini_kwargs = {}
497+
for key, value in kwargs.items():
498+
if isinstance(value, torch.Tensor):
499+
# Clone tensor and convert to infinicore
500+
cloned_value = value.clone().detach()
501+
torch_input_clones.append(cloned_value)
502+
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
503+
else:
504+
# Pass through non-tensor values (scalars, strings, etc.)
505+
infini_kwargs[key] = value
506+
491507
# Determine comparison target
492508
comparison_target = test_case.comparison_target
493509

test/infinicore/framework/tensor.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -310,23 +310,16 @@ def from_scalar(cls, value, dtype=None):
310310

311311
@classmethod
312312
def from_strided_tensor(
313-
cls, shape, strides, dtype=None, init_mode=TensorInitializer.RANDOM, **kwargs
313+
cls,
314+
shape,
315+
strides,
316+
dtype=None,
317+
init_mode=TensorInitializer.RANDOM,
318+
**kwargs,
314319
):
315320
"""Alias for from_tensor with explicit strides (for backward compatibility)"""
316321
return cls.from_tensor(shape, strides, dtype, init_mode, **kwargs)
317322

318-
def with_dtype(self, dtype):
319-
"""Create a new TensorSpec with the specified dtype"""
320-
return TensorSpec(
321-
shape=self.shape,
322-
dtype=dtype,
323-
strides=self.strides,
324-
value=self.value,
325-
is_scalar=self.is_scalar,
326-
init_mode=self.init_mode,
327-
**self.kwargs,
328-
)
329-
330323
def create_torch_tensor(self, device):
331324
"""Create a torch tensor based on this specification"""
332325
if self.is_scalar:
@@ -335,7 +328,7 @@ def create_torch_tensor(self, device):
335328
# Create tensor using unified interface
336329
return TensorInitializer.create_tensor(
337330
shape=self.shape,
338-
dtype=self.dtype, # Use the dtype from the spec
331+
dtype=self.dtype,
339332
device=device,
340333
mode=self.init_mode,
341334
strides=self.strides,

test/infinicore/ops/mul.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def build_test_cases():
6464
output_spec=None,
6565
comparison_target=None,
6666
tolerance=tolerance,
67-
description=f"Mul - OUT_OF_PLACE (dtype={dtype})",
67+
description=f"Mul - OUT_OF_PLACE",
6868
)
6969
)
7070

@@ -77,7 +77,7 @@ def build_test_cases():
7777
output_spec=c_spec,
7878
comparison_target="out",
7979
tolerance=tolerance,
80-
description=f"Mul - INPLACE(out) (dtype={dtype})",
80+
description=f"Mul - INPLACE(out)",
8181
)
8282
)
8383

@@ -90,7 +90,7 @@ def build_test_cases():
9090
output_spec=None,
9191
comparison_target=0,
9292
tolerance=tolerance,
93-
description=f"Mul - INPLACE(a) (dtype={dtype})",
93+
description=f"Mul - INPLACE(a)",
9494
)
9595
)
9696

@@ -103,7 +103,7 @@ def build_test_cases():
103103
output_spec=None,
104104
comparison_target=1,
105105
tolerance=tolerance,
106-
description=f"Mul - INPLACE(b) (dtype={dtype})",
106+
description=f"Mul - INPLACE(b)",
107107
)
108108
)
109109

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import torch.nn.functional as F
8+
import infinicore
9+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
10+
from framework.runner import GenericTestRunner
11+
from framework.tensor import TensorInitializer
12+
13+
# ==============================================================================
14+
# Operator-specific configuration
15+
# ==============================================================================
16+
17+
# Test cases format: (input_shape, num_classes, has_weight, p, margin, reduction)
18+
_TEST_CASES_DATA = [
19+
# Basic cases without weight - 2D inputs only
20+
((10, 5), 5, False, 1, 1.0, "mean"),
21+
((10, 5), 5, False, 1, 1.0, "sum"),
22+
((10, 5), 5, False, 1, 1.0, "none"),
23+
((8, 3), 3, False, 2, 1.0, "mean"),
24+
((8, 3), 3, False, 2, 0.5, "sum"),
25+
# Cases with weight tensor
26+
((10, 5), 5, True, 1, 1.0, "mean"),
27+
((10, 5), 5, True, 1, 1.0, "sum"),
28+
((8, 3), 3, True, 2, 1.0, "mean"),
29+
((8, 3), 3, True, 2, 0.5, "sum"),
30+
# Edge cases - only 2D inputs
31+
((1, 3), 3, False, 1, 1.0, "mean"), # Single sample
32+
((5, 1), 1, False, 1, 1.0, "mean"), # Single class
33+
((100, 10), 10, True, 1, 2.0, "mean"), # Larger tensors
34+
]
35+
36+
# Tolerance configuration
37+
_TOLERANCE_MAP = {
38+
infinicore.float16: {"atol": 1e-3, "rtol": 1e-2},
39+
infinicore.float32: {"atol": 1e-5, "rtol": 1e-4},
40+
infinicore.bfloat16: {"atol": 1e-2, "rtol": 5e-2},
41+
}
42+
43+
# Data types to test
44+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
45+
46+
47+
def parse_test_cases():
48+
"""
49+
Parse test case data for multi_margin_loss operation.
50+
All tensors will be created on the same device.
51+
"""
52+
test_cases = []
53+
54+
for data in _TEST_CASES_DATA:
55+
input_shape = data[0]
56+
num_classes = data[1]
57+
has_weight = data[2]
58+
p_value = data[3]
59+
margin_value = data[4]
60+
reduction = data[5]
61+
62+
# Generate test cases for all data types
63+
for dtype in _TENSOR_DTYPES:
64+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4})
65+
66+
# Create input tensor spec
67+
input_spec = TensorSpec.from_tensor(input_shape, dtype=dtype)
68+
69+
# FIX: Create target as a tensor, not a scalar
70+
# For 2D input (batch_size, num_classes), target should be (batch_size,) tensor
71+
target_shape = (input_shape[0],)
72+
target_spec = TensorSpec.from_tensor(
73+
target_shape,
74+
dtype=infinicore.int64, # target must be int64 for classification
75+
init_mode=TensorInitializer.RANDINT,
76+
low=0,
77+
high=num_classes, # class indices from 0 to num_classes-1
78+
)
79+
80+
base_description = "MultiMarginLoss"
81+
82+
# Build kwargs
83+
kwargs = {"p": p_value, "margin": margin_value, "reduction": reduction}
84+
85+
# Add weight tensor if specified
86+
if has_weight:
87+
weight_spec = TensorSpec.from_tensor(
88+
(num_classes,), dtype=dtype, init_mode=TensorInitializer.RANDOM
89+
)
90+
kwargs["weight"] = weight_spec
91+
92+
test_cases.append(
93+
TestCase(
94+
inputs=[input_spec, target_spec],
95+
kwargs=kwargs,
96+
output_spec=None,
97+
comparison_target=None,
98+
tolerance=tolerance,
99+
description=base_description,
100+
)
101+
)
102+
103+
return test_cases
104+
105+
106+
class MultiMarginLossOpTest(BaseOperatorTest):
107+
"""MultiMarginLoss operator test with device handling"""
108+
109+
def __init__(self):
110+
super().__init__("MultiMarginLoss")
111+
112+
def get_test_cases(self):
113+
return parse_test_cases()
114+
115+
def torch_operator(self, *args, **kwargs):
116+
"""PyTorch multi_margin_loss implementation with device handling"""
117+
return F.multi_margin_loss(*args, **kwargs)
118+
119+
def infinicore_operator(self, *args, **kwargs):
120+
"""InfiniCore multi_margin_loss implementation"""
121+
return None
122+
123+
124+
def main():
125+
"""Main entry point"""
126+
runner = GenericTestRunner(MultiMarginLossOpTest)
127+
runner.run_and_exit()
128+
129+
130+
if __name__ == "__main__":
131+
main()

0 commit comments

Comments
 (0)