Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
e6c84bf
add pipeline. test=develop
heavengate Oct 19, 2021
8f19ea2
fixing compile error.
heavengate Oct 19, 2021
928e2ce
compile success
heavengate Oct 20, 2021
7cdecdd
run success
heavengate Oct 24, 2021
0de3d70
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
heavengate Nov 5, 2021
548bf84
add DataScope
heavengate Nov 9, 2021
9bd247b
refine pipeline manager
heavengate Nov 9, 2021
efec91b
add map_op and map_runner, compile success
heavengate Nov 9, 2021
798331b
add python API and VarType LOD_TENSOR_BLOCKING_QUEUE
heavengate Nov 10, 2021
d66ed31
add Shutdown for MapRunner
heavengate Nov 10, 2021
5ad0af5
add file reader
LielinJiang Oct 28, 2021
8ba9417
add decode
LielinJiang Nov 3, 2021
88b7809
[data op] add random_crop_and_resize_op
ghostxsl Nov 9, 2021
1080a73
fix compile error
heavengate Nov 12, 2021
7e1da1e
output tensor
LielinJiang Nov 12, 2021
ef9b44f
add debug log
heavengate Nov 15, 2021
1d94fb9
add threads pool
LielinJiang Nov 16, 2021
5c1316a
multi-phrase nvjpeg decode single thread success
heavengate Nov 21, 2021
61c85ee
mult-phrase decode with thread pool
heavengate Nov 22, 2021
c5043dc
polish code
heavengate Nov 23, 2021
7ab1889
why 2nd op not run
heavengate Nov 30, 2021
256d7b8
run success
heavengate Dec 1, 2021
41cebd6
map success
heavengate Dec 6, 2021
b14d92a
fix typo and clean log
heavengate Dec 6, 2021
acc731d
polish code
heavengate Dec 7, 2021
78824a8
polish code
heavengate Dec 7, 2021
e9dd9ed
queue + loop success
heavengate Dec 8, 2021
2e89ad3
fix map input type
heavengate Dec 9, 2021
ff3a7cd
use new CUDADeviceContext in map op
heavengate Dec 15, 2021
18cd907
simplify log
heavengate Dec 15, 2021
8c74403
random flip success
heavengate Dec 21, 2021
6e4b45f
polish log
heavengate Dec 21, 2021
a06c26d
add SetROI in nvjpeg decoder
heavengate Dec 30, 2021
cec2758
fix typo
heavengate Jan 3, 2022
3eddb61
fix exit segmentfault, need update
heavengate Jan 4, 2022
887e749
add label and multi-gpu
LielinJiang Jan 5, 2022
37a502a
Merge branch 'add_pipeline' of https://github.com/heavengate/Paddle i…
LielinJiang Jan 5, 2022
632e8b0
add batch_resize/batch_decode
heavengate Jan 5, 2022
d3e5f9e
Merge branch 'add_pipeline' of https://github.com/heavengate/Paddle i…
heavengate Jan 5, 2022
1fc2a58
add local_rank for batch_decode
heavengate Jan 5, 2022
6a74949
refine shutdown
heavengate Jan 9, 2022
ec2ff93
remove prefetch thread/queue in pipeline
heavengate Jan 9, 2022
b5fef20
map support multi inputs/outputs
heavengate Jan 9, 2022
c5860c9
add SIGSEGV handler for map_runner
heavengate Jan 10, 2022
9e03339
add end of epoch process
heavengate Jan 11, 2022
bd93dab
fix nvjpeg hw bug
heavengate Jan 12, 2022
29f670b
use NVJPEG_OUTPUT_RGBI
heavengate Jan 12, 2022
84684b5
add reader manager
LielinJiang Jan 13, 2022
582e851
pull upstream
LielinJiang Jan 13, 2022
47aa4fc
mv file label reader to data/ , add reader manager
LielinJiang Jan 13, 2022
bf7cb1c
Merge pull request #12 from LielinJiang/add_pipeline
heavengate Jan 13, 2022
2e1a4db
add loader & data_reader, compile success
heavengate Jan 19, 2022
4dd29d5
run success, hang to fix
heavengate Jan 19, 2022
b6c2e1f
fix speed
heavengate Jan 20, 2022
fd14988
support dygraph running
heavengate Jan 24, 2022
f0d9d95
fix import _C_ops
heavengate Jan 24, 2022
9adaeab
refine api
heavengate Jan 25, 2022
6183815
fix drop_last=False hang
heavengate Jan 25, 2022
8668ac8
support program shutdown
heavengate Jan 26, 2022
71ee11c
fix dygraph error
heavengate Jan 26, 2022
72922e6
refine shutdown
heavengate Jan 26, 2022
6ed135c
opencv
LielinJiang Jan 27, 2022
6527936
add opencv decode
LielinJiang Jan 28, 2022
002915e
Merge pull request #13 from LielinJiang/io_add_opencv
LielinJiang Jan 28, 2022
b459b82
fix train speed
heavengate Jan 28, 2022
a30b9fb
add random_flip op
heavengate Feb 13, 2022
f482564
fix decode error and add layout for decode op
LielinJiang Feb 14, 2022
a30d38f
clean code
LielinJiang Feb 14, 2022
19b3a08
Merge pull request #14 from LielinJiang/fix-decode-channel-error
heavengate Feb 14, 2022
baf2f55
add mirror_normalize
heavengate Feb 14, 2022
11592c7
Merge branch 'add_pipeline' of https://github.com/heavengate/Paddle i…
heavengate Feb 14, 2022
da718cd
revert pipeline debug
heavengate Feb 14, 2022
a764fca
fix flip_normalize output error
heavengate Feb 14, 2022
a130086
fix index
LielinJiang Feb 17, 2022
7fcd89c
Merge pull request #15 from LielinJiang/fix_index
heavengate Feb 17, 2022
f867fca
fix memory leak
heavengate Feb 28, 2022
141af90
Merge branch 'add_pipeline' of https://github.com/heavengate/Paddle i…
heavengate Feb 28, 2022
8c59fa8
fix memory leak
heavengate Mar 1, 2022
b5213fe
add barrier
heavengate Mar 3, 2022
b0d88ee
fix training hang
heavengate Mar 23, 2022
018451a
change label data type to int64
heavengate Mar 23, 2022
e396470
add opencv cvtColor
heavengate Mar 25, 2022
7cecb18
merge develop
heavengate Mar 25, 2022
e5bed90
add Pipeline.reset()
heavengate Mar 27, 2022
0065a28
unique shuffle seed in data_reader on multi-device
heavengate Mar 27, 2022
c91e76e
merge develop
heavengate Mar 28, 2022
be35dbb
clean code. test=develop
heavengate Mar 28, 2022
2a2eb8e
merge develop
heavengate Mar 28, 2022
a27a9ce
rename data_io_queue -> dataloader_pass. test=develop
heavengate Mar 28, 2022
71900f4
refine API and add file_label_loader unittest. test=develop
heavengate Mar 29, 2022
d608b84
add mirror normalize unittests. test=develop
heavengate Mar 29, 2022
6281d28
add unittest for random_flip. test=develop
heavengate Mar 29, 2022
cc51bbf
add test_ops_crop_resize. test=develop
heavengate Mar 30, 2022
948e035
add unittest for random_crop_and_resize. test=develop
heavengate Mar 30, 2022
a738cfb
add image decode unittest. test=develop
heavengate Apr 1, 2022
e1bf5f1
lod_tensor_array to list[lod_tensor]
LielinJiang Apr 1, 2022
1eea1d2
pull upstream
LielinJiang Apr 1, 2022
7806cf6
Merge pull request #16 from LielinJiang/tensor_array2list_tensor
heavengate Apr 1, 2022
59d0242
refine map API. test=develop
heavengate Apr 1, 2022
8680a98
merge develop. test=develop
heavengate Apr 1, 2022
52771e6
fix ci compile. test=develop
heavengate Apr 2, 2022
353f759
fix ci compile. test=develop
heavengate Apr 2, 2022
ce4610b
fix ci compile. test=develop
heavengate Apr 2, 2022
a08e487
add test_data_pipeline. test=develop
heavengate Apr 2, 2022
0206f3f
add dynamic unittest for all data pipeline ops. test=develop
heavengate Apr 2, 2022
4c91cd9
fix ci compile. test=develop
heavengate Apr 3, 2022
d4bd597
merge develop
heavengate Apr 5, 2022
676335c
merge develop
heavengate Apr 5, 2022
210f5b5
fix unittest. test=develop
heavengate Apr 5, 2022
a559fb2
complete docs. test=develop
heavengate Apr 5, 2022
a01f870
add C++ docs. test=develop
heavengate Apr 5, 2022
03783f2
add test_data_pipeline dynmic test. test=develop
heavengate Apr 5, 2022
6d20a9f
add NVJPEG error meassage spider to fix ci build. test=develop
heavengate Apr 6, 2022
b4571db
fix ci. test=develop
heavengate Apr 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ option(NEW_RELEASE_ALL "PaddlePaddle next-level release strategy for all arche
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON)
option(WITH_OPENCV "Compile with opencv" OFF)
option(WITH_RECORD_BUILDTIME "Compile PaddlePaddle with record all targets build time" OFF)
option(WITH_CUSTOM_DEVICE "Compile with custom device support" OFF)

Expand Down Expand Up @@ -393,6 +394,18 @@ include(third_party) # download, build, install third_party, Contains about 20+

include(flags) # set paddle compile flags

if(WITH_OPENCV)
find_package(OpenCV 4.0 QUIET COMPONENTS core imgproc imgcodecs)
if(NOT OpenCV_FOUND)
find_package(OpenCV 3.0 REQUIRED COMPONENTS core imgproc imgcodecs)
endif()
message(STATUS "Found OpenCV: ${OpenCV_INCLUDE_DIRS} (found suitable version \"${OpenCV_VERSION}\", minimum required is \"3.0\")")
include_directories(SYSTEM ${OpenCV_INCLUDE_DIRS})
include_directories(${OpenCV_INCLUDE_DIRS})
link_directories(${OpenCV_LIBS})
add_definitions(-DPADDLE_WITH_OPENCV)
endif()

if(WITH_PROFILER)
find_package(Gperftools REQUIRED)
include_directories(${GPERFTOOLS_INCLUDE_DIR})
Expand Down
2 changes: 2 additions & 0 deletions cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ function(common_link TARGET_NAME)
if (WITH_PROFILER)
target_link_libraries(${TARGET_NAME} gperftools::profiler)
endif()


endfunction()

# find all third_party modules is used for paddle static library
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass
fix_op_run_order_pass fuse_gemm_epilogue_pass)
fix_op_run_order_pass fuse_gemm_epilogue_pass dataloader_queue_pass)

if (WITH_CINN)
set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass)
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");

AppendPass("dataloader_queue_pass");

SetCollectiveContext();
}

Expand Down Expand Up @@ -503,6 +505,7 @@ USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass);
USE_PASS(dataloader_queue_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass);
#endif
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/executor_gc_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ static bool VarCanBeDeleted(const std::string &name, const BlockDesc &block,

return type == proto::VarType::LOD_TENSOR ||
type == proto::VarType::SELECTED_ROWS ||
type == proto::VarType::LOD_TENSOR_ARRAY;
type == proto::VarType::LOD_TENSOR_ARRAY ||
type == proto::VarType::LOD_TENSOR_BLOCKING_QUEUE;
}

std::unordered_map<const OperatorBase *, std::vector<std::string>>
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@ message VarType {
STRINGS = 26;
VOCAB = 27;
FEED_LIST = 28;

// The data type of phi::StringTensor
PSTRING = 29;

LOD_TENSOR_BLOCKING_QUEUE = 31;
}

required Type type = 1;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(mixed_precision_configure_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
pass_library(dataloader_queue_pass base)
target_link_libraries(generate_pass pass_desc_proto)

if(WITH_TENSORRT)
Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/framework/ir/dataloader_queue_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <map>
#include <set>

#include "glog/logging.h"
#include "paddle/fluid/framework/ir/pass.h"

namespace paddle {
namespace framework {
namespace ir {

class Graph;

std::set<std::string> output_queue_holder_ops = {
"file_label_reader", "map", "data_reader",
};

std::set<std::string> input_array_ops = {
"random_crop_and_resize", "batch_decode",
};

static bool IsOutputQueueHolderOp(std::string op_type) {
return output_queue_holder_ops.find(op_type) != output_queue_holder_ops.end();
}

static bool IsInputArrayOp(std::string op_type) {
return input_array_ops.find(op_type) != input_array_ops.end();
}

static void ProcessOutputQueueHolderOp(ir::Graph *graph) {
std::set<std::string> var_names;
for (const Node *n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
auto *op = n->Op();
if (IsOutputQueueHolderOp(op->Type())) {
auto &outputs = op->Outputs();
for (auto iter = outputs.begin(); iter != outputs.end(); iter++) {
for (auto var : iter->second) var_names.insert(var);
}
}
}
}

for (const Node *n : graph->Nodes()) {
if (n->IsVar() && n->Var()) {
auto *var = n->Var();
if (var_names.find(var->Name()) != var_names.end()) {
VLOG(3) << "Change output variable type of " << var->Name()
<< " to queue holder";
var->SetType(framework::proto::VarType::LOD_TENSOR_BLOCKING_QUEUE);
var->SetPersistable(true);
}
}
}
}

static void ProcessInputArrayOp(ir::Graph *graph) {
std::set<std::string> var_names;
for (const Node *n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
auto *op = n->Op();
if (IsInputArrayOp(op->Type())) {
auto &inputs = op->Inputs();
for (auto iter = inputs.begin(); iter != inputs.end(); iter++) {
for (auto var : iter->second) var_names.insert(var);
}
}
}
}

for (const Node *n : graph->Nodes()) {
if (n->IsVar() && n->Var()) {
auto *var = n->Var();
if (var_names.find(var->Name()) != var_names.end()) {
VLOG(3) << "Change output variable type of " << var->Name()
<< " to queue holder";
var->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
}
}
}
}

class DataLoaderQueuePass : public Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
ProcessOutputQueueHolderOp(graph);
ProcessInputArrayOp(graph);
}
};

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(dataloader_queue_pass,
paddle::framework::ir::DataLoaderQueuePass);
8 changes: 7 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place,
RuntimeContext* runtime_ctx) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
auto* dev_ctx = HasAttr("_stream_id")
? platform::AsyncDeviceContextPool::Instance().Get(
place, Attr<int>("_stream_id"))
: nullptr;
if (dev_ctx == nullptr) {
dev_ctx = pool.Get(place);
}

#ifdef PADDLE_WITH_ASCEND_CL
// NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/var_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST);
REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER);
REG_PROTO_VAR_TYPE_TRAIT(operators::reader::LoDTensorBlockingQueueHolder,
proto::VarType::LOD_TENSOR_BLOCKING_QUEUE);
REG_PROTO_VAR_TYPE_TRAIT(FeedList, proto::VarType::FEED_LIST);
REG_PROTO_VAR_TYPE_TRAIT(FetchList, proto::VarType::FETCH_LIST);
REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32);
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/variable_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
Expand All @@ -42,6 +43,8 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::LOD_TENSOR_BLOCKING_QUEUE) {
var->GetMutable<operators::reader::LoDTensorBlockingQueueHolder>();
} else if (var_type == proto::VarType::STRINGS) {
var->GetMutable<Strings>();
} else if (var_type == proto::VarType::VOCAB) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ add_subdirectory(reader)

if (NOT WIN32)
add_subdirectory(nccl)
add_subdirectory(data)
endif()

if (WITH_GPU AND TENSORRT_FOUND)
Expand Down
29 changes: 29 additions & 0 deletions paddle/fluid/operators/data/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
include(operators)

if(WITH_UNITY_BUILD)
# Load Unity Build rules for operators in paddle/fluid/operators/data/
include(unity_build_rule.cmake)
endif()

cc_library(pipeline SRCS pipeline.cc DEPS parallel_executor simple_threadpool scope)
op_library(dataloader_op SRCS dataloader_op.cc dataloader_op.cu.cc DEPS pipeline ${OP_HEADER_DEPS})

op_library(data_reader_op SRCS data_reader_op.cc DEPS ${OP_HEADER_DEPS})

cc_library(map_runner SRCS map_runner.cc DEPS parallel_executor simple_threadpool scope)
op_library(map_op SRCS map_op.cc map_op.cu.cc DEPS map_runner ${OP_HEADER_DEPS})

if (WITH_GPU AND NOT WIN32)
op_library(file_label_loader_op SRCS file_label_loader_op.cc DEPS ${OP_HEADER_DEPS})

cc_library(random_roi_generator SRCS random_roi_generator.cc DEPS ${OP_HEADER_DEPS})
cc_library(image_decoder SRCS image_decoder.cc DEPS random_roi_generator ${OP_HEADER_DEPS})

op_library(batch_decode_random_crop_op SRCS batch_decode_random_crop_op.cc batch_decode_random_crop_op.cu DEPS image_decoder ${OP_HEADER_DEPS})
op_library(batch_decode_op SRCS batch_decode_op.cc batch_decode_op.cu DEPS image_decoder ${OP_HEADER_DEPS})

op_library(batch_random_crop_and_resize_op SRCS batch_random_crop_and_resize_op.cc batch_random_crop_and_resize_op.cu DEPS ${OP_HEADER_DEPS})
op_library(batch_resize_op SRCS batch_resize_op.cc batch_resize_op.cu DEPS ${OP_HEADER_DEPS})

op_library(mirror_normalize_op SRCS mirror_normalize_op.cc mirror_normalize_op.cu DEPS ${OP_HEADER_DEPS})
endif()
98 changes: 98 additions & 0 deletions paddle/fluid/operators/data/batch_decode_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/data/batch_decode_op.h"

namespace paddle {
namespace operators {
namespace data {

class BatchDecodeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(X) of DecodeJpeg should not be empty."));
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of DecodeJpeg should not be empty."));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::UINT8,
ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (var_name == "X") {
return expected_kernel_type;
}

return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place());
}
};

class BatchDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(List[Tensor]) A one dimensional uint8 tensor containing "
"the raw bytes of the JPEG image. It is a tensor with rank "
"1.")
.AsDuplicable();
AddOutput("Out", "The output tensor of BatchDecodeOp").AsDuplicable();
AddComment(R"DOC(
This operator decodes a JPEG image into a 3 dimensional RGB Tensor.
The values of the output tensor are uint8 between 0 and 255.
)DOC");
AddAttr<int>("num_threads", "Path of the file to be readed.").SetDefault(2);
AddAttr<int>("local_rank",
"(int)"
"The index of the op to start execution");
AddAttr<int64_t>("program_id",
"(int64_t)"
"The unique hash id used as cache key for "
"decode thread pool");
AddAttr<int64_t>(
"host_memory_padding",
"(int64, default 0),"
"pinned memory allocation padding number for Nvjpeg decoding")
.SetDefault(0);
AddAttr<int64_t>(
"device_memory_padding",
"(int64, default 0),"
"device memory allocation padding number for Nvjpeg decoding")
.SetDefault(0);
}
};

} // namespace data
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(
batch_decode, ops::data::BatchDecodeOp, ops::data::BatchDecodeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)

REGISTER_OP_CPU_KERNEL(batch_decode, ops::data::CPUBatchDecodeKernel<uint8_t>)
Loading