|
| 1 | +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#if GOOGLE_CUDA |
| 17 | +#define EIGEN_USE_GPU |
| 18 | +#endif // GOOGLE_CUDA |
| 19 | + |
| 20 | +#include <memory> |
| 21 | + |
| 22 | +#include "tensorflow/core/framework/register_types.h" |
| 23 | +#include "tensorflow/core/framework/tensor.h" |
| 24 | +#include "tensorflow/core/framework/tensor_shape.h" |
| 25 | +#include "tensorflow/core/lib/core/status.h" |
| 26 | +#include "tensorflow/core/platform/logging.h" |
| 27 | +#include "tensorflow/core/util/work_sharder.h" |
| 28 | +#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h" |
| 29 | + |
| 30 | +namespace tensorflow { |
| 31 | + |
| 32 | +typedef Eigen::ThreadPoolDevice CPUDevice; |
| 33 | +typedef Eigen::GpuDevice GPUDevice; |
| 34 | + |
| 35 | +class AdjustHsvInYiqOpBase : public OpKernel { |
| 36 | + protected: |
| 37 | + explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context) |
| 38 | + : OpKernel(context) {} |
| 39 | + |
| 40 | + struct ComputeOptions { |
| 41 | + const Tensor* input = nullptr; |
| 42 | + Tensor* output = nullptr; |
| 43 | + const Tensor* delta_h = nullptr; |
| 44 | + const Tensor* scale_s = nullptr; |
| 45 | + const Tensor* scale_v = nullptr; |
| 46 | + int64 channel_count = 0; |
| 47 | + }; |
| 48 | + |
| 49 | + virtual void DoCompute(OpKernelContext* context, |
| 50 | + const ComputeOptions& options) = 0; |
| 51 | + |
| 52 | + void Compute(OpKernelContext* context) override { |
| 53 | + const Tensor& input = context->input(0); |
| 54 | + const Tensor& delta_h = context->input(1); |
| 55 | + const Tensor& scale_s = context->input(2); |
| 56 | + const Tensor& scale_v = context->input(3); |
| 57 | + OP_REQUIRES(context, input.dims() >= 3, |
| 58 | + errors::InvalidArgument("input must be at least 3-D, got shape", |
| 59 | + input.shape().DebugString())); |
| 60 | + OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()), |
| 61 | + errors::InvalidArgument("delta_h must be scalar: ", |
| 62 | + delta_h.shape().DebugString())); |
| 63 | + OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()), |
| 64 | + errors::InvalidArgument("scale_s must be scalar: ", |
| 65 | + scale_s.shape().DebugString())); |
| 66 | + OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()), |
| 67 | + errors::InvalidArgument("scale_v must be scalar: ", |
| 68 | + scale_v.shape().DebugString())); |
| 69 | + auto channels = input.dim_size(input.dims() - 1); |
| 70 | + OP_REQUIRES( |
| 71 | + context, channels == kChannelSize, |
| 72 | + errors::InvalidArgument("input must have 3 channels but instead has ", |
| 73 | + channels, " channels.")); |
| 74 | + |
| 75 | + Tensor* output = nullptr; |
| 76 | + OP_REQUIRES_OK(context, |
| 77 | + context->allocate_output(0, input.shape(), &output)); |
| 78 | + |
| 79 | + if (input.NumElements() > 0) { |
| 80 | + const int64 channel_count = input.NumElements() / channels; |
| 81 | + ComputeOptions options; |
| 82 | + options.input = &input; |
| 83 | + options.delta_h = &delta_h; |
| 84 | + options.scale_s = &scale_s; |
| 85 | + options.scale_v = &scale_v; |
| 86 | + options.output = output; |
| 87 | + options.channel_count = channel_count; |
| 88 | + DoCompute(context, options); |
| 89 | + } |
| 90 | + } |
| 91 | +}; |
| 92 | + |
| 93 | +template <class Device> |
| 94 | +class AdjustHsvInYiqOp; |
| 95 | + |
| 96 | +template <> |
| 97 | +class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase { |
| 98 | + public: |
| 99 | + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) |
| 100 | + : AdjustHsvInYiqOpBase(context) {} |
| 101 | + |
| 102 | + void DoCompute(OpKernelContext* context, |
| 103 | + const ComputeOptions& options) override { |
| 104 | + const Tensor* input = options.input; |
| 105 | + Tensor* output = options.output; |
| 106 | + const int64 channel_count = options.channel_count; |
| 107 | + auto input_data = input->shaped<float, 2>({channel_count, kChannelSize}); |
| 108 | + const float delta_h = options.delta_h->scalar<float>()(); |
| 109 | + const float scale_s = options.scale_s->scalar<float>()(); |
| 110 | + const float scale_v = options.scale_v->scalar<float>()(); |
| 111 | + auto output_data = output->shaped<float, 2>({channel_count, kChannelSize}); |
| 112 | + float tranformation_matrix[kChannelSize * kChannelSize] = {0}; |
| 113 | + internal::compute_tranformation_matrix<kChannelSize * kChannelSize>( |
| 114 | + delta_h, scale_s, scale_v, tranformation_matrix); |
| 115 | + const int kCostPerChannel = 10; |
| 116 | + const DeviceBase::CpuWorkerThreads& worker_threads = |
| 117 | + *context->device()->tensorflow_cpu_worker_threads(); |
| 118 | + Shard(worker_threads.num_threads, worker_threads.workers, channel_count, |
| 119 | + kCostPerChannel, [&input_data, &output_data, &tranformation_matrix]( |
| 120 | + int64 start_channel, int64 end_channel) { |
| 121 | + // Applying projection matrix to input RGB vectors. |
| 122 | + const float* p = input_data.data() + start_channel * kChannelSize; |
| 123 | + float* q = output_data.data() + start_channel * kChannelSize; |
| 124 | + for (int i = start_channel; i < end_channel; i++) { |
| 125 | + for (int q_index = 0; q_index < kChannelSize; q_index++) { |
| 126 | + q[q_index] = 0; |
| 127 | + for (int p_index = 0; p_index < kChannelSize; p_index++) { |
| 128 | + q[q_index] += |
| 129 | + p[p_index] * |
| 130 | + tranformation_matrix[q_index + kChannelSize * p_index]; |
| 131 | + } |
| 132 | + } |
| 133 | + p += kChannelSize; |
| 134 | + q += kChannelSize; |
| 135 | + } |
| 136 | + }); |
| 137 | + } |
| 138 | +}; |
| 139 | + |
| 140 | +REGISTER_KERNEL_BUILDER( |
| 141 | + Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"), |
| 142 | + AdjustHsvInYiqOp<CPUDevice>); |
| 143 | + |
| 144 | +#if GOOGLE_CUDA |
| 145 | +template <> |
| 146 | +class AdjustHsvInYiqOp<GPUDevice> : public AdjustHsvInYiqOpBase { |
| 147 | + public: |
| 148 | + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) |
| 149 | + : AdjustHsvInYiqOpBase(context) {} |
| 150 | + |
| 151 | + void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { |
| 152 | + const int64 number_of_elements = options.input->NumElements(); |
| 153 | + if (number_of_elements <= 0) { |
| 154 | + return; |
| 155 | + } |
| 156 | + const float* delta_h = options.delta_h->flat<float>().data(); |
| 157 | + const float* scale_s = options.scale_s->flat<float>().data(); |
| 158 | + const float* scale_v = options.scale_v->flat<float>().data(); |
| 159 | + functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, |
| 160 | + delta_h, scale_s, scale_v, options.output); |
| 161 | + } |
| 162 | +}; |
| 163 | + |
| 164 | +REGISTER_KERNEL_BUILDER( |
| 165 | + Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"), |
| 166 | + AdjustHsvInYiqOp<GPUDevice>); |
| 167 | +#endif |
| 168 | + |
| 169 | +} // namespace tensorflow |
0 commit comments