From 2f8919f66048837e5e3340287926b71628622152 Mon Sep 17 00:00:00 2001 From: YU Qing <2961548487@qq.com> Date: Thu, 11 Dec 2025 11:07:34 +0800 Subject: [PATCH] Add files via upload --- __init__.py | 7 +++++++ softshrink.cpp | 6 ++++++ softshrink_kernel.cu | 31 +++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+) create mode 100644 __init__.py create mode 100644 softshrink.cpp create mode 100644 softshrink_kernel.cu diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..be6ad7d88 --- /dev/null +++ b/__init__.py @@ -0,0 +1,7 @@ +# python/infinicore/ops/softshrink/__init__.py +import torch +from .._C import softshrink as _softshrink +def softshrink(input, lambda=0.5): + if input.is_cuda and input.dtype == torch.float16: + return _softshrink(input, lambda) + return torch.nn.functional.softshrink(input.float(), lambda).to(input.dtype) \ No newline at end of file diff --git a/softshrink.cpp b/softshrink.cpp new file mode 100644 index 000000000..a80a76bef --- /dev/null +++ b/softshrink.cpp @@ -0,0 +1,6 @@ +// python/infinicore/ops/softshrink/softshrink.cpp +#include +torch::Tensor softshrink(const torch::Tensor& input, float lambda); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("softshrink", &softshrink, "softshrink", py::arg("input"), py::arg("lambda")=0.5f); +} \ No newline at end of file diff --git a/softshrink_kernel.cu b/softshrink_kernel.cu new file mode 100644 index 000000000..09cf5083a --- /dev/null +++ b/softshrink_kernel.cu @@ -0,0 +1,31 @@ +// python/infinicore/ops/softshrink/softshrink_kernel.cu +#include + +__global__ void softshrink_152x_kernel(const half* x, half* y, half lambda, int64_t n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) return; + + half val = x[idx]; + half zero = __float2half(0.0f); + half pos = __hadd(val, -lambda); + half neg = __hadd(val, lambda); + + y[idx] = __hgt(val, lambda) ? pos : (__hlt(val, -lambda) ? neg : zero); +} + +torch::Tensor softshrink(const torch::Tensor& input, float lambda = 0.5f) { + TORCH_CHECK(input.scalar_type() == torch::kFloat16); + auto out = torch::empty_like(input); + half h_lambda = __float2half(lambda); + + int threads = 512; + int blocks = (input.numel() + threads - 1) / threads; + + softshrink_152x_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + h_lambda, + input.numel() + ); + return out; +} \ No newline at end of file