diff --git a/csrc/ops/block_ops.cc b/csrc/ops/block_ops.cc index 9388f72..9870402 100644 --- a/csrc/ops/block_ops.cc +++ b/csrc/ops/block_ops.cc @@ -1,8 +1,37 @@ #include "block_ops.h" +#include "compare_ops.h" namespace recis { namespace functional { +template +struct CompareFactory; + +template +struct CompareFactory { + bool operator()(const T a, const T b) const { return a < b; } +}; + +template +struct CompareFactory { + bool operator()(const T a, const T b) const { return a <= b; } +}; + +template +struct CompareFactory { + bool operator()(const T a, const T b) const { return a > b; } +}; + +template +struct CompareFactory { + bool operator()(const T a, const T b) const { return a >= b; } +}; + +template +struct CompareFactory { + bool operator()(const T a, const T b) const { return a == b; } +}; + template struct GatherFunctor { GatherFunctor(int64_t embedding_dim, const int64_t *index_vec, TEmb *emb_vec, @@ -225,16 +254,18 @@ torch::Tensor block_gather_cpu_kernel(const torch::Tensor ids, return output; } -template +template struct BlocksFilterFunctor { BlocksFilterFunctor(TEmb threshold, int64_t block_size, const int64_t *index_vec, - std::vector &emb_blocks, bool *mask_out) + std::vector &emb_blocks, bool *mask_out, + Factory factory) : threshold_(threshold), block_size_(block_size), index_vec_(index_vec), emb_blocks_(emb_blocks), - mask_out_(mask_out) {} + mask_out_(mask_out), + factory_(factory) {} void operator()(const int64_t beg, const int64_t end) const { for (auto src_index : c10::irange(beg, end)) { auto dst_index = index_vec_[src_index]; @@ -242,7 +273,7 @@ struct BlocksFilterFunctor { auto dst_row_index = dst_index % block_size_; TEmb filter_val = emb_blocks_[dst_block_index].data_ptr()[dst_row_index]; - bool to_filter = (filter_val < threshold_); + bool to_filter = factory_(filter_val, threshold_); mask_out_[src_index] = to_filter; } } @@ -253,11 +284,13 @@ struct BlocksFilterFunctor { const int64_t *index_vec_; std::vector &emb_blocks_; bool *mask_out_; + Factory factory_; }; torch::Tensor block_filter_cpu_kernel(const torch::Tensor ids, std::vector &emb_blocks, - int64_t threshold, int64_t block_size) { + int64_t threshold, int64_t block_size, + int64_t compare_op) { TORCH_CHECK(ids.device().is_cpu(), "CPU version requires CPU input"); TORCH_CHECK(emb_blocks.size() == 0 || emb_blocks[0].device().is_cpu(), "CPU version requires CPU emb_blocks"); @@ -268,11 +301,16 @@ torch::Tensor block_filter_cpu_kernel(const torch::Tensor ids, if (num_ids == 0) return output; AT_DISPATCH_INDEX_TYPES( - emb_blocks[0].scalar_type(), "block_filter_cuda_impl", ([&] { - BlocksFilterFunctor filter_functor( - static_cast(threshold), block_size, - ids.data_ptr(), emb_blocks, output.data_ptr()); - at::parallel_for(0, ids.numel(), 0, filter_functor); + emb_blocks[0].scalar_type(), "block_filter_cpu_impl", ([&] { + BF_DISPATCH_COMPARE_TYPES( + compare_op, ([&] { + using Factory = CompareFactory; + BlocksFilterFunctor filter_functor( + static_cast(threshold), block_size, + ids.data_ptr(), emb_blocks, output.data_ptr(), + Factory()); + at::parallel_for(0, ids.numel(), 0, filter_functor); + })); })); return output; } @@ -300,14 +338,16 @@ torch::Tensor block_gather(const torch::Tensor ids, torch::Tensor block_filter(const torch::Tensor ids, std::vector emb_blocks, - int64_t block_size, int64_t threshold) { + int64_t block_size, int64_t threshold, + int64_t compare_op) { if (ids.device().type() == torch::kCUDA) { - return block_filter_cuda(ids, emb_blocks, threshold, block_size); + return block_filter_cuda(ids, emb_blocks, threshold, block_size, compare_op); } else { - return block_filter_cpu_kernel(ids, emb_blocks, threshold, block_size); + return block_filter_cpu_kernel(ids, emb_blocks, threshold, block_size, compare_op); } } + torch::Tensor block_gather_by_range(const torch::Tensor ids, std::vector emb_blocks, int64_t block_size, int64_t beg, @@ -335,7 +375,6 @@ void block_insert_cpu_kernel(const torch::Tensor ids, auto num_ids = ids.numel(); if (num_ids == 0) return; int64_t embedding_dim = embedding_blocks[0].size(1); - auto ids_data = ids.data_ptr(); AT_DISPATCH_ALL_TYPES_AND( at::ScalarType::Half, embedding.scalar_type(), "block_insert_cpu_impl", ([&] { @@ -376,7 +415,6 @@ void block_insert_with_mask_cpu_kernel( auto num_ids = ids.numel(); if (num_ids == 0) return; int64_t embedding_dim = embedding_blocks[0].size(1); - auto ids_data = ids.data_ptr(); AT_DISPATCH_ALL_TYPES_AND( at::ScalarType::Half, embedding.scalar_type(), "block_insert_with_mask_cpu_impl", ([&] { diff --git a/csrc/ops/block_ops.h b/csrc/ops/block_ops.h index 8dab08a..df85a71 100644 --- a/csrc/ops/block_ops.h +++ b/csrc/ops/block_ops.h @@ -43,12 +43,14 @@ torch::Tensor gather_cuda(const torch::Tensor ids, const torch::Tensor emb); torch::Tensor block_filter_cuda(const torch::Tensor ids, std::vector& emb_blocks, - int64_t block_size, int64_t step); + int64_t block_size, int64_t step, + int64_t compare_op); torch::Tensor block_filter(const torch::Tensor ids, std::vector emb_blocks, // const torch::Tensor mask, - int64_t block_size, int64_t step); + int64_t block_size, int64_t step, + int64_t compare_op); } // namespace functional } // namespace recis diff --git a/csrc/ops/block_ops_kernel.cu b/csrc/ops/block_ops_kernel.cu index 3c18f36..63c20bf 100644 --- a/csrc/ops/block_ops_kernel.cu +++ b/csrc/ops/block_ops_kernel.cu @@ -8,10 +8,39 @@ #include "cuda/cuda_param.cuh" #include "cuda/utils.cuh" +#include "compare_ops.h" namespace recis { namespace functional { +template +struct CompareFactory; + +template +struct CompareFactory { + __host__ __device__ bool operator()(const T a, const T b) const { return a < b; } +}; + +template +struct CompareFactory { + __host__ __device__ bool operator()(const T a, const T b) const { return a <= b; } +}; + +template +struct CompareFactory { + __host__ __device__ bool operator()(const T a, const T b) const { return a > b; } +}; + +template +struct CompareFactory { + __host__ __device__ bool operator()(const T a, const T b) const { return a >= b; } +}; + +template +struct CompareFactory { + __host__ __device__ bool operator()(const T a, const T b) const { return a == b; } +}; + template __global__ void gather_cuda_kernel(const int64_t* ids, const scalar_t* emb, int64_t num_ids, int64_t embedding_dim, @@ -276,11 +305,11 @@ torch::Tensor block_gather_cuda(const torch::Tensor ids, return output; } -template +template __global__ void block_filter_cuda_kernel( const int64_t* __restrict__ ids, const scalar_t** __restrict__ emb_blocks, int64_t num_ids, int64_t block_size, scalar_t threshold, - bool* __restrict__ output) { + bool* __restrict__ output, Factory factory) { int64_t did = blockIdx.x * blockDim.x + threadIdx.x; if (did >= num_ids) { return; @@ -289,27 +318,28 @@ __global__ void block_filter_cuda_kernel( int64_t block_index = id / block_size; int64_t row_index = id % block_size; scalar_t ft_val = emb_blocks[block_index][row_index]; - bool expired = (ft_val < threshold); + bool expired = factory(ft_val, threshold); output[did] = expired; } -template +template void block_filter_kernel_launcher(const int64_t* ids, const scalar_t** emb_blocks, int64_t num_ids, int64_t block_size, scalar_t threshold, - bool* output) { + bool* output, Factory factory) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); static const int filter_block_size = 128; dim3 grids((num_ids + filter_block_size - 1) / filter_block_size); dim3 blocks(filter_block_size); - block_filter_cuda_kernel<<>>( - ids, emb_blocks, num_ids, block_size, threshold, output); + block_filter_cuda_kernel<<>>( + ids, emb_blocks, num_ids, block_size, threshold, output, factory); } torch::Tensor block_filter_cuda(const torch::Tensor ids, std::vector& emb_blocks, - int64_t threshold, int64_t block_size) { + int64_t threshold, int64_t block_size, + int64_t compare_op) { TORCH_CHECK(ids.is_cuda(), "Input must be on CUDA device"); TORCH_CHECK(emb_blocks.size() == 0 || emb_blocks[0].is_cuda(), "Embedding must be on CUDA device"); @@ -324,15 +354,19 @@ torch::Tensor block_filter_cuda(const torch::Tensor ids, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_INDEX_TYPES( emb_blocks[0].scalar_type(), "block_filter_cuda_impl", ([&] { - recis::cuda::CudaVecParam emb_blocks_ptrs(block_num, stream); - for (auto i = 0; i < block_num; ++i) { - emb_blocks_ptrs[i] = emb_blocks[i].data_ptr(); - } - block_filter_kernel_launcher( - ids.data_ptr(), - const_cast(emb_blocks_ptrs.data()), num_ids, - block_size, static_cast(threshold), - output.data_ptr()); + BF_DISPATCH_COMPARE_TYPES( + compare_op, ([&] { + using Factory = CompareFactory; + recis::cuda::CudaVecParam emb_blocks_ptrs(block_num, stream); + for (auto i = 0; i < block_num; ++i) { + emb_blocks_ptrs[i] = emb_blocks[i].data_ptr(); + } + block_filter_kernel_launcher( + ids.data_ptr(), + const_cast(emb_blocks_ptrs.data()), num_ids, + block_size, static_cast(threshold), + output.data_ptr(), Factory()); + })); })); cudaError_t err = cudaGetLastError(); TORCH_CHECK(cudaSuccess == err, cudaGetErrorString(err)); diff --git a/csrc/ops/compare_ops.h b/csrc/ops/compare_ops.h new file mode 100644 index 0000000..12c2789 --- /dev/null +++ b/csrc/ops/compare_ops.h @@ -0,0 +1,43 @@ +#pragma once + +namespace recis { +namespace functional { + +enum CompareOp : int +{ + LT, + LE, + GT, + GE, + EQ, +}; + +#define BF_DISPATCH_CASE(enum_type, ...) \ + case enum_type: { \ + constexpr auto compare_t = enum_type; \ + __VA_ARGS__(); \ + break; \ + } + +#define BF_DISPATCH_CASE_COMPARE_TYPES(...) \ + BF_DISPATCH_CASE(CompareOp::LT, __VA_ARGS__) \ + BF_DISPATCH_CASE(CompareOp::LE, __VA_ARGS__) \ + BF_DISPATCH_CASE(CompareOp::GT, __VA_ARGS__) \ + BF_DISPATCH_CASE(CompareOp::GE, __VA_ARGS__) \ + BF_DISPATCH_CASE(CompareOp::EQ, __VA_ARGS__) + +#define BF_DISPATCH_SWITCH(TYPE, ...) \ + switch (TYPE) { \ + __VA_ARGS__ \ + default: \ + TORCH_CHECK_NOT_IMPLEMENTED( \ + false, \ + "Invalid TYPE: ", \ + std::to_string(TYPE)); \ + } + +#define BF_DISPATCH_COMPARE_TYPES(TYPE, ...) \ + BF_DISPATCH_SWITCH(TYPE, BF_DISPATCH_CASE_COMPARE_TYPES(__VA_ARGS__)) + +} // namespace functional +} // namespace recis diff --git a/recis/nn/modules/hashtable_hook_impl.py b/recis/nn/modules/hashtable_hook_impl.py index 16c6d66..e9c3348 100644 --- a/recis/nn/modules/hashtable_hook_impl.py +++ b/recis/nn/modules/hashtable_hook_impl.py @@ -6,6 +6,7 @@ from recis.common.singleton import SingletonMeta from recis.nn.hashtable_hook import AdmitHook, FilterHook from recis.utils.logger import Logger +from recis.utils.data_utils import CompareOp logger = Logger(__name__) @@ -547,6 +548,7 @@ def filter( filter_values, filter_slot.block_size(), self._global_step - self._filter_step, + CompareOp.LT, ) return torch.masked_select(ids, mask), torch.masked_select(index, mask) diff --git a/recis/utils/data_utils.py b/recis/utils/data_utils.py index 61d64f0..f0b5983 100644 --- a/recis/utils/data_utils.py +++ b/recis/utils/data_utils.py @@ -1,5 +1,14 @@ import collections from dataclasses import fields, is_dataclass +from enum import IntEnum, auto + + +class CompareOp(IntEnum): + LT = auto() + LE = auto() + GT = auto() + GE = auto() + EQ = auto() def copy_data_to_device(data, device, *args, **kwargs):