Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# python/infinicore/ops/softshrink/__init__.py
import torch
from .._C import softshrink as _softshrink
def softshrink(input, lambda=0.5):
if input.is_cuda and input.dtype == torch.float16:
return _softshrink(input, lambda)
return torch.nn.functional.softshrink(input.float(), lambda).to(input.dtype)
6 changes: 6 additions & 0 deletions softshrink.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// python/infinicore/ops/softshrink/softshrink.cpp
#include <torch/extension.h>
torch::Tensor softshrink(const torch::Tensor& input, float lambda);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("softshrink", &softshrink, "softshrink", py::arg("input"), py::arg("lambda")=0.5f);
}
31 changes: 31 additions & 0 deletions softshrink_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// python/infinicore/ops/softshrink/softshrink_kernel.cu
#include <cuda_fp16.h>

__global__ void softshrink_152x_kernel(const half* x, half* y, half lambda, int64_t n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;

half val = x[idx];
half zero = __float2half(0.0f);
half pos = __hadd(val, -lambda);
half neg = __hadd(val, lambda);

y[idx] = __hgt(val, lambda) ? pos : (__hlt(val, -lambda) ? neg : zero);
}

torch::Tensor softshrink(const torch::Tensor& input, float lambda = 0.5f) {
TORCH_CHECK(input.scalar_type() == torch::kFloat16);
auto out = torch::empty_like(input);
half h_lambda = __float2half(lambda);

int threads = 512;
int blocks = (input.numel() + threads - 1) / threads;

softshrink_152x_kernel<<<blocks, threads>>>(
input.data_ptr<at::Half>(),
out.data_ptr<at::Half>(),
h_lambda,
input.numel()
);
return out;
}