Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions paddle/fluid/operators/data/batch_decode_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class BatchDecodeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::UINT8,
ctx.GetPlace());
return framework::OpKernelType(
framework::proto::VarType::UINT8, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down Expand Up @@ -67,7 +67,8 @@ or 1 dimensional Gray Tensor. Optionally converts the image to the
desired format. The values of the output tensor are uint8 between 0
and 255.
)DOC");
AddAttr<int>("num_threads", "Path of the file to be readed.").SetDefault(2);
AddAttr<int>("num_threads", "Path of the file to be readed.")
.SetDefault(2);
AddAttr<int>("local_rank",
"(int)"
"The index of the op to start execution");
Expand All @@ -77,13 +78,11 @@ and 255.
"decode thread pool");
AddAttr<int64_t>(
"host_memory_padding",
"(int64, default 0),"
"pinned memory allocation padding number for Nvjpeg decoding")
"(int64, default 0), pinned memory allocation padding number for Nvjpeg decoding")
.SetDefault(0);
AddAttr<int64_t>(
"device_memory_padding",
"(int64, default 0),"
"device memory allocation padding number for Nvjpeg decoding")
"(int64, default 0), device memory allocation padding number for Nvjpeg decoding")
.SetDefault(0);
}
};
Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/operators/data/batch_decode_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@

#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_HIP)

#include "paddle/fluid/operators/data/batch_decode_op.h"
#include "paddle/fluid/operators/data/batch_decode_random_crop_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"

namespace paddle {
namespace operators {
namespace data {

using LoDTensorBlockingQueueHolder =
operators::reader::LoDTensorBlockingQueueHolder;
using LoDTensorBlockingQueueHolder = operators::reader::LoDTensorBlockingQueueHolder;

template <typename T>
class GPUBatchDecodeKernel : public framework::OpKernel<T> {
Expand All @@ -36,12 +34,12 @@ class GPUBatchDecodeKernel : public framework::OpKernel<T> {
auto device_memory_padding = ctx.Attr<int64_t>("device_memory_padding");

// multi-phrase decode thread pool
auto* decode_pool =
ImageDecoderThreadPoolManager::Instance()->GetDecoderThreadPool(
program_id, num_threads, local_rank,
static_cast<size_t>(host_memory_padding),
static_cast<size_t>(device_memory_padding));

auto* decode_pool =
ImageDecoderThreadPoolManager::Instance()->GetDecoderThreadPool(
program_id, num_threads, local_rank,
static_cast<size_t>(host_memory_padding),
static_cast<size_t>(device_memory_padding));
const framework::LoDTensorArray* inputs =
ctx.Input<framework::LoDTensorArray>("X");

Expand All @@ -54,11 +52,13 @@ class GPUBatchDecodeKernel : public framework::OpKernel<T> {
auto* x_data = x.data<T>();
size_t x_numel = static_cast<size_t>(x.numel());

ImageDecodeTask task = {.bit_stream = x_data,
.bit_len = x_numel,
.tensor = &out_array[i],
.roi_generator = nullptr,
.place = ctx.GetPlace()};
ImageDecodeTask task = {
.bit_stream = x_data,
.bit_len = x_numel,
.tensor = &out_array[i],
.roi_generator = nullptr,
.place = ctx.GetPlace()
};
decode_pool->AddTask(std::make_shared<ImageDecodeTask>(task));
}

Expand Down
95 changes: 41 additions & 54 deletions paddle/fluid/operators/data/batch_decode_random_crop_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,64 +23,53 @@ class BatchDecodeRandomCropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(X) of DecodeJpeg should not be empty."));
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of DecodeJpeg should not be empty."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DecodeJpeg");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DecodeJpeg");

auto aspect_ratio_min = ctx->Attrs().Get<float>("aspect_ratio_min");
auto aspect_ratio_max = ctx->Attrs().Get<float>("aspect_ratio_max");
PADDLE_ENFORCE_GT(
aspect_ratio_min, 0.,
platform::errors::InvalidArgument(
PADDLE_ENFORCE_GT(aspect_ratio_min, 0.,
platform::errors::InvalidArgument(
"aspect_ratio_min should be greater than 0, but received "
"%f",
aspect_ratio_min));
PADDLE_ENFORCE_GT(
aspect_ratio_max, 0.,
platform::errors::InvalidArgument(
"%f", aspect_ratio_min));
PADDLE_ENFORCE_GT(aspect_ratio_max, 0.,
platform::errors::InvalidArgument(
"aspect_ratio_max should be greater than 0, but received "
"%f",
aspect_ratio_max));
PADDLE_ENFORCE_GE(
aspect_ratio_max, aspect_ratio_min,
platform::errors::InvalidArgument(
"%f", aspect_ratio_max));
PADDLE_ENFORCE_GE(aspect_ratio_max, aspect_ratio_min,
platform::errors::InvalidArgument(
"aspect_ratio_max should be greater than aspect_ratio_min, "
"but received aspect_ratio_max(%d) < aspect_ratio_min(%d)",
aspect_ratio_max, aspect_ratio_min));

auto area_min = ctx->Attrs().Get<float>("area_min");
auto area_max = ctx->Attrs().Get<float>("area_max");
PADDLE_ENFORCE_GT(area_min, 0.,
platform::errors::InvalidArgument(
"area_minshould be greater than 0, but received "
"%f",
area_min));
platform::errors::InvalidArgument(
"area_minshould be greater than 0, but received "
"%f", area_min));
PADDLE_ENFORCE_GT(area_max, 0.,
platform::errors::InvalidArgument(
"area_max should be greater than 0, but received "
"%f",
area_max));
platform::errors::InvalidArgument(
"area_max should be greater than 0, but received "
"%f", area_max));
PADDLE_ENFORCE_GE(area_max, area_min,
platform::errors::InvalidArgument(
"area_max should be greater than area_min, "
"but received area_max(%f) < area_min(%f)",
area_max, area_min));
platform::errors::InvalidArgument(
"area_max should be greater than area_min, "
"but received area_max(%f) < area_min(%f)",
area_max, area_min));

auto num_attempts = ctx->Attrs().Get<int64_t>("num_attempts");
auto num_attempts= ctx->Attrs().Get<int64_t>("num_attempts");
PADDLE_ENFORCE_GT(num_attempts, 0,
platform::errors::InvalidArgument(
"num_attempts should be a positive integerm, but "
"received %d",
num_attempts));
platform::errors::InvalidArgument(
"num_attempts should be a positive integerm, but "
"received %d", num_attempts));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::UINT8,
ctx.GetPlace());
return framework::OpKernelType(
framework::proto::VarType::UINT8, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down Expand Up @@ -108,9 +97,8 @@ class BatchDecodeRandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"A one dimensional uint8 tensor containing the raw bytes "
"of the JPEG image. It is a tensor with rank 1.")
.AsDuplicable();
AddOutput("Out", "The output tensor of DecodeJpeg op").AsDuplicable();
"of the JPEG image. It is a tensor with rank 1.");
AddOutput("Out", "The output tensor of DecodeJpeg op");
AddComment(R"DOC(
This operator decodes a JPEG image into a 3 dimensional RGB Tensor
or 1 dimensional Gray Tensor. Optionally converts the image to the
Expand All @@ -120,14 +108,15 @@ and 255.
AddAttr<int>("local_rank",
"(int64_t)"
"The index of the op to start execution");
AddAttr<int>("num_threads", "Path of the file to be readed.").SetDefault(2);
AddAttr<int64_t>("host_memory_padding",
"(int64, default 0), pinned memory allocation padding "
"number for Nvjpeg decoding")
AddAttr<int>("num_threads", "Path of the file to be readed.")
.SetDefault(2);
AddAttr<int64_t>(
"host_memory_padding",
"(int64, default 0), pinned memory allocation padding number for Nvjpeg decoding")
.SetDefault(0);
AddAttr<int64_t>("device_memory_padding",
"(int64, default 0), device memory allocation padding "
"number for Nvjpeg decoding")
AddAttr<int64_t>(
"device_memory_padding",
"(int64, default 0), device memory allocation padding number for Nvjpeg decoding")
.SetDefault(0);
AddAttr<std::string>(
"data_format",
Expand All @@ -136,8 +125,8 @@ and 255.
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddAttr<float>("aspect_ratio_min", "").SetDefault(3. / 4.);
AddAttr<float>("aspect_ratio_max", "").SetDefault(4. / 3.);
AddAttr<float>("aspect_ratio_min", "").SetDefault(3./4.);
AddAttr<float>("aspect_ratio_max", "").SetDefault(4./3.);
AddAttr<float>("area_min", "").SetDefault(0.08);
AddAttr<float>("area_max", "").SetDefault(1.);
AddAttr<int64_t>("num_attempts", "").SetDefault(10);
Expand All @@ -155,10 +144,8 @@ and 255.
namespace ops = paddle::operators;

REGISTER_OPERATOR(
batch_decode_random_crop, ops::data::BatchDecodeRandomCropOp,
ops::data::BatchDecodeRandomCropOpMaker,
batch_decode_random_crop, ops::data::BatchDecodeRandomCropOp, ops::data::BatchDecodeRandomCropOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)

REGISTER_OP_CPU_KERNEL(batch_decode_random_crop,
ops::data::CPUBatchDecodeRandomCropKernel<uint8_t>)
REGISTER_OP_CPU_KERNEL(batch_decode_random_crop, ops::data::CPUBatchDecodeRandomCropKernel<uint8_t>)
82 changes: 45 additions & 37 deletions paddle/fluid/operators/data/batch_decode_random_crop_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ namespace paddle {
namespace operators {
namespace data {

using LoDTensorBlockingQueueHolder =
operators::reader::LoDTensorBlockingQueueHolder;
using LoDTensorBlockingQueueHolder = operators::reader::LoDTensorBlockingQueueHolder;
using DataLayout = framework::DataLayout;

ImageDecoderThreadPool* decode_pool = nullptr;
Expand All @@ -42,17 +41,21 @@ class GPUBatchDecodeRandomCropKernel : public framework::OpKernel<T> {
auto device_memory_padding = ctx.Attr<int64_t>("device_memory_padding");

// multi-phrase decode thread pool
auto* decode_pool =
ImageDecoderThreadPoolManager::Instance()->GetDecoderThreadPool(
program_id, num_threads, local_rank,
static_cast<size_t>(host_memory_padding),
static_cast<size_t>(device_memory_padding));
auto* decode_pool =
ImageDecoderThreadPoolManager::Instance()->GetDecoderThreadPool(
program_id, num_threads, local_rank,
static_cast<size_t>(host_memory_padding),
static_cast<size_t>(device_memory_padding));

auto inputs = ctx.MultiInput<framework::LoDTensor>("X");
int batch_size = inputs.size();
const framework::LoDTensorArray* inputs =
ctx.Input<framework::LoDTensorArray>("X");
int batch_size = inputs->size();

auto out_array = ctx.MultiOutput<framework::LoDTensor>("Out");
auto* out = ctx.OutputVar("Out");
auto dev = platform::CUDAPlace(local_rank);

auto& out_array = *out->GetMutable<framework::LoDTensorArray>();
out_array.resize(batch_size);

const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_format =
Expand All @@ -72,46 +75,52 @@ class GPUBatchDecodeRandomCropKernel : public framework::OpKernel<T> {
AreaRange area_range{area_min, area_max};

auto* generators = GeneratorManager::Instance()->GetGenerators(
program_id, batch_size, aspect_ratio_range, area_range);

for (size_t i = 0; i < inputs.size(); i++) {
const framework::LoDTensor* x = inputs.at(i);
auto* x_data = x->data<T>();
size_t x_numel = static_cast<size_t>(x->numel());

if (data_format == DataLayout::kNCHW) {
ImageDecodeTask task = {.bit_stream = x_data,
.bit_len = x_numel,
.tensor = &temp_array[i],
.roi_generator = generators->at(i).get(),
.place = dev};
program_id, batch_size, aspect_ratio_range,
area_range);

for (size_t i = 0; i < inputs->size(); i++) {
const framework::LoDTensor x = inputs->at(i);
auto* x_data = x.data<T>();
size_t x_numel = static_cast<size_t>(x.numel());

if (data_format == DataLayout::kNCHW){
ImageDecodeTask task = {
.bit_stream = x_data,
.bit_len = x_numel,
.tensor = &temp_array[i],
.roi_generator = generators->at(i).get(),
.place = dev
};
decode_pool->AddTask(std::make_shared<ImageDecodeTask>(task));
} else {
ImageDecodeTask task = {.bit_stream = x_data,
.bit_len = x_numel,
.tensor = out_array[i],
.roi_generator = generators->at(i).get(),
.place = dev};
}
else{
ImageDecodeTask task = {
.bit_stream = x_data,
.bit_len = x_numel,
.tensor = &out_array[i],
.roi_generator = generators->at(i).get(),
.place = dev
};
decode_pool->AddTask(std::make_shared<ImageDecodeTask>(task));
}

}

decode_pool->RunAll(true);

if (data_format == DataLayout::kNCHW) {
if (data_format == DataLayout::kNCHW){
const auto& dev_ctx = ctx.cuda_device_context();
phi::funcs::Transpose<paddle::platform::CUDADeviceContext, T, 3> trans;
std::vector<int> axis = {2, 0, 1};
for (size_t i = 0; i < inputs.size(); i++) {
for (size_t i = 0; i < inputs->size(); i++) {
// Do transpose
const framework::DDim& in_sizes = temp_array[i].dims();
framework::DDim transposed_input_shape = in_sizes.transpose(axis);
std::vector<int64_t> transposed_input_shape_ =
phi::vectorize(transposed_input_shape);

out_array[i]->Resize(transposed_input_shape);
out_array[i]->mutable_data<T>(dev_ctx.GetPlace());
trans(dev_ctx, temp_array[i], out_array[i], axis);
out_array[i].Resize(transposed_input_shape);
out_array[i].mutable_data<T>(dev_ctx.GetPlace());
trans(dev_ctx, temp_array[i], &out_array[i], axis);
}
}
}
Expand All @@ -122,7 +131,6 @@ class GPUBatchDecodeRandomCropKernel : public framework::OpKernel<T> {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(batch_decode_random_crop,
ops::data::GPUBatchDecodeRandomCropKernel<uint8_t>)
REGISTER_OP_CUDA_KERNEL(batch_decode_random_crop, ops::data::GPUBatchDecodeRandomCropKernel<uint8_t>)

#endif
Loading