Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 53 additions & 15 deletions csrc/ops/block_ops.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
#include "block_ops.h"
#include "compare_ops.h"

namespace recis {
namespace functional {

template <typename T, CompareOp op>
struct CompareFactory;

template <typename T>
struct CompareFactory<T, CompareOp::LT> {
bool operator()(const T a, const T b) const { return a < b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::LE> {
bool operator()(const T a, const T b) const { return a <= b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::GT> {
bool operator()(const T a, const T b) const { return a > b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::GE> {
bool operator()(const T a, const T b) const { return a >= b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::EQ> {
bool operator()(const T a, const T b) const { return a == b; }
};

template <class TEmb>
struct GatherFunctor {
GatherFunctor(int64_t embedding_dim, const int64_t *index_vec, TEmb *emb_vec,
Expand Down Expand Up @@ -225,24 +254,26 @@ torch::Tensor block_gather_cpu_kernel(const torch::Tensor ids,
return output;
}

template <class TEmb>
template <class TEmb, typename Factory>
struct BlocksFilterFunctor {
BlocksFilterFunctor(TEmb threshold, int64_t block_size,
const int64_t *index_vec,
std::vector<torch::Tensor> &emb_blocks, bool *mask_out)
std::vector<torch::Tensor> &emb_blocks, bool *mask_out,
Factory factory)
: threshold_(threshold),
block_size_(block_size),
index_vec_(index_vec),
emb_blocks_(emb_blocks),
mask_out_(mask_out) {}
mask_out_(mask_out),
factory_(factory) {}
void operator()(const int64_t beg, const int64_t end) const {
for (auto src_index : c10::irange(beg, end)) {
auto dst_index = index_vec_[src_index];
auto dst_block_index = dst_index / block_size_;
auto dst_row_index = dst_index % block_size_;
TEmb filter_val =
emb_blocks_[dst_block_index].data_ptr<TEmb>()[dst_row_index];
bool to_filter = (filter_val < threshold_);
bool to_filter = factory_(filter_val, threshold_);
mask_out_[src_index] = to_filter;
}
}
Expand All @@ -253,11 +284,13 @@ struct BlocksFilterFunctor {
const int64_t *index_vec_;
std::vector<torch::Tensor> &emb_blocks_;
bool *mask_out_;
Factory factory_;
};

torch::Tensor block_filter_cpu_kernel(const torch::Tensor ids,
std::vector<torch::Tensor> &emb_blocks,
int64_t threshold, int64_t block_size) {
int64_t threshold, int64_t block_size,
int64_t compare_op) {
TORCH_CHECK(ids.device().is_cpu(), "CPU version requires CPU input");
TORCH_CHECK(emb_blocks.size() == 0 || emb_blocks[0].device().is_cpu(),
"CPU version requires CPU emb_blocks");
Expand All @@ -268,11 +301,16 @@ torch::Tensor block_filter_cpu_kernel(const torch::Tensor ids,
if (num_ids == 0) return output;

AT_DISPATCH_INDEX_TYPES(
emb_blocks[0].scalar_type(), "block_filter_cuda_impl", ([&] {
BlocksFilterFunctor<index_t> filter_functor(
static_cast<index_t>(threshold), block_size,
ids.data_ptr<int64_t>(), emb_blocks, output.data_ptr<bool>());
at::parallel_for(0, ids.numel(), 0, filter_functor);
emb_blocks[0].scalar_type(), "block_filter_cpu_impl", ([&] {
BF_DISPATCH_COMPARE_TYPES(
compare_op, ([&] {
using Factory = CompareFactory<index_t, compare_t>;
BlocksFilterFunctor<index_t, Factory> filter_functor(
static_cast<index_t>(threshold), block_size,
ids.data_ptr<int64_t>(), emb_blocks, output.data_ptr<bool>(),
Factory());
at::parallel_for(0, ids.numel(), 0, filter_functor);
}));
}));
return output;
}
Expand Down Expand Up @@ -300,14 +338,16 @@ torch::Tensor block_gather(const torch::Tensor ids,

torch::Tensor block_filter(const torch::Tensor ids,
std::vector<torch::Tensor> emb_blocks,
int64_t block_size, int64_t threshold) {
int64_t block_size, int64_t threshold,
int64_t compare_op) {
if (ids.device().type() == torch::kCUDA) {
return block_filter_cuda(ids, emb_blocks, threshold, block_size);
return block_filter_cuda(ids, emb_blocks, threshold, block_size, compare_op);
} else {
return block_filter_cpu_kernel(ids, emb_blocks, threshold, block_size);
return block_filter_cpu_kernel(ids, emb_blocks, threshold, block_size, compare_op);
}
}


torch::Tensor block_gather_by_range(const torch::Tensor ids,
std::vector<torch::Tensor> emb_blocks,
int64_t block_size, int64_t beg,
Expand Down Expand Up @@ -335,7 +375,6 @@ void block_insert_cpu_kernel(const torch::Tensor ids,
auto num_ids = ids.numel();
if (num_ids == 0) return;
int64_t embedding_dim = embedding_blocks[0].size(1);
auto ids_data = ids.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, embedding.scalar_type(), "block_insert_cpu_impl",
([&] {
Expand Down Expand Up @@ -376,7 +415,6 @@ void block_insert_with_mask_cpu_kernel(
auto num_ids = ids.numel();
if (num_ids == 0) return;
int64_t embedding_dim = embedding_blocks[0].size(1);
auto ids_data = ids.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, embedding.scalar_type(),
"block_insert_with_mask_cpu_impl", ([&] {
Expand Down
6 changes: 4 additions & 2 deletions csrc/ops/block_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ torch::Tensor gather_cuda(const torch::Tensor ids, const torch::Tensor emb);

torch::Tensor block_filter_cuda(const torch::Tensor ids,
std::vector<torch::Tensor>& emb_blocks,
int64_t block_size, int64_t step);
int64_t block_size, int64_t step,
int64_t compare_op);

torch::Tensor block_filter(const torch::Tensor ids,
std::vector<torch::Tensor> emb_blocks,
// const torch::Tensor mask,
int64_t block_size, int64_t step);
int64_t block_size, int64_t step,
int64_t compare_op);

} // namespace functional
} // namespace recis
68 changes: 51 additions & 17 deletions csrc/ops/block_ops_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,39 @@

#include "cuda/cuda_param.cuh"
#include "cuda/utils.cuh"
#include "compare_ops.h"

namespace recis {
namespace functional {

template <typename T, CompareOp op>
struct CompareFactory;

template <typename T>
struct CompareFactory<T, CompareOp::LT> {
__host__ __device__ bool operator()(const T a, const T b) const { return a < b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::LE> {
__host__ __device__ bool operator()(const T a, const T b) const { return a <= b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::GT> {
__host__ __device__ bool operator()(const T a, const T b) const { return a > b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::GE> {
__host__ __device__ bool operator()(const T a, const T b) const { return a >= b; }
};

template <typename T>
struct CompareFactory<T, CompareOp::EQ> {
__host__ __device__ bool operator()(const T a, const T b) const { return a == b; }
};

template <typename scalar_t, typename pack_t>
__global__ void gather_cuda_kernel(const int64_t* ids, const scalar_t* emb,
int64_t num_ids, int64_t embedding_dim,
Expand Down Expand Up @@ -276,11 +305,11 @@ torch::Tensor block_gather_cuda(const torch::Tensor ids,
return output;
}

template <typename scalar_t>
template <typename scalar_t, typename Factory>
__global__ void block_filter_cuda_kernel(
const int64_t* __restrict__ ids, const scalar_t** __restrict__ emb_blocks,
int64_t num_ids, int64_t block_size, scalar_t threshold,
bool* __restrict__ output) {
bool* __restrict__ output, Factory factory) {
int64_t did = blockIdx.x * blockDim.x + threadIdx.x;
if (did >= num_ids) {
return;
Expand All @@ -289,27 +318,28 @@ __global__ void block_filter_cuda_kernel(
int64_t block_index = id / block_size;
int64_t row_index = id % block_size;
scalar_t ft_val = emb_blocks[block_index][row_index];
bool expired = (ft_val < threshold);
bool expired = factory(ft_val, threshold);
output[did] = expired;
}

template <typename scalar_t>
template <typename scalar_t, typename Factory>
void block_filter_kernel_launcher(const int64_t* ids,
const scalar_t** emb_blocks, int64_t num_ids,
int64_t block_size, scalar_t threshold,
bool* output) {
bool* output, Factory factory) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
static const int filter_block_size = 128;
dim3 grids((num_ids + filter_block_size - 1) / filter_block_size);
dim3 blocks(filter_block_size);

block_filter_cuda_kernel<scalar_t><<<grids, blocks, 0, stream>>>(
ids, emb_blocks, num_ids, block_size, threshold, output);
block_filter_cuda_kernel<scalar_t, Factory><<<grids, blocks, 0, stream>>>(
ids, emb_blocks, num_ids, block_size, threshold, output, factory);
}

torch::Tensor block_filter_cuda(const torch::Tensor ids,
std::vector<torch::Tensor>& emb_blocks,
int64_t threshold, int64_t block_size) {
int64_t threshold, int64_t block_size,
int64_t compare_op) {
TORCH_CHECK(ids.is_cuda(), "Input must be on CUDA device");
TORCH_CHECK(emb_blocks.size() == 0 || emb_blocks[0].is_cuda(),
"Embedding must be on CUDA device");
Expand All @@ -324,15 +354,19 @@ torch::Tensor block_filter_cuda(const torch::Tensor ids,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_INDEX_TYPES(
emb_blocks[0].scalar_type(), "block_filter_cuda_impl", ([&] {
recis::cuda::CudaVecParam<index_t*> emb_blocks_ptrs(block_num, stream);
for (auto i = 0; i < block_num; ++i) {
emb_blocks_ptrs[i] = emb_blocks[i].data_ptr<index_t>();
}
block_filter_kernel_launcher<index_t>(
ids.data_ptr<int64_t>(),
const_cast<const index_t**>(emb_blocks_ptrs.data()), num_ids,
block_size, static_cast<index_t>(threshold),
output.data_ptr<bool>());
BF_DISPATCH_COMPARE_TYPES(
compare_op, ([&] {
using Factory = CompareFactory<index_t, compare_t>;
recis::cuda::CudaVecParam<index_t*> emb_blocks_ptrs(block_num, stream);
for (auto i = 0; i < block_num; ++i) {
emb_blocks_ptrs[i] = emb_blocks[i].data_ptr<index_t>();
}
block_filter_kernel_launcher<index_t, Factory>(
ids.data_ptr<int64_t>(),
const_cast<const index_t**>(emb_blocks_ptrs.data()), num_ids,
block_size, static_cast<index_t>(threshold),
output.data_ptr<bool>(), Factory());
}));
}));
cudaError_t err = cudaGetLastError();
TORCH_CHECK(cudaSuccess == err, cudaGetErrorString(err));
Expand Down
43 changes: 43 additions & 0 deletions csrc/ops/compare_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

namespace recis {
namespace functional {

enum CompareOp : int
{
LT,
LE,
GT,
GE,
EQ,
};

#define BF_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
constexpr auto compare_t = enum_type; \
__VA_ARGS__(); \
break; \
}

#define BF_DISPATCH_CASE_COMPARE_TYPES(...) \
BF_DISPATCH_CASE(CompareOp::LT, __VA_ARGS__) \
BF_DISPATCH_CASE(CompareOp::LE, __VA_ARGS__) \
BF_DISPATCH_CASE(CompareOp::GT, __VA_ARGS__) \
BF_DISPATCH_CASE(CompareOp::GE, __VA_ARGS__) \
BF_DISPATCH_CASE(CompareOp::EQ, __VA_ARGS__)

#define BF_DISPATCH_SWITCH(TYPE, ...) \
switch (TYPE) { \
__VA_ARGS__ \
default: \
TORCH_CHECK_NOT_IMPLEMENTED( \
false, \
"Invalid TYPE: ", \
std::to_string(TYPE)); \
}

#define BF_DISPATCH_COMPARE_TYPES(TYPE, ...) \
BF_DISPATCH_SWITCH(TYPE, BF_DISPATCH_CASE_COMPARE_TYPES(__VA_ARGS__))

} // namespace functional
} // namespace recis
2 changes: 2 additions & 0 deletions recis/nn/modules/hashtable_hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from recis.common.singleton import SingletonMeta
from recis.nn.hashtable_hook import AdmitHook, FilterHook
from recis.utils.logger import Logger
from recis.utils.data_utils import CompareOp


logger = Logger(__name__)
Expand Down Expand Up @@ -547,6 +548,7 @@ def filter(
filter_values,
filter_slot.block_size(),
self._global_step - self._filter_step,
CompareOp.LT,
)
return torch.masked_select(ids, mask), torch.masked_select(index, mask)

Expand Down
9 changes: 9 additions & 0 deletions recis/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import collections
from dataclasses import fields, is_dataclass
from enum import IntEnum, auto


class CompareOp(IntEnum):
LT = auto()
LE = auto()
GT = auto()
GE = auto()
EQ = auto()


def copy_data_to_device(data, device, *args, **kwargs):
Expand Down