diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 0ad81385e..ea4b8dd56 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -3,6 +3,197 @@ diopi_configs = { + 'has_inf': dict( + name=["isinf"], + interface=["torch"], + atol=1e-3, + rtol=1e-4, + tensor_para=dict( + args=[ + { + "ins": ['input'], + "shape": ((), (1024,), (2, 4096), (64, 28, 28), (32, 64, 112, 112), (64, 3, 7, 28, 28), (0,), (256, 0), (8, 0, 128)), + "dtype": [np.float16, np.float32, np.float64, np.int16, np.int32, np.int64, np.uint8, np.int8], + }, + ], + ), + ), + + 'trunc': dict( + name=["trunc"], + interface=["torch"], + atol=1e-3, + rtol=1e-4, + tensor_para=dict( + args=[ + { + "ins": ['input'], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + "dtype": [np.float32, np.float16, np.float64], + }, + ], + ), + ), + + 'round': dict( + name=["round"], + interface=["torch"], + atol=1e-3, + rtol=1e-4, + tensor_para=dict( + args=[ + { + "ins": ['input'], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + "dtype": [np.float32, np.float16, np.float64], + }, + ], + ), + ), + + 'round': dict( + name=["hardsigmoid"], + atol=1e-3, + rtol=1e-4, + tensor_para=dict( + args=[ + { + "ins": ['input'], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + "dtype": [np.float32, np.float16, np.float64], + }, + ], + ), + ), + + 'elu': dict( + name=["elu"], + atol=1e-3, + rtol=1e-4, + para=dict( + alpha=[0.234, 4.8, -10, 1.0], + ), + tensor_para=dict( + args=[ + { + "ins": ["input"], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + "dtype": [np.float32, np.float16, np.float64], + }, + ], + ), + ), + + 'prelu': dict( + name=["prelu"], + atol=1e-3, + rtol=1e-4, + dtype=[np.float32, np.float16, np.float64], + tensor_para=dict( + args=[ + { + "ins": ["input"], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + }, + { + "ins": ["weight"], + "shape": ((16,), (64,), (96,), (1,)), + }, + ], + ), + ), + + 'selu': dict( + name=["selu"], + dtype=[np.float32, np.float16, np.float64], + atol=1e-3, + rtol=1e-4, + tensor_para=dict( + args=[ + { + "ins": ["input"], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + }, + ], + ), + ), + + 'softplus': dict( + name=["softplus"], + atol=1e-3, + rtol=1e-4, + para=dict( + beta=[0.234, 4.8, -10, 1.0], + threshold=[0.234, 4.8, -10, 1.0] + ), + tensor_para=dict( + args=[ + { + "ins": ["input"], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + "dtype": [np.float32, np.float16, np.float64], + }, + ], + ), + ), + + 'softsign': dict( + name=["softsign"], + atol=1e-3, + rtol=1e-4, + tensor_para=dict( + args=[ + { + "ins": ["input"], + "shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)), + "dtype": [np.float32, np.float16, np.float64], + }, + ], + ), + ), + + "batch_norm_GB": dict( + name=["batch_norm_GB"], + interface=['CustomizedTest'], + dtype=[np.float32, np.float16, np.float64], + atol=1e-2, + rtol=1e-2, + atol_half=1e-1, + rtol_half=1e-2, + para=dict( + training=[True, True, True], + momentum=[0.01, 0.01, 0.01], + axis=[0, 1, 2], + eps=[1e-4, 1e-4, 1e-4], + ), + tensor_para=dict( + args=[ + { + "ins": ["input"], + "requires_grad": [True], + "shape": ((2, 64, 32, 32),(2, 64, 32, 32),(2, 64, 32, 32)), + "gen_fn": "Genfunc.randn", + }, + { + "ins": ["running_mean"], + "shape": ((2,), (64,), (32,)), + "gen_fn": "Genfunc.zeros", + }, + { + "ins": ["running_var"], + "shape": ((2,), (64,), (32,)), + "gen_fn": "Genfunc.ones", + }, + { + "ins": ["weight", "bias"], + "requires_grad": [True], + "shape": ((2,), (64,), (32,)), + "gen_fn": "Genfunc.randn", + }, + ] + ), + ), + # FIXME batch_norm输入0size的张量报错 'batch_norm': dict( name=["batch_norm"], @@ -507,17 +698,34 @@ args=[ { "ins": ['input'], + "requires_grad": [True], "shape": ((), (1024,), (2, 4096), (64, 28, 28), (32, 64, 112, 112), (64, 3, 7, 28, 28), (0,), (256, 0), (8, 0, 128)), - "dtype": [np.float16, np.float32, np.float64, - np.int16, np.int32, np.int64, - np.uint8, np.int8], + "dtype": [np.float16, np.float32, np.float64], "gen_fn": 'Genfunc.randn', }, ], ), ), + + 'erf': dict( + name=['erf'], + interface=['torch'], + dtype=[np.float16, np.float32, np.float64], + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['input'], + "requires_grad": [True], + "shape": ((), (1, ), (1024,), (364800, 4), (2, 128, 3072), + (256, 128, 3, 3), + (2, 31, 512, 6, 40), (0,), (16, 0)), + }, + ], + ), + ), 'relu_no_contiguous': dict( name=["relu"], @@ -4902,6 +5110,8 @@ name=["dropout"], no_output_ref=True, is_inplace=True, + atol=1e-3, + rtol=1e-3, para=dict( p=[0.5, 0, 0.1, 0.4], training=[True, True, True, False] @@ -4911,7 +5121,7 @@ { "ins": ['input'], "shape": ((2, 4096), (32, 49, 256), (2, 16, 64, 64), (1, 2304, 1, 1, 1)), - "dtype": [np.float16, np.float32, np.float64], + "dtype": [np.float32, np.float64], "gen_fn": 'Genfunc.positive', }, ], @@ -6996,6 +7206,37 @@ ] ), ), + + 'group_norm_GB': dict( + name=['group_norm_GB'], + interface=['CustomizedTest'], + atol=1e-4, + rtol=1e-5, + para=dict( + num_groups=[32, 4, 5, 1], + eps=[1e-05, 1e-05, 1e-05, 1e-05], + reduced_axes = [[2, 3], [1, 3], [0, 3], [2, 3]], + channel_axis = [1, 2, 1, 0] + ), + tensor_para=dict( + args=[ + { + "ins": ["input"], + "requires_grad": [True], + "shape": ((2, 256, 7, 10), (2, 256, 12, 12), + (12, 15, 8, 9),(3, 6, 9, 0)), + "dtype": [np.float32, np.float64, np.float16], + }, + { + "ins": ["weight", "bias"], + "requires_grad": [True], + "shape": ((256,), (12,), + (15,), (3,)), + "dtype": [np.float32, np.float64, np.float16], + }, + ] + ), + ), 'unique': dict( name=['unique'], diff --git a/diopi_test/python/conformance/customized_test.py b/diopi_test/python/conformance/customized_test.py index 3f351e27c..c95a3996d 100644 --- a/diopi_test/python/conformance/customized_test.py +++ b/diopi_test/python/conformance/customized_test.py @@ -891,3 +891,59 @@ def pool3d(input, kernel_size, stride, padding, dilation, ceil_mode, count_inclu def layer_normGB(input, weight, bias, eps, normalized_shape): return torch.nn.functional.layer_norm(input=input, weight=weight, bias=bias, eps=eps, normalized_shape=normalized_shape) + def batch_norm_GB(input, running_mean, running_var, weight, bias, training=False, momentum=0.1, eps=1e-05, axis=1): + dim = input.dim() + dims = list(range(dim)) + dims.remove(axis) + dims.insert(1, axis) + permuted_input = input.permute(dims) + out = torch.nn.functional.batch_norm( + permuted_input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) + out = out.permute(dims) + return out + + def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_axes=[2, 3], channel_axis=1): + + input_dims = list(input.size()) + reduced_axes_set = set(reduced_axes) + dims = [] + non_reduced_dims = [] + + for i, size in enumerate(input_dims): + if i == channel_axis: + continue + elif i in reduced_axes_set: + continue + else: + non_reduced_dims.append(i) + N = 1 + for i in non_reduced_dims: + N = N * input.size(i) + HxW = 1 + for i in reduced_axes: + HxW = HxW * input.size(i) + C = input.size(channel_axis) + dims = non_reduced_dims + [channel_axis] + reduced_axes + permuted_input = input.permute(dims) + reshaped_input = permuted_input.reshape([N, C, HxW, 1]).contiguous() + out = torch.nn.functional.group_norm( + reshaped_input, + num_groups, + weight=weight, + bias=bias, + eps=eps + ) + + reversed_order = [0]*len(dims) + for i in range(1, len(dims)): + reversed_order[dims[i]] = i + return out.reshape(permuted_input.shape).permute(reversed_order) + \ No newline at end of file diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 3b35e6fc9..e9052394c 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -224,6 +224,73 @@ def promote_type(input: Tensor, promoted_dtype: Dtype) -> Dtype: ] return dtype1 if dtype1 not in need_promote_types else promoted_dtype +def isinf(input) -> Tensor: + func = check_function("diopiHasInf") + out = Tensor(size=input.size(), dtype=Dtype.bool) + ret = func(input.context(), out, input) + check_returncode(ret) + return out + +def trunc(input) -> Tensor: + func = check_function("diopiTrunc") + out = Tensor(size=input.size(), dtype=input.get_dtype()) + ret = func(input.context(), out, input) + check_returncode(ret) + return out + +def round(input) -> Tensor: + func = check_function("diopiTRound") + out = Tensor(size=input.size(), dtype=input.get_dtype()) + ret = func(input.context(), out, input) + check_returncode(ret) + return out + +def hardsigmoid(input) -> Tensor: + func = check_function("diopiHardSigmoid") + out = Tensor(size=input.size(), dtype=input.get_dtype()) + ret = func(input.context(), out, input) + check_returncode(ret) + return out + +def elu(input, alpha) -> Tensor: + func = check_function("diopiElu") + out = Tensor(size=input.size(), dtype=input.get_dtype()) + value = Scalar(alpha) + ret = func(input.context(), out, input, value) + check_returncode(ret) + return out + + +def prelu(input, weight) -> Tensor: + func = check_function("diopiPrelu") + out = Tensor(size=input.size(), dtype=input.get_dtype()) + ret = func(input.context(), out, input, weight) + check_returncode(ret) + return out + + +def selu(input): + func = check_function("diopiSelu") + out = Tensor(size=input.size(), dtype=input.get_dtype()) + ret = func(input.context(), out, input) + check_returncode(ret) + return out + +def softplus(input, beta, threshold): + func = check_function("diopiSoftplus") + beta = Scalar(beta) + threshold = Scalar(threshold) + out = Tensor(size=input.size(), dtype=input.get_dtype()) + ret = func(input.context(), out, input, beta, threshold) + check_returncode(ret) + return out + +def softsign(input): + func = check_function("diopiSoftsign") + out = Tensor(size=input.size(), dtype=input.get_dtype()) + ret = func(input.context(), out, input) + check_returncode(ret) + return out def fill_(input, value): func = check_function("diopiFill") @@ -356,6 +423,15 @@ def relu(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiRelu") +def relu_backward(input, grad_outputs, **kwargs) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + grad_input = raw_like(input) + func = check_function("diopiReluBackward") + ret = func(input.context(), grad_input, grad_outputs[0], input) + check_returncode(ret) + return {"input": grad_input} if grad_input.requires_grad else {} + + def abs(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiAbs") @@ -464,6 +540,15 @@ def log1p(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiLog1p", promote_type(input, Dtype.float32)) +def erf_backward(input, grad_outputs, **kwargs) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + grad_input = raw_like(input) + func = check_function("diopiErfBackward") + ret = func(input.context(), grad_input, grad_outputs[0], input) + check_returncode(ret) + return {"input": grad_input} if grad_input.requires_grad else {} + + def erf(input, inplace=False) -> Tensor: return unary_op(input, inplace, "diopiErf", promote_type(input, Dtype.float32)) @@ -2738,6 +2823,102 @@ def batch_norm( return out +def batch_norm_GB( + input, + running_mean, + running_var, + weight, + bias, + training=False, + momentum=0.1, + eps=1e-05, + axis=1 +) -> Tensor: + dim = input.size().len + dim = [i for i in range(dim) if i!= axis] + dtype = Dtype.float32 if input.get_dtype() == Dtype.float16 else None + _, save_mean = reduce_op_process(input, dim, dtype=dtype) + save_invstd = raw_like(save_mean) + + if not training: + assert ( + running_mean is not None and running_var is not None + ), "if not trainging, running_mean and running_var must be defined" + + out = raw_like(input) + func = check_function("diopiBatchNormGB") + ret = func( + input.context(), + out, + save_mean, + save_invstd, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + axis + ) + + check_returncode(ret) + GLOBAL_STATE["batch_norm_GB_save_mean"] = save_mean + GLOBAL_STATE["batch_norm_GB_save_invstd"] = save_invstd + return out + +def batch_norm_GB_backward( + input, + grad_outputs, + running_mean, + running_var, + weight, + bias, + training=False, + eps=1e-05, + axis = 1, + **kwargs, +) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + save_mean = GLOBAL_STATE.pop("batch_norm_GB_save_mean") + save_invstd = GLOBAL_STATE.pop("batch_norm_GB_save_invstd") + + grad_input = raw_like(input) + grad_weight = raw_like(weight) + grad_bias = raw_like(bias) + + if not training: + assert ( + running_mean is not None and running_var is not None + ), "if not trainging, running_mean and running_var must be defined" + # running_mean = running_mean if running_mean is None else running_mean + # running_var = running_var if running_var is None else running_var + keys = ["input", "weight", "bias"] + grads = [grad_input, grad_weight, grad_bias] + out = {k: v for k, v in zip(keys, grads) if v.requires_grad} + + func = check_function("diopiBatchNormGBBackward") + grad_output = grad_outputs[0] + ret = func( + input.context(), + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + training, + eps, + axis + ) + check_returncode(ret) + return out + def batch_norm_stats(input, eps): func = check_function("diopiBatchNormStats") # cuda accumulate dtype mapping @@ -5111,6 +5292,80 @@ def norm_backward(grad_outputs, input, p, dim, keepdim=False, dtype=None): return {k: v for k, v in out.items() if v.requires_grad} +def group_norm_GB(input, num_groups, weight=None, bias=None, eps=1e-05, reduced_axes=[2, 3], channel_axis=1): + dim = list(input.size().data) + N = 1 + for i in range(len(dim)): + if i not in reduced_axes and i != channel_axis: + N = N * dim[i] + save_mean = Tensor((N, num_groups), input.get_dtype()) + save_invstd = raw_like(save_mean) + + weight = None if weight is None else weight + bias = None if bias is None else bias + + reduced_axes = Sizes(reduced_axes) + out = raw_like(input) + func = check_function("diopiGroupNormGB") + ret = func( + input.context(), + out, + save_mean, + save_invstd, + input, + weight, + bias, + num_groups, + eps, + reduced_axes, + channel_axis + ) + check_returncode(ret) + GLOBAL_STATE["group_norm_GB_save_mean"] = save_mean + GLOBAL_STATE["group_norm_GB_save_invstd"] = save_invstd + return out + + +def group_norm_GB_backward( + input, + grad_outputs, + num_groups, + weight=None, + bias=None, + eps=1e-05, + reduced_axes=[2, 3], + channel_axis=1, + **kwargs, +) -> Tensor: + assert len(grad_outputs) == 1, "only accept 1 gradient to do backward" + save_mean = GLOBAL_STATE.pop("group_norm_GB_save_mean") + save_invstd = GLOBAL_STATE.pop("group_norm_GB_save_invstd") + grad_input = raw_like(input) + grad_weight = raw_like(weight) + grad_bias = raw_like(bias) + weight = None if weight is None else weight + bias = None if bias is None else bias + + out = {"input": grad_input, "weight": grad_weight, "bias": grad_bias} + func = check_function("diopiGroupNormGBBackward") + reduced_axes = Sizes(reduced_axes) + ret = func( + input.context(), + grad_input, + grad_weight, + grad_bias, + grad_outputs[0], + input, + weight, + save_mean, + save_invstd, + num_groups, + reduced_axes, + channel_axis, + ) + check_returncode(ret) + return {k: v for k, v in out.items() if v.requires_grad} + def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): dim = list(input.size().data) diff --git a/diopi_test/python/conformance/diopi_manual_test.py b/diopi_test/python/conformance/diopi_manual_test.py index 7cdf82ee7..ebca10766 100644 --- a/diopi_test/python/conformance/diopi_manual_test.py +++ b/diopi_test/python/conformance/diopi_manual_test.py @@ -1,12 +1,33 @@ # Copyright (c) 2023, DeepLink. # -*- coding: UTF-8 -*- import numpy as np +import diopilib from diopilib import build_generator_state from .diopi_runtime import Tensor, Generator, default_context from . import diopi_functions as F class ManualTest(object): + + def test_dropout_backward(input, p, atol, rtol): + import torch + import pytest + grad_in = Tensor(input.size().data, input.get_dtype()) + torch_input = torch.from_numpy(input.numpy()).requires_grad_(False) + torch_input[torch_input==0] = 0.5 + torch_input = torch_input.requires_grad_() + torch_ones = torch.ones_like(torch_input) + grad_outputs = Tensor.from_numpy(torch_ones.numpy()) + out = torch.nn.functional.dropout(torch_input, p=p, training=True) + out.backward(torch_ones) + mask = Tensor.from_numpy(out.ne(0).numpy()) + if hasattr(diopilib, "diopiDropoutBackward"): + diopilib.diopiDropoutBackward(input.context(), grad_in, grad_outputs, mask, p) + assert np.allclose(grad_in.numpy(), torch_input.grad.numpy(), rtol=rtol, atol=atol) + else: + pytest.xfail("diopiDropoutBackward not support") + + def test_dropout_(func, input, p=0.5, training=True, inplace=False): input_numpy = input.numpy() state = build_generator_state(input.context()) @@ -16,8 +37,8 @@ def test_dropout_(func, input, p=0.5, training=True, inplace=False): out_numpy = out.numpy() mask_numpy = mask.numpy() - rtol = 1e-2 if input_numpy.dtype == np.float16 else 1e-4 - atol = 5e-2 if input_numpy.dtype == np.float16 else 1e-5 + rtol = 1e-2 if input_numpy.dtype == np.float16 else 1e-3 + atol = 5e-2 if input_numpy.dtype == np.float16 else 1e-3 if training and input.numel() > 0: # compute ratio @@ -30,6 +51,10 @@ def test_dropout_(func, input, p=0.5, training=True, inplace=False): ref = input_numpy[mask_numpy == 1] assert np.allclose(remains, ref / (1 - p), rtol=rtol, atol=atol), \ f"failed to execute {name}, dropout value doesn't matches." + + if name == 'dropout': + ManualTest.test_dropout_backward(input, p, atol, rtol) + if mask.numel() > 100: # 0.05 is from pytorch assert np.abs(real_ratio - (1 - p)) < 0.05, \ @@ -43,7 +68,7 @@ def test_dropout(input, p=0.5, training=True, inplace=False): def test_dropout2d(input, p=0.5, training=True, inplace=False): ManualTest.test_dropout_(F.dropout2d, input, p, training, inplace) - + def test_randperm(n): state = build_generator_state(default_context) generator = Generator(state) diff --git a/diopi_test/python/conformance/global_op_list.py b/diopi_test/python/conformance/global_op_list.py index aab78faa8..af6185fa0 100644 --- a/diopi_test/python/conformance/global_op_list.py +++ b/diopi_test/python/conformance/global_op_list.py @@ -11,6 +11,7 @@ "conv2d": ["2d", "input", "weight"], "conv3d": ["3d", "input", "weight"], "batch_norm": ["input"], + "batch_norm_GB": ["input", "running_mean", "running_var"], "adaptive_avg_pool2d": ["2d", "input"], "adaptive_max_pool2d": ["2d", "input"], "adaptive_avg_pool3d": ["3d", "input"], @@ -64,6 +65,7 @@ ops_with_states = { "batch_norm": {"running_mean", "running_var"}, + "batch_norm_GB": {"running_mean", "running_var"}, "sgd": {"buf", "param"}, "fill_": {"input"}, "zero_": {"input"}, diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 6ee8e104e..3bd99f27d 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -65,6 +65,89 @@ const char* diopiGetImplVersion() { return version; } +diopiError_t diopiHasInf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + CALL_ATEN_FUNC(isinf_out, atOut, atInput); + + return diopiSuccess; +} + +diopiError_t diopiTrunc(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + CALL_ATEN_FUNC(trunc_out, atOut, atInput); + return diopiSuccess; +} + +diopiError_t diopiRound(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + CALL_ATEN_FUNC(round_out, atOut, atInput); + + return diopiSuccess; +} + +diopiError_t diopiHardSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + CALL_ATEN_FUNC(hardsigmoid_out, atOut, atInput); + + return diopiSuccess; +} + +diopiError_t diopiElu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* alpha) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + auto atAlpha = impl::aten::buildAtScalar(alpha); + CALL_ATEN_FUNC(elu_out, atOut, atInput, atAlpha); + return diopiSuccess; +} + +diopiError_t diopiPrelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiTensorHandle_t weight) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atWeight = impl::aten::buildATen(weight); + auto atOut = CALL_ATEN_FUNC(prelu, atInput, atWeight); + impl::aten::updateATen2Tensor(ctx, atOut, out); + return diopiSuccess; +} + +diopiError_t diopiSelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = CALL_ATEN_FUNC(selu, atInput); + impl::aten::updateATen2Tensor(ctx, atOut, out); + return diopiSuccess; +} + +diopiError_t diopiSoftplus(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* beta, + const diopiScalar_t* threshold) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atOut = impl::aten::buildATen(out); + auto atBeta = impl::aten::buildAtScalar(beta); + auto atThreshold = impl::aten::buildAtScalar(threshold); + CALL_ATEN_FUNC(softplus_out, atOut, atInput, atBeta, atThreshold); + return diopiSuccess; +} + +diopiError_t diopiSoftsign(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atAbsInput = CALL_ATEN_FUNC(abs, atInput); + auto atDenominator = CALL_ATEN_FUNC(add, atAbsInput, 1.0); + auto atOut = CALL_ATEN_FUNC(div, atInput, atDenominator); + impl::aten::updateATen2Tensor(ctx, atOut, out); + return diopiSuccess; +} + diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) { impl::aten::setCurStream(ctx); auto atOut = impl::aten::buildATen(out); @@ -75,6 +158,18 @@ diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiC return diopiSuccess; } +diopiError_t diopiReluBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + + auto atGradOut = impl::aten::buildATen(grad_out); + auto atInput = impl::aten::buildATen(input); + auto atGradIn = impl::aten::buildATen(grad_in); + auto mask = (atInput > 0).to(atGradOut.dtype()); + atGradIn.copy_(atGradOut * mask); + + return diopiSuccess; +} + diopiError_t diopiReluInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); @@ -1505,6 +1600,17 @@ diopiError_t diopiErf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiCo return diopiSuccess; } +diopiError_t diopiErfBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, diopiConstTensorHandle_t input) { + impl::aten::setCurStream(ctx); + auto atGradIn = impl::aten::buildATen(grad_in); + auto atGradOut = impl::aten::buildATen(grad_out); + auto atInput = impl::aten::buildATen(input); + auto local_grad = (2.0 / std::sqrt(M_PI)) * at::exp(-atInput * atInput); + atGradIn.copy_(atGradOut * local_grad); + + return diopiSuccess; +} + diopiError_t diopiErfInp(diopiContextHandle_t ctx, diopiTensorHandle_t input) { impl::aten::setCurStream(ctx); auto atInput = impl::aten::buildATen(input); @@ -2388,6 +2494,17 @@ diopiError_t diopiDropoutInp(diopiContextHandle_t ctx, diopiTensorHandle_t input return diopiSuccess; } +diopiError_t diopiDropoutBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output, diopiTensorHandle_t mask, + double p) { + impl::aten::setCurStream(ctx); + auto atGradInput = impl::aten::buildATen(grad_input); + auto atGradOutput = impl::aten::buildATen(grad_output); + auto atMask = impl::aten::buildATen(mask); + CALL_ATEN_FUNC(native_dropout_backward_out, atGradInput, atGradOutput, atMask, 1.0 / (1 - p)); + + return diopiSuccess; +} + diopiError_t diopiMSELoss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiReduction_t reduction) { impl::aten::setCurStream(ctx); @@ -2449,6 +2566,84 @@ diopiError_t diopiBatchNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, d return diopiSuccess; } +diopiError_t diopiBatchNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, diopiTensorHandle_t running_mean, + diopiTensorHandle_t running_var, bool training, double momentum, double eps, int64_t axis) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atWeight = impl::aten::buildATen(weight); + auto atBias = impl::aten::buildATen(bias); + auto atRunningMean = impl::aten::buildATen(running_mean); + auto atRunningVar = impl::aten::buildATen(running_var); + auto atOut = impl::aten::buildATen(out); + auto atSaveMean = impl::aten::buildATen(save_mean); + auto atSaveInvstd = impl::aten::buildATen(save_invstd); + + std::vector dims(atInput.dim()); + std::iota(dims.begin(), dims.end(), 0); + std::swap(dims[1], dims[axis]); + auto permutedInput = atInput.permute(dims); + CALL_ATEN_CUDA_FUNC( + native_batch_norm_out, atOut, atSaveMean, atSaveInvstd, permutedInput, atWeight, atBias, atRunningMean, atRunningVar, training, momentum, eps); + atOut = atOut.permute(dims); + return diopiSuccess; +} + +diopiError_t diopiBatchNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias, + diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t running_mean, diopiConstTensorHandle_t running_var, diopiConstTensorHandle_t save_mean, + diopiConstTensorHandle_t save_invstd, bool training, double eps, int64_t axis) { + impl::aten::setCurStream(ctx); + + auto atGradOutput = impl::aten::buildATen(grad_output); + auto atInput = impl::aten::buildATen(input); + auto atWeight = impl::aten::buildATen(weight); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(atRunningMean, running_mean); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(atRunningVar, running_var); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(atSaveMean, save_mean); + DIOPI_IMPL_BUILD_ATEN_OPTIONAL(atSaveVar, save_invstd); + std::vector dims(atInput.dim()); + std::iota(dims.begin(), dims.end(), 0); + std::swap(dims[1], dims[axis]); + auto permutedAtInput = atInput.permute(dims); + if (grad_input && grad_weight && grad_bias) { + auto grad_input_mask = std::array{true, true, true}; + auto atGradInput = impl::aten::buildATen(grad_input).permute(dims); + auto atGradWeight = impl::aten::buildATen(grad_weight); + auto atGradBias = impl::aten::buildATen(grad_bias); + at::native_batch_norm_backward_out(atGradInput, + atGradWeight, + atGradBias, + atGradOutput.permute(dims), + atInput.permute(dims), + atWeight, + atRunningMean, + atRunningVar, + atSaveMean, + atSaveVar, + training, + eps, + grad_input_mask); + atGradInput = atGradInput.permute(dims); + // impl::aten::updateATen2Tensor(ctx, std::get<0>(atOut), grad_input); + } else { + auto grad_input_mask = std::array{grad_input != nullptr, grad_weight != nullptr, grad_bias != nullptr}; + auto atOut = at::native_batch_norm_backward( + atGradOutput.permute(dims), permutedAtInput, atWeight, atRunningMean, atRunningVar, atSaveMean, atSaveVar, training, eps, grad_input_mask); + if (grad_input) { + impl::aten::updateATen2Tensor(ctx, std::get<0>(atOut), grad_input); + } + if (grad_weight) { + impl::aten::updateATen2Tensor(ctx, std::get<1>(atOut), grad_weight); + } + if (grad_bias) { + impl::aten::updateATen2Tensor(ctx, std::get<2>(atOut), grad_bias); + } + } + + return diopiSuccess; +} + diopiError_t diopiSlice(diopiContextHandle_t ctx, diopiTensorHandle_t null_out, diopiConstTensorHandle_t input, int64_t dim, int64_t start, int64_t end, int64_t step) { impl::aten::setCurStream(ctx); @@ -4051,6 +4246,146 @@ diopiError_t diopiForeachnormScalar(diopiContextHandle_t ctx, diopiTensorHandle_ return diopiSuccess; } +diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps, + diopiSize_t reduced_axes, const int64_t channel_axis) { + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + std::vector dims; + int64_t N = 1; + for (int i = 0; i < atInput.dim(); i++) { + if (i == channel_axis) { + continue; + } else { + bool is_reduced_axis = false; + for (int m = 0; m < reduced_axes.len; m++) { + if (i == reduced_axes.data[m]) { + is_reduced_axis = true; + break; + } + } + if (is_reduced_axis) { + continue; + } else { + dims.push_back(i); + N *= atInput.size(i); + } + } + } + dims.push_back(channel_axis); + int64_t HxW = 1; + for (auto i = 0; i < reduced_axes.len; i++) { + dims.push_back(reduced_axes.data[i]); + HxW *= atInput.size(reduced_axes.data[i]); + } + auto C = atInput.size(channel_axis); + auto permutedInput = atInput.permute(dims); + auto permutedShape = permutedInput.sizes(); + auto reshapedInput = permutedInput.reshape({N, C, HxW, 1}).contiguous(); + + auto atWeight = impl::aten::buildATen(weight); + auto atBias = impl::aten::buildATen(bias); + auto atOut = impl::aten::buildATen(out); + auto atSaveMean = impl::aten::buildATen(save_mean); + auto atSaveInvstd = impl::aten::buildATen(save_invstd); + + std::vector reverse_order(dims.size()); + for (auto i = 0; i < atInput.dim(); i++) { + reverse_order[dims[i]] = i; + } + auto tempOut = CALL_ATEN_CUDA_FUNC(native_group_norm, reshapedInput, atWeight, atBias, N, C, HxW, num_groups, eps); + at::native::copy_(atOut, std::get<0>(tempOut).reshape(permutedShape).permute(reverse_order), true); + at::native::copy_(atSaveMean, std::get<1>(tempOut), true); + at::native::copy_(atSaveInvstd, std::get<2>(tempOut), true); + return diopiSuccess; +} + +diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias, + diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, int64_t num_groups, diopiSize_t reduced_axes, + const int64_t channel_axis) { + impl::aten::setCurStream(ctx); + auto atGradOutput = impl::aten::buildATen(grad_output); + auto atInput = impl::aten::buildATen(input); + auto atWeight = impl::aten::buildATen(weight); + auto atSaveMean = impl::aten::buildATen(mean); + auto atSaveVar = impl::aten::buildATen(rstd); + auto atGradWeight = impl::aten::buildATen(grad_weight); + auto atGradBias = impl::aten::buildATen(grad_bias); + std::vector dims; + int64_t N = 1; + for (int i = 0; i < atInput.dim(); i++) { + if (i == channel_axis) { + continue; + } else { + bool is_reduced_axis = false; + for (int m = 0; m < reduced_axes.len; m++) { + if (i == reduced_axes.data[m]) { + is_reduced_axis = true; + break; + } + } + if (is_reduced_axis) { + continue; + } else { + dims.push_back(i); + N *= atInput.size(i); + } + } + } + dims.push_back(channel_axis); + int64_t HxW = 1; + for (auto i = 0; i < reduced_axes.len; i++) { + dims.push_back(reduced_axes.data[i]); + HxW *= atInput.size(reduced_axes.data[i]); + } + auto C = atInput.size(channel_axis); + auto permutedInput = atInput.permute(dims); + auto permutedShape = permutedInput.sizes(); + auto reshapedInput = permutedInput.reshape({N, C, HxW, 1}).contiguous(); + + std::vector reverse_order(dims.size()); + for (auto i = 0; i < atInput.dim(); i++) { + reverse_order[dims[i]] = i; + } + + if (grad_weight && grad_bias) { + auto atGradInput = impl::aten::buildATen(grad_input).permute(dims).reshape({N, C, HxW, 1}); + + at::native_group_norm_backward_out(atGradInput, + atGradWeight, + atGradBias, + atGradOutput.permute(dims).reshape({N, C, HxW, 1}), + reshapedInput, + atSaveMean, + atSaveVar, + atWeight, + N, + C, + HxW, + num_groups, + {true, true, true}); + atGradInput = atGradInput.reshape(permutedShape).permute(reverse_order); + impl::aten::updateATen2Tensor(ctx, atGradInput, grad_input); + } else { + auto atOuts = at::native_group_norm_backward(atGradOutput.permute(dims).reshape({N, C, HxW, 1}), + reshapedInput, + atSaveMean, + atSaveVar, + atWeight, + N, + C, + HxW, + num_groups, + {true, grad_weight != nullptr, grad_bias != nullptr}); + impl::aten::updateATen2Tensor(ctx, std::get<0>(atOuts).reshape(permutedShape).permute(reverse_order), grad_input); + impl::aten::updateATen2Tensor(ctx, std::get<1>(atOuts), grad_weight); + impl::aten::updateATen2Tensor(ctx, std::get<2>(atOuts), grad_bias); + } + + return diopiSuccess; +} + diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps) { impl::aten::setCurStream(ctx); diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 4f7dfcecb..ae54de2ee 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -19,6 +19,51 @@ extern "C" { DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetVendorName(); DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetImplVersion(); DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetLastErrorString(); +/** + * @brief Returns whether the input tensor contains any Inf values. + */ +DIOPI_API diopiError_t diopiHasInf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); + +/** + * @brief Truncates the input tensor to an integer value. + */ +DIOPI_API diopiError_t diopiTrunc(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); + +/** + * @brief Rounds the input tensor to the nearest integer value. + */ +DIOPI_API diopiError_t diopiRound(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); + +/** + * @brief Applies the hard sigmoid activation function to an input tensor. + */ +DIOPI_API diopiError_t diopiHardSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); + +/** + * @brief Applies the exponential linear unit (ELU) activation function to an input tensor. + */ +DIOPI_API diopiError_t diopiElu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* alpha); + +/** + * @brief Applies the parametric rectified linear unit (PReLU) activation function to an input tensor. + */ +DIOPI_API diopiError_t diopiPrelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiTensorHandle_t weight); + +/** + * @brief Applies the SELU activation function to an input tensor. + */ +DIOPI_API diopiError_t diopiSelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); + +/** + * @brief Applies the softplus activation function to an input tensor. + */ +DIOPI_API diopiError_t diopiSoftplus(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* beta, + const diopiScalar_t* threshold); + +/** + * @brief Applies the softsign activation function to an input tensor. + */ +DIOPI_API diopiError_t diopiSoftsign(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); /** * @brief Applies a 2D convolution over an input image composed of several input planes. @@ -75,6 +120,23 @@ DIOPI_API diopiError_t diopiBatchNorm(diopiContextHandle_t ctx, diopiTensorHandl diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, diopiTensorHandle_t running_mean, diopiTensorHandle_t running_var, bool training, double momentum, double eps); +/** + * @brief Applies Batch Normalization. + */ +DIOPI_API diopiError_t diopiBatchNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, + diopiTensorHandle_t running_mean, diopiTensorHandle_t running_var, bool training, double momentum, double eps, + int64_t axis); + +/** + * @brief Backward pass for Batch Normalization. + */ +DIOPI_API diopiError_t diopiBatchNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, + diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t weight, diopiConstTensorHandle_t running_mean, diopiConstTensorHandle_t running_var, + diopiConstTensorHandle_t save_mean, diopiConstTensorHandle_t save_invstd, bool training, double eps, + int64_t axis); + /** * @brief Computes the mean and inverse standard deviation across a batch of data for Synchronized Batch Normalization (SyncBN). * @param[in] ctx Context environment. @@ -191,6 +253,18 @@ DIOPI_API diopiError_t diopiBatchNormBackward(diopiContextHandle_t ctx, diopiTen */ DIOPI_API diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); +/** + * @brief Computes the gradient of the rectified linear unit function. + */ +DIOPI_API diopiError_t diopiReluBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, + diopiConstTensorHandle_t input); + +/** + * @brief Comput the gradient of the error function. + */ +DIOPI_API diopiError_t diopiErfBackward(diopiContextHandle_t ctx, diopiConstTensorHandle_t grad_in, diopiTensorHandle_t grad_out, + diopiConstTensorHandle_t input); + /** * @brief The in-place version of diopiRelu(). * @param[in] ctx Context environment. @@ -656,6 +730,13 @@ DIOPI_API diopiError_t diopiAdaptiveMaxPool2dBackward(diopiContextHandle_t ctx, */ DIOPI_API diopiError_t diopiDropout(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t mask, diopiConstTensorHandle_t input, double p, bool train, diopiGeneratorHandle_t generator); + +/** + *@brief Compute the backward pass of diopiDropout(). + */ +DIOPI_API diopiError_t diopiDropoutBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output, + diopiTensorHandle_t mask, double p); + /** * @brief The in-place version of diopiDropout(). * @param[in] ctx Context environment. @@ -3530,6 +3611,20 @@ DIOPI_API diopiError_t diopiGroupNorm(diopiContextHandle_t ctx, diopiTensorHandl diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, double eps); +/** + * @brief Applies Group Normalization over a mini-batch of inputs. + */ +DIOPI_API diopiError_t diopiGroupNormGB(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t save_mean, diopiTensorHandle_t save_invstd, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, int64_t num_groups, + double eps, diopiSize_t reduced_axes, const int64_t channel_axis); + +/** + * @brief Compute the backward pass of diopiGroupNorm(). + */ +DIOPI_API diopiError_t diopiGroupNormGBBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, + diopiTensorHandle_t grad_bias, diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t weight, diopiConstTensorHandle_t mean, diopiConstTensorHandle_t rstd, + int64_t num_groups, diopiSize_t reduced_axes, const int64_t channel_axis); /** * @brief Compute the backward pass of diopiGroupNorm(). * @param[in] ctx Context environment.