Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 0 additions & 84 deletions backends/cadence/reference/kernels/kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include <algorithm>
#include <cstring>
#include <limits>
#include <numeric>

namespace impl {
namespace reference {
namespace kernels {
Expand Down Expand Up @@ -58,36 +56,6 @@ void dequantize(
}
}

// Requantize the int8_t/uint8_t in value to a uint8_t/int8_t out value.
// The scale and zero_point for requantization are in the args.
template <typename IT, typename OT>
OT requantize(
const IT in,
float in_scale,
int32_t in_zero_point,
float inv_out_scale,
int32_t out_zero_point) {
float dequant = dequantize<IT>(in, in_scale, in_zero_point);
return quantize<OT>(dequant, inv_out_scale, out_zero_point);
}

// Requantize the int8_t/uint8_t in array to a uint8_t/int8_t out array.
// The scale and zero_point for requantization are in the args.
template <typename IT, typename OT>
void requantize(
OT* __restrict__ out,
const IT* __restrict__ in,
float in_scale,
int32_t in_zero_point,
float inv_out_scale,
int32_t out_zero_point,
size_t size) {
for (size_t i = 0; i < size; ++i) {
out[i] = requantize<IT, OT>(
in[i], in_scale, in_zero_point, inv_out_scale, out_zero_point);
}
}

// explicit template instantiation

#define typed_quantize_val(dtype) \
Expand Down Expand Up @@ -136,58 +104,6 @@ typed_dequantize_vec(uint16_t);
typed_dequantize_vec(int32_t);
#undef typed_dequantize_vec

#define typed_requantize_val(itype, otype) \
template otype requantize( \
const itype in, \
float in_scale, \
int32_t in_zero_point, \
float inv_out_scale, \
int32_t out_zero_point);
typed_requantize_val(int8_t, int8_t);
typed_requantize_val(int8_t, uint8_t);
typed_requantize_val(int8_t, int16_t);
typed_requantize_val(int8_t, uint16_t);
typed_requantize_val(uint8_t, int8_t);
typed_requantize_val(uint8_t, uint8_t);
typed_requantize_val(uint8_t, int16_t);
typed_requantize_val(uint8_t, uint16_t);
typed_requantize_val(int16_t, int8_t);
typed_requantize_val(int16_t, uint8_t);
typed_requantize_val(int16_t, int16_t);
typed_requantize_val(int16_t, uint16_t);
typed_requantize_val(uint16_t, int8_t);
typed_requantize_val(uint16_t, uint8_t);
typed_requantize_val(uint16_t, int16_t);
typed_requantize_val(uint16_t, uint16_t);
#undef typed_requantize_val

#define typed_requantize_vec(itype, otype) \
template void requantize( \
otype* __restrict__ out, \
const itype* __restrict__ in, \
float in_scale, \
int32_t in_zero_point, \
float inv_out_scale, \
int32_t out_zero_point, \
size_t size);
typed_requantize_vec(int8_t, int8_t);
typed_requantize_vec(int8_t, uint8_t);
typed_requantize_vec(int8_t, int16_t);
typed_requantize_vec(int8_t, uint16_t);
typed_requantize_vec(uint8_t, int8_t);
typed_requantize_vec(uint8_t, uint8_t);
typed_requantize_vec(uint8_t, int16_t);
typed_requantize_vec(uint8_t, uint16_t);
typed_requantize_vec(int16_t, int8_t);
typed_requantize_vec(int16_t, uint8_t);
typed_requantize_vec(int16_t, int16_t);
typed_requantize_vec(int16_t, uint16_t);
typed_requantize_vec(uint16_t, int8_t);
typed_requantize_vec(uint16_t, uint8_t);
typed_requantize_vec(uint16_t, int16_t);
typed_requantize_vec(uint16_t, uint16_t);
#undef typed_requantize_vec

}; // namespace kernels
}; // namespace reference
}; // namespace impl
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,15 @@ Tensor& requantize_out(
torch::executor::toString(out.scalar_type()),
torch::executor::toString(out_dtype));

#define typed_requantize(ctype, dtype) \
const ctype* input_data = input.const_data_ptr<ctype>(); \
dtype* out_data = out.mutable_data_ptr<dtype>(); \
kernels::requantize<ctype, dtype>( \
out_data, \
input_data, \
in_scale, \
in_zero_point, \
1.0 / out_scale, \
out_zero_point, \
numel);
#define typed_requantize(ctype, dtype) \
const ctype* input_data = input.const_data_ptr<ctype>(); \
dtype* out_data = out.mutable_data_ptr<dtype>(); \
for (size_t i = 0; i < numel; ++i) { \
float dequant = \
kernels::dequantize<ctype>(input_data[i], in_scale, in_zero_point); \
out_data[i] = \
kernels::quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \
};

#define typed_requantize_in(ctype) \
switch (out_dtype) { \
Expand Down Expand Up @@ -190,17 +188,15 @@ Tensor& requantize_per_tensor_out(
torch::executor::toString(out.scalar_type()),
torch::executor::toString(out_dtype));

#define typed_requantize(ctype, dtype) \
const ctype* input_data = input.const_data_ptr<ctype>(); \
dtype* out_data = out.mutable_data_ptr<dtype>(); \
kernels::requantize<ctype, dtype>( \
out_data, \
input_data, \
static_cast<float>(in_scale), \
static_cast<int32_t>(in_zero_point), \
1.0 / static_cast<float>(out_scale), \
static_cast<int32_t>(out_zero_point), \
numel);
#define typed_requantize(ctype, dtype) \
const ctype* input_data = input.const_data_ptr<ctype>(); \
dtype* out_data = out.mutable_data_ptr<dtype>(); \
for (size_t i = 0; i < numel; ++i) { \
float dequant = \
kernels::dequantize<ctype>(input_data[i], in_scale, in_zero_point); \
out_data[i] = \
kernels::quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \
};

#define typed_requantize_in(ctype) \
switch (out_dtype) { \
Expand Down
Loading