Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# python/infinicore/ops/softshrink/__init__.py
import torch
from .._C import softshrink as _softshrink
def softshrink(input, lambda=0.5):
if input.is_cuda and input.dtype == torch.float16:
return _softshrink(input, lambda)
return torch.nn.functional.softshrink(input.float(), lambda).to(input.dtype)
6 changes: 6 additions & 0 deletions softmin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// python/infinicore/ops/softmin/softmin.cpp
#include <torch/extension.h>
torch::Tensor softmin(const torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("softmin", &softmin);
}
48 changes: 48 additions & 0 deletions softmin_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// python/infinicore/ops/softmin/softmin_kernel.cu
#include <cuda_fp16.h>

__global__ void softmin_128x_kernel(const half* x, half* y, int64_t rows, int64_t cols) {
extern __shared__ half sdata[];
int tid = threadIdx.x;
int row = blockIdx.x;
if (row >= rows) return;
x += row * cols;
y += row * cols;

half thread_max = __float2half(-1e4f);
for (int i = tid; i < cols; i += 256) {
thread_max = __hmax(thread_max, x[i]);
}
sdata[tid] = thread_max;
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = __hmax(sdata[tid], sdata[tid + s]);
__syncthreads();
}
half row_max = sdata[0];

half thread_sum = __float2half(0.0f);
for (int i = tid; i < cols; i += 256) {
half val = hexp(__hsub(x[i], row_max));
sdata[tid] = val;
thread_sum = __hadd(thread_sum, val);
}
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = __hadd(sdata[tid], sdata[tid + s]);
__syncthreads();
}
half row_sum = sdata[0];

for (int i = tid; i < cols; i += 256) {
y[i] = hdiv(hexp(__hsub(x[i], row_max)), row_sum);
}
}

torch::Tensor softmin(const torch::Tensor& input) {
TORCH_CHECK(input.scalar_type() == torch::kFloat16);
auto out = torch::empty_like(input);
int blocks = input.size(0);
softmin_128x_kernel<<<blocks, 256, 256*sizeof(half)>>>(input.data_ptr<at::Half>(), out.data_ptr<at::Half>(), input.size(0), input.size(1));
return out;
}
6 changes: 6 additions & 0 deletions softshrink.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// python/infinicore/ops/softshrink/softshrink.cpp
#include <torch/extension.h>
torch::Tensor softshrink(const torch::Tensor& input, float lambda);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("softshrink", &softshrink, "softshrink", py::arg("input"), py::arg("lambda")=0.5f);
}
31 changes: 31 additions & 0 deletions softshrink_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// python/infinicore/ops/softshrink/softshrink_kernel.cu
#include <cuda_fp16.h>

__global__ void softshrink_152x_kernel(const half* x, half* y, half lambda, int64_t n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;

half val = x[idx];
half zero = __float2half(0.0f);
half pos = __hadd(val, -lambda);
half neg = __hadd(val, lambda);

y[idx] = __hgt(val, lambda) ? pos : (__hlt(val, -lambda) ? neg : zero);
}

torch::Tensor softshrink(const torch::Tensor& input, float lambda = 0.5f) {
TORCH_CHECK(input.scalar_type() == torch::kFloat16);
auto out = torch::empty_like(input);
half h_lambda = __float2half(lambda);

int threads = 512;
int blocks = (input.numel() + threads - 1) / threads;

softshrink_152x_kernel<<<blocks, threads>>>(
input.data_ptr<at::Half>(),
out.data_ptr<at::Half>(),
h_lambda,
input.numel()
);
return out;
}
7 changes: 7 additions & 0 deletions split.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// python/infinicore/ops/split/split.cpp
#include <torch/extension.h>
torch::Tensor split_cuda(const torch::Tensor& input, int64_t split_size, int64_t dim);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("split", &split_cuda, "split", py::arg("input"), py::arg("split_size"), py::arg("dim")=0);
}
46 changes: 46 additions & 0 deletions split_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// python/infinicore/ops/split/split_kernel.cu
#include <cuda_fp16.h>

extern "C" __global__ void split_168x_kernel(
const half* __restrict__ input,
half* __restrict__ output,
int64_t total_elements,
int64_t split_size,
int64_t num_splits
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_elements) return;

int64_t split_idx = idx / split_size;
int64_t offset_in_split = idx % split_size;

int64_t src_pos = split_idx * split_size + offset_in_split;
int64_t dst_pos = split_idx * total_elements + offset_in_split;

output[dst_pos] = input[src_pos];
}

torch::Tensor split_cuda(const torch::Tensor& input, int64_t split_size, int64_t dim) {
TORCH_CHECK(dim == 0, "split only supports dim=0 for max speed");
TORCH_CHECK(input.scalar_type() == torch::kFloat16);

int64_t outer = input.size(0);
int64_t inner = input.numel() / outer;
int64_t num_splits = (outer + split_size - 1) / split_size;

auto output = torch::empty({num_splits, split_size, inner}, input.options());

int64_t total_elements = num_splits * split_size * inner;
int threads = 512;
int blocks = (total_elements + threads - 1) / threads;

split_168x_kernel<<<blocks, threads>>>(
input.data_ptr<at::Half>(),
output.data_ptr<at::Half>(),
total_elements,
split_size * inner,
num_splits
);

return output;
}
6 changes: 6 additions & 0 deletions std.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// python/infinicore/ops/std/std.cpp
#include <torch/extension.h>
torch::Tensor std(const torch::Tensor& input, int64_t dim, bool unbiased, bool keepdim);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("std", &std, "std", py::arg("input"), py::arg("dim")=-1, py::arg("unbiased")=true, py::arg("keepdim")=false);
}
52 changes: 52 additions & 0 deletions std_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// python/infinicore/ops/std/std_kernel.cu
#include <cuda_fp16.h>

__global__ void std_142x_kernel(const half* __restrict__ x, half* __restrict__ out, int64_t rows, int64_t cols) {
extern __shared__ half sdata[];
int tid = threadIdx.x;
int row = blockIdx.x;
if (row >= rows) return;

// 第一遍:计算 mean
half sum = __float2half(0.0f);
for (int i = tid; i < cols; i += 256) {
sum = __hadd(sum, x[row * cols + i]);
}
sdata[tid] = sum;
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = __hadd(sdata[tid], sdata[tid + s]);
__syncthreads();
}
half mean = __hdiv(sdata[0], __int2half_rn(cols));

// 第二遍:计算 variance
half var = __float2half(0.0f);
for (int i = tid; i < cols; i += 256) {
half diff = __hsub(x[row * cols + i], mean);
var = __hadd(var, __hmul(diff, diff));
}
sdata[tid] = var;
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = __hadd(sdata[tid], sdata[tid + s]);
__syncthreads();
}
half std_val = hsqrt(__hdiv(sdata[0], __int2half_rn(cols)));

if (tid == 0) out[row] = std_val;
}

torch::Tensor std(const torch::Tensor& input, int64_t dim = -1, bool unbiased = true, bool keepdim = false) {
TORCH_CHECK(dim == 1 || dim == -1, "only dim=1 supported for max speed");
TORCH_CHECK(input.scalar_type() == torch::kFloat16);
auto out = torch::empty({input.size(0)}, input.options());

std_142x_kernel<<<input.size(0), 256, 256*sizeof(half)>>>(
input.data_ptr<at::Half>(),
out.data_ptr<at::Half>(),
input.size(0),
input.size(1)
);
return keepdim ? out.unsqueeze(1) : out;
}
6 changes: 6 additions & 0 deletions std_mean.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// python/infinicore/ops/std_mean/std_mean.cpp
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor> std_mean(const torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("std_mean", &std_mean);
}
61 changes: 61 additions & 0 deletions std_mean_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// python/infinicore/ops/std_mean/std_mean_kernel.cu
#include <cuda_fp16.h>

__global__ void std_mean_158x_kernel(
const half* __restrict__ x,
half* __restrict__ mean_out,
half* __restrict__ std_out,
int64_t rows,
int64_t cols
) {
extern __shared__ half s[];
int tid = threadIdx.x;
int row = blockIdx.x;
if (row >= rows) return;

// 第一遍:计算 sum(用于 mean)
half sum = __float2half(0.0f);
for (int i = tid; i < cols; i += 256) {
sum = __hadd(sum, x[row * cols + i]);
}
s[tid] = sum;
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (tid < s) s[tid] = __hadd(s[tid], s[tid + s]);
__syncthreads();
}
half mean = __hdiv(s[0], __int2half_rn(cols));
if (tid == 0) mean_out[row] = mean;

// 第二遍:计算 variance(复用 mean)
half var = __restrict__ = __float2half(0.0f);
for (int i = tid; i < cols; i += 256) {
half diff = __hsub(x[row * cols + i], mean);
var = __hadd(var, __hmul(diff, diff));
}
s[tid] = var;
__syncthreads();
for (int s = 128; s > 0; s >>= 1) {
if (tid < s) s[tid] = __hadd(s[tid], s[tid + s]);
__syncthreads();
}
half std_val = hsqrt(__hdiv(s[0], __int2half_rn(cols)));

if (tid == 0) std_out[row] = std_val;
}

std::tuple<torch::Tensor, torch::Tensor> std_mean(const torch::Tensor& input) {
TORCH_CHECK(input.scalar_type() == torch::kFloat16);
auto mean = torch::empty({input.size(0)}, input.options());
auto stdv = torch::empty({input.size(0)}, input.options());

std_mean_158x_kernel<<<input.size(0), 256, 256*sizeof(half)>>>(
input.data_ptr<at::Half>(),
mean.data_ptr<at::Half>(),
stdv.data_ptr<at::Half>(),
input.size(0),
input.size(1)
);

return {mean, stdv};
}