From c8285363fdd0ddf3e3c45cb25c4a01d2b215779e Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Mon, 30 Jun 2025 19:41:42 +0800 Subject: [PATCH 1/4] issue/207 operator: sigmoid op on cpu and cuda --- include/infiniop.h | 1 + include/infiniop/ops/sigmoid.h | 24 ++ scripts/python_test.py | 1 + src/infiniop-test/include/ops.hpp | 2 + src/infiniop-test/src/ops/sigmoid.cpp | 103 ++++++++ src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc | 49 ++++ src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h | 19 ++ src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu | 55 ++++ .../ops/sigmoid/cuda/sigmoid_cuda.cuh | 8 + .../sigmoid/cuda/sigmoid_cuda_internal.cuh | 30 +++ src/infiniop/ops/sigmoid/operator.cc | 115 +++++++++ .../test_generate/testcases/sigmoid.py | 136 ++++++++++ test/infiniop/sigmoid.py | 239 ++++++++++++++++++ 13 files changed, 782 insertions(+) create mode 100644 include/infiniop/ops/sigmoid.h create mode 100644 src/infiniop-test/src/ops/sigmoid.cpp create mode 100644 src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc create mode 100644 src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h create mode 100644 src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu create mode 100644 src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cuh create mode 100644 src/infiniop/ops/sigmoid/cuda/sigmoid_cuda_internal.cuh create mode 100644 src/infiniop/ops/sigmoid/operator.cc create mode 100644 test/infiniop-test/test_generate/testcases/sigmoid.py create mode 100644 test/infiniop/sigmoid.py diff --git a/include/infiniop.h b/include/infiniop.h index d51b8d92e..ce7239729 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -14,6 +14,7 @@ #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" +#include "infiniop/ops/sigmoid.h" #include "infiniop/ops/sub.h" #include "infiniop/ops/swiglu.h" #include "infiniop/tensor_descriptor.h" diff --git a/include/infiniop/ops/sigmoid.h b/include/infiniop/ops/sigmoid.h new file mode 100644 index 000000000..4fa0f6604 --- /dev/null +++ b/include/infiniop/ops/sigmoid.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_SIGMOID_API_H__ +#define __INFINIOP_SIGMOID_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSigmoidDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSigmoidDescriptor(infiniopHandle_t handle, + infiniopSigmoidDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopSigmoid(infiniopSigmoidDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc); + +#endif diff --git a/scripts/python_test.py b/scripts/python_test.py index eb2d4319e..9086b5708 100644 --- a/scripts/python_test.py +++ b/scripts/python_test.py @@ -22,6 +22,7 @@ def run_tests(args): "rearrange.py", "rms_norm.py", "rope.py", + "sigmoid.py", "sub.py", "swiglu.py", ]: diff --git a/src/infiniop-test/include/ops.hpp b/src/infiniop-test/include/ops.hpp index 3820f7cfd..d543b4cb3 100644 --- a/src/infiniop-test/include/ops.hpp +++ b/src/infiniop-test/include/ops.hpp @@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add) DECLARE_INFINIOP_TEST(causal_softmax) DECLARE_INFINIOP_TEST(rearrange) DECLARE_INFINIOP_TEST(sub) +DECLARE_INFINIOP_TEST(sigmoid) #define REGISTER_INFINIOP_TEST(name) \ { \ @@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub) REGISTER_INFINIOP_TEST(causal_softmax) \ REGISTER_INFINIOP_TEST(rearrange) \ REGISTER_INFINIOP_TEST(sub) \ + REGISTER_INFINIOP_TEST(sigmoid) \ } namespace infiniop_test { diff --git a/src/infiniop-test/src/ops/sigmoid.cpp b/src/infiniop-test/src/ops/sigmoid.cpp new file mode 100644 index 000000000..bb3a0f70a --- /dev/null +++ b/src/infiniop-test/src/ops/sigmoid.cpp @@ -0,0 +1,103 @@ +#include "ops.hpp" +#include "utils.hpp" +#include +#include +#include + +namespace infiniop_test::sigmoid { +struct Test::Attributes { + std::shared_ptr x; + std::shared_ptr y; + std::shared_ptr ans; +}; + +std::shared_ptr Test::build( + std::unordered_map> attributes, + std::unordered_map> tensors, + double rtol, double atol) { + auto test = std::shared_ptr(new Test(rtol, atol)); + test->_attributes = new Attributes(); + if (tensors.find("x") == tensors.end() + || tensors.find("y") == tensors.end() + || tensors.find("ans") == tensors.end()) { + throw std::runtime_error("Invalid Test"); + } + + test->_attributes->x = tensors["x"]; + test->_attributes->y = tensors["y"]; + test->_attributes->ans = tensors["ans"]; + + return test; +} + +std::shared_ptr Test::run( + infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) { + infiniopSigmoidDescriptor_t op_desc; + auto x = _attributes->x->to(device, device_id); + auto y = _attributes->y->to(device, device_id); + CHECK_OR(infiniopCreateSigmoidDescriptor(handle, &op_desc, + y->desc(), + x->desc()), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor.")); + size_t workspace_size; + CHECK_OR(infiniopGetSigmoidWorkspaceSize(op_desc, &workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size.")); + void *workspace; + CHECK_OR(infinirtMalloc(&workspace, workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace.")); + CHECK_OR(infiniopSigmoid(op_desc, workspace, workspace_size, + y->data(), + x->data(), + nullptr), + return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution.")); + + try { + allClose(y, _attributes->ans, _rtol, _atol); + } catch (const std::exception &e) { + return TEST_FAILED(RESULT_INCORRECT, e.what()); + } + + double elapsed_time = 0.; + + elapsed_time = benchmark( + [=]() { + infiniopSigmoid( + op_desc, workspace, workspace_size, + y->data(), + x->data(), + nullptr); + }, + warm_ups, iterations); + + infiniopDestroySigmoidDescriptor(op_desc); + infinirtFree(workspace); + return TEST_PASSED(elapsed_time); +} + +std::vector Test::attribute_names() { + return {}; +} + +std::vector Test::tensor_names() { + return {"x", "y", "ans"}; +} + +std::vector Test::output_names() { + return {"y"}; +} + +std::string Test::toString() const { + std::ostringstream oss; + oss << op_name() << std::endl; + oss << "- x: " << _attributes->x->info() << std::endl; + oss << "- y: " << _attributes->y->info() << std::endl; + oss << std::scientific << std::setprecision(2); + oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; + return oss.str(); +} + +Test::~Test() { + delete _attributes; +} + +} // namespace infiniop_test::sigmoid diff --git a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc new file mode 100644 index 000000000..16fb7ffeb --- /dev/null +++ b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc @@ -0,0 +1,49 @@ +#include "sigmoid_cpu.h" + +namespace op::sigmoid::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::sigmoid::cpu diff --git a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h new file mode 100644 index 000000000..49c963f44 --- /dev/null +++ b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h @@ -0,0 +1,19 @@ +#ifndef __SIGMOID_CPU_H__ +#define __SIGMOID_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(sigmoid, cpu) + +namespace op::sigmoid::cpu { +typedef struct SigmoidOp { +public: + static constexpr size_t num_inputs = 1; + template + T operator()(const T &x) const { + return T(1) / (T(1) + std::exp(-x)); + } +} SigmoidOp; +} // namespace op::sigmoid::cpu + +#endif // __SIGMOID_CPU_H__ diff --git a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu b/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu new file mode 100644 index 000000000..0c59dc361 --- /dev/null +++ b/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu @@ -0,0 +1,55 @@ +#include "sigmoid_cuda.cuh" +#include "sigmoid_cuda_internal.cuh" + +namespace op::sigmoid::cuda { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, SigmoidOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, SigmoidOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, SigmoidOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::sigmoid::cuda diff --git a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cuh b/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cuh new file mode 100644 index 000000000..f3eec5cf4 --- /dev/null +++ b/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __SIGMOID_CUDA_API_H__ +#define __SIGMOID_CUDA_API_H__ + +#include "../../../elementwise/cuda/elementwise_cuda_api.cuh" + +ELEMENTWISE_DESCRIPTOR(sigmoid, cuda) + +#endif // __SIGMOID_CUDA_API_H__ diff --git a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda_internal.cuh b/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda_internal.cuh new file mode 100644 index 000000000..fca168114 --- /dev/null +++ b/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda_internal.cuh @@ -0,0 +1,30 @@ +#ifndef __SIDMOID_CUDA_H__ +#define __SIDMOID_CUDA_H__ + +#include "../../../elementwise/cuda/elementwise_cuda.cuh" +#include + +namespace op::sigmoid::cuda { +typedef struct SigmoidOp { +public: + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + // sigmoid(x) = 1 / (1 + exp(-x)) + if constexpr (std::is_same_v) { + half2 denominator = __hadd2(make_half2(1, 1), h2exp(__hneg2(x))); + return h2rcp(denominator); + } else if constexpr (std::is_same_v) { + half denominator = __hadd(__float2half(1.0f), hexp(__hneg(x))); + return hrcp(denominator); + } else if constexpr (std::is_same_v) { + float denominator = __fadd_rn(1.0f, __expf(-x)); + return __frcp_rn(denominator); + } else { // double + return 1.0 / (1.0 + exp(-x)); + } + } +} SigmoidOp; +} // namespace op::sigmoid::cuda + +#endif // __SIDMOID_CUDA_H__ diff --git a/src/infiniop/ops/sigmoid/operator.cc b/src/infiniop/ops/sigmoid/operator.cc new file mode 100644 index 000000000..e4c22fa21 --- /dev/null +++ b/src/infiniop/ops/sigmoid/operator.cc @@ -0,0 +1,115 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/sigmoid.h" + +#ifdef ENABLE_CPU_API +#include "cpu/sigmoid_cpu.h" +#endif +#ifdef ENABLE_CUDA_API +#include "cuda/sigmoid_cuda.cuh" +#endif + +__C infiniStatus_t infiniopCreateSigmoidDescriptor( + infiniopHandle_t handle, + infiniopSigmoidDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::sigmoid::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_CUDA_API + GET(INFINI_DEVICE_NVIDIA, cuda) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopSigmoid( + infiniopSigmoidDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CALCULATE(INFINI_DEVICE_NVIDIA, cuda); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + DELETE(INFINI_DEVICE_NVIDIA, cuda); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infiniop-test/test_generate/testcases/sigmoid.py b/test/infiniop-test/test_generate/testcases/sigmoid.py new file mode 100644 index 000000000..f622a4d6e --- /dev/null +++ b/test/infiniop-test/test_generate/testcases/sigmoid.py @@ -0,0 +1,136 @@ +import numpy as np +from numpy.lib.stride_tricks import as_strided +import gguf +from typing import List + +from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor + + +def sigmoid( + x: np.ndarray, +): + return 1 / (1 + np.exp(-x)) + + +def random_tensor(shape, dtype): + rate = 1e-3 + var = 0.5 * rate + return rate * np.random.rand(*shape).astype(dtype) - var + + +def process_tensors(a, b, stride_a=None, stride_b=None): + def normalize_stride(tensor, stride): + if stride: + slices = tuple(slice(0, 1) if s == 0 else slice(None) for s in stride) + return tensor[slices] + else: + return tensor + + a_unique = normalize_stride(a, stride_a) + b_unique = normalize_stride(b, stride_b) + return a_unique, b_unique + + +def process_tensor(a, stride_a=None): + def normalize_stride(tensor, stride): + if stride: + slices = tuple(slice(0, 1) if s == 0 else slice(None) for s in stride) + return tensor[slices] + else: + return tensor + + a_unique = normalize_stride(a, stride_a) + return a_unique + + +class SigmoidTestCase(InfiniopTestCase): + def __init__( + self, + x: np.ndarray, + shape_x: List[int] | None, + stride_x: List[int] | None, + y: np.ndarray, + shape_y: List[int] | None, + stride_y: List[int] | None, + ): + super().__init__("sigmoid") + self.x = x + self.shape_x = shape_x + self.stride_x = stride_x + + self.y = y + self.shape_y = shape_y + self.stride_y = stride_y + + def write_test(self, test_writer: "InfiniopTestWriter"): + super().write_test(test_writer) + + if self.shape_x is not None: + test_writer.add_array(test_writer.gguf_key("x.shape"), self.shape_x) + if self.shape_y is not None: + test_writer.add_array(test_writer.gguf_key("y.shape"), self.shape_y) + + if self.stride_x is not None: + test_writer.add_array(test_writer.gguf_key("x.strides"), gguf_strides(*self.stride_x)) + + test_writer.add_array( + test_writer.gguf_key("y.strides"), + gguf_strides(*self.stride_y if self.stride_y is not None else contiguous_gguf_strides(self.shape_y)) + ) + + test_writer.add_tensor( + test_writer.gguf_key("x"), self.x, raw_dtype=np_dtype_to_ggml(self.x.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("y"), self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype) + ) + + input_x = self.x.astype(np.float64) + if (self.stride_x is not None) and (0 in self.stride_x): + typesize = np.dtype(input_x.dtype).itemsize + new_strides_bytes = tuple(x * typesize for x in self.stride_x) + input_x = as_strided(x=input_x, shape=self.shape_x, strides=new_strides_bytes) + + ans = sigmoid(input_x) + + test_writer.add_tensor( + test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64 + ) + + +if __name__ == '__main__': + test_writer = InfiniopTestWriter("sigmoid.gguf") + + test_cases = [] + _TEST_CASES_ = [ + # shape, x_stride, y_stride + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4), (0, 1), None), + ((13, 4, 4), None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), None), + ((16, 5632), None, None), + ((16, 5632), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1)), + ] + _TENSOR_DTYPES_ = [np.float16, np.float32] + + for dtype in _TENSOR_DTYPES_: + for shape, stride_x, stride_y in _TEST_CASES_: + x = np.random.rand(*shape).astype(dtype) + y = np.empty(tuple(0 for _ in shape), dtype=dtype) + + x = process_zero_stride_tensor(x, stride_x) + test_case = SigmoidTestCase(x=x, + shape_x=shape, + stride_x=stride_x, + y=y, + shape_y=shape, + stride_y=stride_y) + + test_cases.append(test_case) + + test_writer.add_tests(test_cases) + test_writer.save() diff --git a/test/infiniop/sigmoid.py b/test/infiniop/sigmoid.py new file mode 100644 index 000000000..65a8bb515 --- /dev/null +++ b/test/infiniop/sigmoid.py @@ -0,0 +1,239 @@ +import os + +import torch +import ctypes +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 +from libinfiniop import (infiniopHandle_t, + infiniopTensorDescriptor_t, + open_lib, + to_tensor, + get_test_devices, + check_error, + rearrange_if_needed, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + create_workspace, + ) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, x_stride, y_stride + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4), (0, 1), None,), + ((13, 4, 4), None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), None), + ((16, 5632), None, None), + ((16, 5632), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1)), + ((4, 4, 56320), None, None), +] + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_X = auto() + + +# Inplace options applied for each test case in _TEST_CASES_ +_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE_X, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + +# Data types used for testing +_TENSOR_DTYPES = [torch.float16, torch.float32, torch.float64] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + torch.float16: {"atol": 1e-3, "rtol": 1e-3}, + torch.float32: {"atol": 1e-7, "rtol": 1e-7}, + torch.float64: {"atol": 1e-7, "rtol": 1e-7}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class SigmoidDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopSigmoidDescriptor_t = POINTER(SigmoidDescriptor) + + +def sigmoid_torch(x): + return torch.sigmoid(x) + + +def process_tensors(y, y_strides, x, x_stride, inplace): + """ + rearrange the tensors if needed and apply the inplace config. + if inplace is true and the output (i.e., c) is placed to the broadcasted input, + the inplace config is ignored and out-of-place is used + """ + original_y_strides = y_strides if y_strides else y.stride() + + + def _rearrange(tensor, strides): + if strides and 0 in strides: + tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides) + return tensor + else: + return rearrange_if_needed(tensor, strides) + + x, y = [ + _rearrange(tensor, stride) + for tensor, stride in zip([x, y], [x_stride, y_strides]) + ] + y = ( + y + if inplace == Inplace.OUT_OF_PLACE + else (x) + ) + # if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides + if 0 in y.stride(): + y.set_(y.untyped_storage(), 0, y.shape, original_y_strides) + + return x, y + + +def test(lib, + handle, + torch_device, + shape, + x_stride=None, + y_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=torch.float16, + sync=None, + ): + print(f"Testing Sigmoid on {torch_device} with shape:{shape} x_stride:{x_stride} c_stride:{y_stride} " + f"dtype:{dtype} inplace:{inplace}") + + x = torch.rand(shape, dtype=dtype).to(torch_device) + y = torch.rand(shape, dtype=dtype).to(torch_device) + + x, y = process_tensors(y, y_stride, x, x_stride, inplace) + + ans = sigmoid_torch(x) + + x_tensor, = [to_tensor(tensor, lib) for tensor in [x, ]] + y_tensor = ( + to_tensor(y, lib) + if inplace == Inplace.OUT_OF_PLACE + else x_tensor + ) + if sync is not None: + sync() + + descriptor = infiniopSigmoidDescriptor_t() + check_error( + lib.infiniopCreateSigmoidDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [x_tensor, y_tensor]: + tensor.destroyDesc(lib) + + workspace_size = c_uint64(0) + check_error( + lib.infiniopGetSigmoidWorkspaceSize(descriptor, ctypes.byref(workspace_size)) + ) + workspace = create_workspace(workspace_size.value, y.device) + + def lib_sigmoid(): + check_error( + lib.infiniopSigmoid( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + y_tensor.data, + x_tensor.data, + None, + ) + ) + + lib_sigmoid() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + + assert torch.allclose(y, ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: sigmoid_torch(x), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_sigmoid(), torch_device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(lib.infiniopDestroySigmoidDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + lib = open_lib() + + lib.infiniopCreateSigmoidDescriptor.restype = c_int32 + lib.infiniopCreateSigmoidDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopSigmoidDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetSigmoidWorkspaceSize.restype = c_int32 + lib.infiniopGetSigmoidWorkspaceSize.argtypes = [ + infiniopSigmoidDescriptor_t, + POINTER(c_uint64), + ] + + lib.infiniopSigmoid.restype = c_int32 + lib.infiniopSigmoid.argtypes = [ + infiniopSigmoidDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySigmoidDescriptor.restype = c_int32 + lib.infiniopDestroySigmoidDescriptor.argtypes = [ + infiniopSigmoidDescriptor_t, + ] + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From 0ab5aff1b382dbde799d70c77a77d463734c683e Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Sat, 12 Jul 2025 13:16:21 +0800 Subject: [PATCH 2/4] =?UTF-8?q?issue/207=20-=20=E6=B7=BB=E5=8A=A0bf16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc | 4 +- src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h | 2 +- .../{sigmoid_cuda_internal.cuh => kernel.cuh} | 4 + .../sigmoid_nvidia.cu} | 19 +- .../sigmoid_nvidia.cuh} | 2 +- src/infiniop/ops/sigmoid/operator.cc | 20 +- test/infiniop-test/README.md | 2 +- test/infiniop/libinfiniop/op_register.py | 32 +++ test/infiniop/sigmoid.py | 226 ++++++------------ 9 files changed, 142 insertions(+), 169 deletions(-) rename src/infiniop/ops/sigmoid/cuda/{sigmoid_cuda_internal.cuh => kernel.cuh} (79%) rename src/infiniop/ops/sigmoid/{cuda/sigmoid_cuda.cu => nvidia/sigmoid_nvidia.cu} (65%) rename src/infiniop/ops/sigmoid/{cuda/sigmoid_cuda.cuh => nvidia/sigmoid_nvidia.cuh} (77%) diff --git a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc index 16fb7ffeb..c335bba60 100644 --- a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc +++ b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc @@ -17,7 +17,7 @@ infiniStatus_t Descriptor::create( const auto &y_shape = out_desc->shape(); const auto &x_shape = x_desc->shape(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); CHECK_SAME_SHAPE(y_shape, x_shape); // create CPU elementwise descriptor @@ -40,6 +40,8 @@ infiniStatus_t Descriptor::calculate( return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_F64: return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h index 49c963f44..6ab7eaeb9 100644 --- a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h +++ b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h @@ -3,7 +3,7 @@ #include "../../../elementwise/cpu/elementwise_cpu.h" -ELEMENTWISE_DESCRIPTOR(sigmoid, cpu) +ELEMENTWISE_DESCRIPTOR(sigmoid, cpu, cpu) namespace op::sigmoid::cpu { typedef struct SigmoidOp { diff --git a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda_internal.cuh b/src/infiniop/ops/sigmoid/cuda/kernel.cuh similarity index 79% rename from src/infiniop/ops/sigmoid/cuda/sigmoid_cuda_internal.cuh rename to src/infiniop/ops/sigmoid/cuda/kernel.cuh index fca168114..9094f0248 100644 --- a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda_internal.cuh +++ b/src/infiniop/ops/sigmoid/cuda/kernel.cuh @@ -2,6 +2,7 @@ #define __SIDMOID_CUDA_H__ #include "../../../elementwise/cuda/elementwise_cuda.cuh" +#include #include namespace op::sigmoid::cuda { @@ -17,6 +18,9 @@ public: } else if constexpr (std::is_same_v) { half denominator = __hadd(__float2half(1.0f), hexp(__hneg(x))); return hrcp(denominator); + } else if constexpr (std::is_same_v) { + __nv_bfloat16 denominator = __hadd(__float2bfloat16(1.0f), __float2bfloat16(__expf(__bfloat162float(-x)))); + return __float2bfloat16(1.0f) / denominator; } else if constexpr (std::is_same_v) { float denominator = __fadd_rn(1.0f, __expf(-x)); return __frcp_rn(denominator); diff --git a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu similarity index 65% rename from src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu rename to src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu index 0c59dc361..f5dfe9fdd 100644 --- a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cu +++ b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu @@ -1,7 +1,7 @@ -#include "sigmoid_cuda.cuh" -#include "sigmoid_cuda_internal.cuh" +#include "../cuda/kernel.cuh" +#include "sigmoid_nvidia.cuh" -namespace op::sigmoid::cuda { +namespace op::sigmoid::nvidia { Descriptor::~Descriptor() = default; @@ -18,7 +18,7 @@ infiniStatus_t Descriptor::create( const auto &y_shape = out_desc->shape(); const auto &x_shape = x_desc->shape(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); CHECK_SAME_SHAPE(y_shape, x_shape); @@ -41,15 +41,18 @@ infiniStatus_t Descriptor::calculate( switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, SigmoidOp, half>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::SigmoidOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::SigmoidOp, __nv_bfloat16>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, SigmoidOp, float>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::SigmoidOp, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, SigmoidOp, double>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::SigmoidOp, double>(_info, workspace, output, inputs, stream); + default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } return INFINI_STATUS_SUCCESS; } -} // namespace op::sigmoid::cuda +} // namespace op::sigmoid::nvidia diff --git a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cuh b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh similarity index 77% rename from src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cuh rename to src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh index f3eec5cf4..53dd0a1fa 100644 --- a/src/infiniop/ops/sigmoid/cuda/sigmoid_cuda.cuh +++ b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh @@ -3,6 +3,6 @@ #include "../../../elementwise/cuda/elementwise_cuda_api.cuh" -ELEMENTWISE_DESCRIPTOR(sigmoid, cuda) +ELEMENTWISE_DESCRIPTOR(sigmoid, nvidia, cuda) #endif // __SIGMOID_CUDA_API_H__ diff --git a/src/infiniop/ops/sigmoid/operator.cc b/src/infiniop/ops/sigmoid/operator.cc index e4c22fa21..3f2f95067 100644 --- a/src/infiniop/ops/sigmoid/operator.cc +++ b/src/infiniop/ops/sigmoid/operator.cc @@ -5,8 +5,8 @@ #ifdef ENABLE_CPU_API #include "cpu/sigmoid_cpu.h" #endif -#ifdef ENABLE_CUDA_API -#include "cuda/sigmoid_cuda.cuh" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/sigmoid_nvidia.cuh" #endif __C infiniStatus_t infiniopCreateSigmoidDescriptor( @@ -28,8 +28,8 @@ __C infiniStatus_t infiniopCreateSigmoidDescriptor( #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_CUDA_API - CREATE(INFINI_DEVICE_NVIDIA, cuda); +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif default: @@ -50,8 +50,8 @@ __C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t d #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu) #endif -#ifdef ENABLE_CUDA_API - GET(INFINI_DEVICE_NVIDIA, cuda) +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -79,8 +79,8 @@ __C infiniStatus_t infiniopSigmoid( #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_CUDA_API - CALCULATE(INFINI_DEVICE_NVIDIA, cuda); +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif default: @@ -103,8 +103,8 @@ infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) { #ifdef ENABLE_CPU_API DELETE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_CUDA_API - DELETE(INFINI_DEVICE_NVIDIA, cuda); +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); #endif default: diff --git a/test/infiniop-test/README.md b/test/infiniop-test/README.md index 85e889e42..40dc7e36d 100644 --- a/test/infiniop-test/README.md +++ b/test/infiniop-test/README.md @@ -17,7 +17,7 @@ xmake build infiniop-test 在`/test/infiniop-test/`目录执行矩阵乘测例生成脚本,执行结束以后会在`/test/infiniop-test/`目录生成`gemm.gguf`测例文件。 ```bash -cd /test/infiniop-test/ +cd ./test/infiniop-test/ python -m test_generate.testcases.gemm ``` diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index e92e77105..62febf5c2 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -489,3 +489,35 @@ def conv_(lib): lib.infiniopDestroyConvDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def sigmoid_(lib): + lib.infiniopCreateSigmoidDescriptor.restype = c_int32 + lib.infiniopCreateSigmoidDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetSigmoidWorkspaceSize.restype = c_int32 + lib.infiniopGetSigmoidWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSigmoid.restype = c_int32 + lib.infiniopSigmoid.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySigmoidDescriptor.restype = c_int32 + lib.infiniopDestroySigmoidDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/sigmoid.py b/test/infiniop/sigmoid.py index 65a8bb515..efd8fdb26 100644 --- a/test/infiniop/sigmoid.py +++ b/test/infiniop/sigmoid.py @@ -1,22 +1,22 @@ -import os - import torch import ctypes -from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 -from libinfiniop import (infiniopHandle_t, - infiniopTensorDescriptor_t, - open_lib, - to_tensor, - get_test_devices, - check_error, - rearrange_if_needed, - test_operator, - get_args, - debug, - get_tolerance, - profile_operation, - create_workspace, - ) +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) from enum import Enum, auto # ============================================================================== @@ -24,7 +24,7 @@ # ============================================================================== # These are not meant to be imported from other modules _TEST_CASES_ = [ - # shape, x_stride, y_stride + # shape, a_stride, b_stride, c_stride ((13, 4), None, None), ((13, 4), (10, 1), (10, 1)), ((13, 4), (0, 1), None,), @@ -40,9 +40,8 @@ class Inplace(Enum): - OUT_OF_PLACE = auto() - INPLACE_X = auto() - + OUT_OF_PLACE = auto() + INPLACE_X = auto() # Inplace options applied for each test case in _TEST_CASES_ _INPLACE = [ @@ -58,13 +57,13 @@ class Inplace(Enum): ] # Data types used for testing -_TENSOR_DTYPES = [torch.float16, torch.float32, torch.float64] +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16] # Tolerance map for different data types _TOLERANCE_MAP = { - torch.float16: {"atol": 1e-3, "rtol": 1e-3}, - torch.float32: {"atol": 1e-7, "rtol": 1e-7}, - torch.float64: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, } DEBUG = False @@ -72,107 +71,70 @@ class Inplace(Enum): NUM_PRERUN = 10 NUM_ITERATIONS = 1000 - -class SigmoidDescriptor(Structure): - _fields_ = [("device", c_int32)] - - -infiniopSigmoidDescriptor_t = POINTER(SigmoidDescriptor) - - -def sigmoid_torch(x): - return torch.sigmoid(x) - - -def process_tensors(y, y_strides, x, x_stride, inplace): - """ - rearrange the tensors if needed and apply the inplace config. - if inplace is true and the output (i.e., c) is placed to the broadcasted input, - the inplace config is ignored and out-of-place is used - """ - original_y_strides = y_strides if y_strides else y.stride() - - - def _rearrange(tensor, strides): - if strides and 0 in strides: - tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides) - return tensor - else: - return rearrange_if_needed(tensor, strides) - - x, y = [ - _rearrange(tensor, stride) - for tensor, stride in zip([x, y], [x_stride, y_strides]) - ] - y = ( - y - if inplace == Inplace.OUT_OF_PLACE - else (x) - ) - # if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides - if 0 in y.stride(): - y.set_(y.untyped_storage(), 0, y.shape, original_y_strides) - - return x, y - - -def test(lib, - handle, - torch_device, - shape, - x_stride=None, - y_stride=None, - inplace=Inplace.OUT_OF_PLACE, - dtype=torch.float16, - sync=None, - ): - print(f"Testing Sigmoid on {torch_device} with shape:{shape} x_stride:{x_stride} c_stride:{y_stride} " - f"dtype:{dtype} inplace:{inplace}") - - x = torch.rand(shape, dtype=dtype).to(torch_device) - y = torch.rand(shape, dtype=dtype).to(torch_device) - - x, y = process_tensors(y, y_stride, x, x_stride, inplace) - - ans = sigmoid_torch(x) - - x_tensor, = [to_tensor(tensor, lib) for tensor in [x, ]] - y_tensor = ( - to_tensor(y, lib) - if inplace == Inplace.OUT_OF_PLACE - else x_tensor +def torch_sigmoid(y, x): + torch.sigmoid(x, out=y) + +def test( + handle, + device, + shape, + x_stride=None, + y_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=torch.float16, + sync=None, +): + x = TestTensor(shape, x_stride, dtype, device) + if inplace == Inplace.INPLACE_X: + if x_stride != y_stride: + return + y = x + else: + y = TestTensor(shape, y_stride, dtype, device, mode="ones") + + if y.is_broadcast(): + return + + print( + f"Testing Sigmoid on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} " + f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" ) + + torch_sigmoid(y.torch_tensor(), x.torch_tensor()) + if sync is not None: sync() - - descriptor = infiniopSigmoidDescriptor_t() + + descriptor = infiniopOperatorDescriptor_t() check_error( - lib.infiniopCreateSigmoidDescriptor( + LIBINFINIOP.infiniopCreateSigmoidDescriptor( handle, ctypes.byref(descriptor), - y_tensor.descriptor, - x_tensor.descriptor, + y.descriptor, + x.descriptor, ) ) # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel - for tensor in [x_tensor, y_tensor]: - tensor.destroyDesc(lib) + for tensor in [x, y]: + tensor.destroy_desc() workspace_size = c_uint64(0) check_error( - lib.infiniopGetSigmoidWorkspaceSize(descriptor, ctypes.byref(workspace_size)) + LIBINFINIOP.infiniopGetSigmoidWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) ) - workspace = create_workspace(workspace_size.value, y.device) + workspace = TestWorkspace(workspace_size.value, y.device) def lib_sigmoid(): check_error( - lib.infiniopSigmoid( + LIBINFINIOP.infiniopSigmoid( descriptor, - workspace.data_ptr() if workspace is not None else None, - workspace_size.value, - y_tensor.data, - x_tensor.data, + workspace.data(), + workspace.size(), + y.data(), + x.data(), None, ) ) @@ -181,52 +143,22 @@ def lib_sigmoid(): atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: - debug(y, ans, atol=atol, rtol=rtol) - - assert torch.allclose(y, ans, atol=atol, rtol=rtol) + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) # Profiling workflow if PROFILE: # fmt: off - profile_operation("PyTorch", lambda: sigmoid_torch(x), torch_device, NUM_PRERUN, NUM_ITERATIONS) - profile_operation(" lib", lambda: lib_sigmoid(), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation("PyTorch", lambda: torch_sigmoid(y.torch_tensor(), x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_sigmoid(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on - check_error(lib.infiniopDestroySigmoidDescriptor(descriptor)) + check_error(LIBINFINIOP.infiniopDestroySigmoidDescriptor(descriptor)) if __name__ == "__main__": args = get_args() - lib = open_lib() - - lib.infiniopCreateSigmoidDescriptor.restype = c_int32 - lib.infiniopCreateSigmoidDescriptor.argtypes = [ - infiniopHandle_t, - POINTER(infiniopSigmoidDescriptor_t), - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - ] - - lib.infiniopGetSigmoidWorkspaceSize.restype = c_int32 - lib.infiniopGetSigmoidWorkspaceSize.argtypes = [ - infiniopSigmoidDescriptor_t, - POINTER(c_uint64), - ] - - lib.infiniopSigmoid.restype = c_int32 - lib.infiniopSigmoid.argtypes = [ - infiniopSigmoidDescriptor_t, - c_void_p, - c_uint64, - c_void_p, - c_void_p, - c_void_p, - ] - - lib.infiniopDestroySigmoidDescriptor.restype = c_int32 - lib.infiniopDestroySigmoidDescriptor.argtypes = [ - infiniopSigmoidDescriptor_t, - ] - + # Configure testing options DEBUG = args.debug PROFILE = args.profile @@ -234,6 +166,6 @@ def lib_sigmoid(): NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): - test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) - print("\033[92mTest passed!\033[0m") + print("\033[92m Test passed! \033[0m") From ed5e63ab531875d8b3566040efe3272c646899db Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Sat, 12 Jul 2025 13:39:33 +0800 Subject: [PATCH 3/4] =?UTF-8?q?issue/207=20-=20=E6=B3=A8=E9=87=8A=E4=B8=80?= =?UTF-8?q?=E4=BA=9Bcase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/infiniop/sigmoid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/infiniop/sigmoid.py b/test/infiniop/sigmoid.py index efd8fdb26..197c1b9fb 100644 --- a/test/infiniop/sigmoid.py +++ b/test/infiniop/sigmoid.py @@ -27,10 +27,10 @@ # shape, a_stride, b_stride, c_stride ((13, 4), None, None), ((13, 4), (10, 1), (10, 1)), - ((13, 4), (0, 1), None,), + #((13, 4), (0, 1), None,), ((13, 4, 4), None, None), ((13, 4, 4), (20, 4, 1), (20, 4, 1)), - ((13, 4, 4), (4, 0, 1), None), + #((13, 4, 4), (4, 0, 1), None), ((16, 5632), None, None), ((16, 5632), (13312, 1), (13312, 1)), ((4, 4, 5632), None, None), From 903728e9bd7a99fd77f1054c64d7d2d75cab051a Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Sat, 12 Jul 2025 15:07:23 +0800 Subject: [PATCH 4/4] =?UTF-8?q?issue/207=20-=20=E6=B7=BB=E5=8A=A0stride?= =?UTF-8?q?=E4=B8=AD=E6=9C=890=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BE=8B?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/sigmoid/cuda/kernel.cuh | 2 +- test/infiniop/sigmoid.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/infiniop/ops/sigmoid/cuda/kernel.cuh b/src/infiniop/ops/sigmoid/cuda/kernel.cuh index 9094f0248..1ea7c2a02 100644 --- a/src/infiniop/ops/sigmoid/cuda/kernel.cuh +++ b/src/infiniop/ops/sigmoid/cuda/kernel.cuh @@ -19,7 +19,7 @@ public: half denominator = __hadd(__float2half(1.0f), hexp(__hneg(x))); return hrcp(denominator); } else if constexpr (std::is_same_v) { - __nv_bfloat16 denominator = __hadd(__float2bfloat16(1.0f), __float2bfloat16(__expf(__bfloat162float(-x)))); + __nv_bfloat16 denominator = __float2bfloat16(__fadd_rn(1.0f, __expf(__bfloat162float(-x)))); return __float2bfloat16(1.0f) / denominator; } else if constexpr (std::is_same_v) { float denominator = __fadd_rn(1.0f, __expf(-x)); diff --git a/test/infiniop/sigmoid.py b/test/infiniop/sigmoid.py index 197c1b9fb..b8f896fa1 100644 --- a/test/infiniop/sigmoid.py +++ b/test/infiniop/sigmoid.py @@ -24,13 +24,13 @@ # ============================================================================== # These are not meant to be imported from other modules _TEST_CASES_ = [ - # shape, a_stride, b_stride, c_stride + # shape, x_stride, y_stride ((13, 4), None, None), ((13, 4), (10, 1), (10, 1)), - #((13, 4), (0, 1), None,), + ((13, 4), (0, 1), (0, 1)), ((13, 4, 4), None, None), ((13, 4, 4), (20, 4, 1), (20, 4, 1)), - #((13, 4, 4), (4, 0, 1), None), + ((13, 4, 4), (4, 0, 1), (4, 0, 1)), ((16, 5632), None, None), ((16, 5632), (13312, 1), (13312, 1)), ((4, 4, 5632), None, None),