diff --git a/CMakeLists.txt b/CMakeLists.txt index 6988434996bcc4..9f986dcb59f3fa 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,6 +244,7 @@ option(NEW_RELEASE_ALL "PaddlePaddle next-level release strategy for all arche option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF) option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF) option(WITH_POCKETFFT "Compile with pocketfft support" ON) +option(WITH_OPENCV "Compile with opencv" OFF) option(WITH_RECORD_BUILDTIME "Compile PaddlePaddle with record all targets build time" OFF) option(WITH_CUSTOM_DEVICE "Compile with custom device support" OFF) @@ -393,6 +394,18 @@ include(third_party) # download, build, install third_party, Contains about 20+ include(flags) # set paddle compile flags +if(WITH_OPENCV) + find_package(OpenCV 4.0 QUIET COMPONENTS core imgproc imgcodecs) + if(NOT OpenCV_FOUND) + find_package(OpenCV 3.0 REQUIRED COMPONENTS core imgproc imgcodecs) + endif() + message(STATUS "Found OpenCV: ${OpenCV_INCLUDE_DIRS} (found suitable version \"${OpenCV_VERSION}\", minimum required is \"3.0\")") + include_directories(SYSTEM ${OpenCV_INCLUDE_DIRS}) + include_directories(${OpenCV_INCLUDE_DIRS}) + link_directories(${OpenCV_LIBS}) + add_definitions(-DPADDLE_WITH_OPENCV) +endif() + if(WITH_PROFILER) find_package(Gperftools REQUIRED) include_directories(${GPERFTOOLS_INCLUDE_DIR}) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index ba59eae392c663..5a2d5e96ed654d 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -134,6 +134,8 @@ function(common_link TARGET_NAME) if (WITH_PROFILER) target_link_libraries(${TARGET_NAME} gperftools::profiler) endif() + + endfunction() # find all third_party modules is used for paddle static library diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 948eaab40b4f64..f4a49a0c1f7cde 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass - fix_op_run_order_pass fuse_gemm_epilogue_pass) + fix_op_run_order_pass fuse_gemm_epilogue_pass dataloader_queue_pass) if (WITH_CINN) set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index fdf74d2f769fcd..bcc4ed9f7b9272 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -84,6 +84,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Note: This pass is used to check whether the multi_device_graph is right. AppendPass("multi_devices_check_pass"); + AppendPass("dataloader_queue_pass"); + SetCollectiveContext(); } @@ -503,6 +505,7 @@ USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); USE_PASS(add_reader_dependency_pass); +USE_PASS(dataloader_queue_pass); #ifdef PADDLE_WITH_CINN USE_PASS(build_cinn_pass); #endif diff --git a/paddle/fluid/framework/executor_gc_helper.cc b/paddle/fluid/framework/executor_gc_helper.cc index 6dc53c9649e9d5..58bb317d16c6c6 100644 --- a/paddle/fluid/framework/executor_gc_helper.cc +++ b/paddle/fluid/framework/executor_gc_helper.cc @@ -76,7 +76,8 @@ static bool VarCanBeDeleted(const std::string &name, const BlockDesc &block, return type == proto::VarType::LOD_TENSOR || type == proto::VarType::SELECTED_ROWS || - type == proto::VarType::LOD_TENSOR_ARRAY; + type == proto::VarType::LOD_TENSOR_ARRAY || + type == proto::VarType::LOD_TENSOR_BLOCKING_QUEUE; } std::unordered_map> diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 0d3e7c2741c17b..2e25b57c88cf78 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -152,8 +152,11 @@ message VarType { STRINGS = 26; VOCAB = 27; FEED_LIST = 28; + // The data type of phi::StringTensor PSTRING = 29; + + LOD_TENSOR_BLOCKING_QUEUE = 31; } required Type type = 1; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8cacf34834a16d..c1a133dfeca481 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -101,6 +101,7 @@ pass_library(matmul_scale_fuse_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(mixed_precision_configure_pass inference) pass_library(generate_pass DEPS pass_desc_proto) +pass_library(dataloader_queue_pass base) target_link_libraries(generate_pass pass_desc_proto) if(WITH_TENSORRT) diff --git a/paddle/fluid/framework/ir/dataloader_queue_pass.cc b/paddle/fluid/framework/ir/dataloader_queue_pass.cc new file mode 100644 index 00000000000000..8f3a902815da7a --- /dev/null +++ b/paddle/fluid/framework/ir/dataloader_queue_pass.cc @@ -0,0 +1,109 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +std::set output_queue_holder_ops = { + "file_label_reader", "map", "data_reader", +}; + +std::set input_array_ops = { + "random_crop_and_resize", "batch_decode", +}; + +static bool IsOutputQueueHolderOp(std::string op_type) { + return output_queue_holder_ops.find(op_type) != output_queue_holder_ops.end(); +} + +static bool IsInputArrayOp(std::string op_type) { + return input_array_ops.find(op_type) != input_array_ops.end(); +} + +static void ProcessOutputQueueHolderOp(ir::Graph *graph) { + std::set var_names; + for (const Node *n : graph->Nodes()) { + if (n->IsOp() && n->Op()) { + auto *op = n->Op(); + if (IsOutputQueueHolderOp(op->Type())) { + auto &outputs = op->Outputs(); + for (auto iter = outputs.begin(); iter != outputs.end(); iter++) { + for (auto var : iter->second) var_names.insert(var); + } + } + } + } + + for (const Node *n : graph->Nodes()) { + if (n->IsVar() && n->Var()) { + auto *var = n->Var(); + if (var_names.find(var->Name()) != var_names.end()) { + VLOG(3) << "Change output variable type of " << var->Name() + << " to queue holder"; + var->SetType(framework::proto::VarType::LOD_TENSOR_BLOCKING_QUEUE); + var->SetPersistable(true); + } + } + } +} + +static void ProcessInputArrayOp(ir::Graph *graph) { + std::set var_names; + for (const Node *n : graph->Nodes()) { + if (n->IsOp() && n->Op()) { + auto *op = n->Op(); + if (IsInputArrayOp(op->Type())) { + auto &inputs = op->Inputs(); + for (auto iter = inputs.begin(); iter != inputs.end(); iter++) { + for (auto var : iter->second) var_names.insert(var); + } + } + } + } + + for (const Node *n : graph->Nodes()) { + if (n->IsVar() && n->Var()) { + auto *var = n->Var(); + if (var_names.find(var->Name()) != var_names.end()) { + VLOG(3) << "Change output variable type of " << var->Name() + << " to queue holder"; + var->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); + } + } + } +} + +class DataLoaderQueuePass : public Pass { + protected: + void ApplyImpl(ir::Graph *graph) const override { + ProcessOutputQueueHolderOp(graph); + ProcessInputArrayOp(graph); + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(dataloader_queue_pass, + paddle::framework::ir::DataLoaderQueuePass); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 49248edd322d29..1439a1a01ceb0f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1242,7 +1242,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place, RuntimeContext* runtime_ctx) const { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(place); + auto* dev_ctx = HasAttr("_stream_id") + ? platform::AsyncDeviceContextPool::Instance().Get( + place, Attr("_stream_id")) + : nullptr; + if (dev_ctx == nullptr) { + dev_ctx = pool.Get(place); + } #ifdef PADDLE_WITH_ASCEND_CL // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 9fe67e1dcdff31..ef825b960df6be 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -219,6 +219,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE); REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY); REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST); REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER); +REG_PROTO_VAR_TYPE_TRAIT(operators::reader::LoDTensorBlockingQueueHolder, + proto::VarType::LOD_TENSOR_BLOCKING_QUEUE); REG_PROTO_VAR_TYPE_TRAIT(FeedList, proto::VarType::FEED_LIST); REG_PROTO_VAR_TYPE_TRAIT(FetchList, proto::VarType::FETCH_LIST); REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32); diff --git a/paddle/fluid/framework/variable_helper.cc b/paddle/fluid/framework/variable_helper.cc index 471efc02078357..1613f4859a07c8 100644 --- a/paddle/fluid/framework/variable_helper.cc +++ b/paddle/fluid/framework/variable_helper.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/string_array.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -42,6 +43,8 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) { var->GetMutable(); } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { var->GetMutable(); + } else if (var_type == proto::VarType::LOD_TENSOR_BLOCKING_QUEUE) { + var->GetMutable(); } else if (var_type == proto::VarType::STRINGS) { var->GetMutable(); } else if (var_type == proto::VarType::VOCAB) { diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 3901226216f4d2..aee9a794c15ccb 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -41,6 +41,7 @@ add_subdirectory(reader) if (NOT WIN32) add_subdirectory(nccl) + add_subdirectory(data) endif() if (WITH_GPU AND TENSORRT_FOUND) diff --git a/paddle/fluid/operators/data/CMakeLists.txt b/paddle/fluid/operators/data/CMakeLists.txt new file mode 100644 index 00000000000000..d0ee1bff9e3f2a --- /dev/null +++ b/paddle/fluid/operators/data/CMakeLists.txt @@ -0,0 +1,29 @@ +include(operators) + +if(WITH_UNITY_BUILD) + # Load Unity Build rules for operators in paddle/fluid/operators/data/ + include(unity_build_rule.cmake) +endif() + +cc_library(pipeline SRCS pipeline.cc DEPS parallel_executor simple_threadpool scope) +op_library(dataloader_op SRCS dataloader_op.cc dataloader_op.cu.cc DEPS pipeline ${OP_HEADER_DEPS}) + +op_library(data_reader_op SRCS data_reader_op.cc DEPS ${OP_HEADER_DEPS}) + +cc_library(map_runner SRCS map_runner.cc DEPS parallel_executor simple_threadpool scope) +op_library(map_op SRCS map_op.cc map_op.cu.cc DEPS map_runner ${OP_HEADER_DEPS}) + +if (WITH_GPU AND NOT WIN32) + op_library(file_label_loader_op SRCS file_label_loader_op.cc DEPS ${OP_HEADER_DEPS}) + + cc_library(random_roi_generator SRCS random_roi_generator.cc DEPS ${OP_HEADER_DEPS}) + cc_library(image_decoder SRCS image_decoder.cc DEPS random_roi_generator ${OP_HEADER_DEPS}) + + op_library(batch_decode_random_crop_op SRCS batch_decode_random_crop_op.cc batch_decode_random_crop_op.cu DEPS image_decoder ${OP_HEADER_DEPS}) + op_library(batch_decode_op SRCS batch_decode_op.cc batch_decode_op.cu DEPS image_decoder ${OP_HEADER_DEPS}) + + op_library(batch_random_crop_and_resize_op SRCS batch_random_crop_and_resize_op.cc batch_random_crop_and_resize_op.cu DEPS ${OP_HEADER_DEPS}) + op_library(batch_resize_op SRCS batch_resize_op.cc batch_resize_op.cu DEPS ${OP_HEADER_DEPS}) + + op_library(mirror_normalize_op SRCS mirror_normalize_op.cc mirror_normalize_op.cu DEPS ${OP_HEADER_DEPS}) +endif() diff --git a/paddle/fluid/operators/data/batch_decode_op.cc b/paddle/fluid/operators/data/batch_decode_op.cc new file mode 100644 index 00000000000000..bb3367d4c47909 --- /dev/null +++ b/paddle/fluid/operators/data/batch_decode_op.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/batch_decode_op.h" + +namespace paddle { +namespace operators { +namespace data { + +class BatchDecodeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + platform::errors::InvalidArgument( + "Inputs(X) of DecodeJpeg should not be empty.")); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, + platform::errors::InvalidArgument( + "Outputs(Out) of DecodeJpeg should not be empty.")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::UINT8, + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (var_name == "X") { + return expected_kernel_type; + } + + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place()); + } +}; + +class BatchDecodeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(List[Tensor]) A one dimensional uint8 tensor containing " + "the raw bytes of the JPEG image. It is a tensor with rank " + "1.") + .AsDuplicable(); + AddOutput("Out", "The output tensor of BatchDecodeOp").AsDuplicable(); + AddComment(R"DOC( +This operator decodes a JPEG image into a 3 dimensional RGB Tensor. +The values of the output tensor are uint8 between 0 and 255. +)DOC"); + AddAttr("num_threads", "Path of the file to be readed.").SetDefault(2); + AddAttr("local_rank", + "(int)" + "The index of the op to start execution"); + AddAttr("program_id", + "(int64_t)" + "The unique hash id used as cache key for " + "decode thread pool"); + AddAttr( + "host_memory_padding", + "(int64, default 0)," + "pinned memory allocation padding number for Nvjpeg decoding") + .SetDefault(0); + AddAttr( + "device_memory_padding", + "(int64, default 0)," + "device memory allocation padding number for Nvjpeg decoding") + .SetDefault(0); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + batch_decode, ops::data::BatchDecodeOp, ops::data::BatchDecodeOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker) + +REGISTER_OP_CPU_KERNEL(batch_decode, ops::data::CPUBatchDecodeKernel) diff --git a/paddle/fluid/operators/data/batch_decode_op.cu b/paddle/fluid/operators/data/batch_decode_op.cu new file mode 100644 index 00000000000000..aa5596527e40ec --- /dev/null +++ b/paddle/fluid/operators/data/batch_decode_op.cu @@ -0,0 +1,76 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_HIP) + +#include "paddle/fluid/operators/data/batch_decode_op.h" +#include "paddle/fluid/operators/data/batch_decode_random_crop_op.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { +namespace data { + +using LoDTensorBlockingQueueHolder = + operators::reader::LoDTensorBlockingQueueHolder; + +template +class GPUBatchDecodeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int num_threads = ctx.Attr("num_threads"); + auto local_rank = ctx.Attr("local_rank"); + auto program_id = ctx.Attr("program_id"); + auto host_memory_padding = ctx.Attr("host_memory_padding"); + auto device_memory_padding = ctx.Attr("device_memory_padding"); + + // multi-phrase decode thread pool + auto* decode_pool = + ImageDecoderThreadPoolManager::Instance()->GetDecoderThreadPool( + program_id, num_threads, local_rank, + static_cast(host_memory_padding), + static_cast(device_memory_padding)); + + auto inputs = ctx.MultiInput("X"); + int batch_size = inputs.size(); + + auto out_array = ctx.MultiOutput("Out"); + auto dev = platform::CUDAPlace(local_rank); + + for (size_t i = 0; i < batch_size; i++) { + const framework::LoDTensor* x = inputs.at(i); + auto* x_data = x->data(); + size_t x_numel = static_cast(x->numel()); + + ImageDecodeTask task; + task.bit_stream = x_data; + task.bit_len = x_numel; + task.tensor = out_array[i]; + task.roi_generator = nullptr; + task.place = dev; + decode_pool->AddTask(std::make_shared(task)); + } + + decode_pool->RunAll(true); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(batch_decode, ops::data::GPUBatchDecodeKernel) + +#endif diff --git a/paddle/fluid/operators/data/batch_decode_op.h b/paddle/fluid/operators/data/batch_decode_op.h new file mode 100644 index 00000000000000..cb0b4382346adf --- /dev/null +++ b/paddle/fluid/operators/data/batch_decode_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/data/image_decoder.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { +namespace data { + +template +class CPUBatchDecodeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // TODO(LieLinJiang): add cpu implement. + PADDLE_THROW(platform::errors::Unimplemented( + "BatchDecode op only supports GPU now.")); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/batch_decode_random_crop_op.cc b/paddle/fluid/operators/data/batch_decode_random_crop_op.cc new file mode 100644 index 00000000000000..691d738a87a110 --- /dev/null +++ b/paddle/fluid/operators/data/batch_decode_random_crop_op.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/batch_decode_random_crop_op.h" + +namespace paddle { +namespace operators { +namespace data { + +class BatchDecodeRandomCropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + platform::errors::InvalidArgument( + "Inputs(X) of DecodeJpeg should not be empty.")); + PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, + platform::errors::InvalidArgument( + "Outputs(Out) of DecodeJpeg should not be empty.")); + auto aspect_ratio_min = ctx->Attrs().Get("aspect_ratio_min"); + auto aspect_ratio_max = ctx->Attrs().Get("aspect_ratio_max"); + PADDLE_ENFORCE_GT( + aspect_ratio_min, 0., + platform::errors::InvalidArgument( + "aspect_ratio_min should be greater than 0, but received " + "%f", + aspect_ratio_min)); + PADDLE_ENFORCE_GT( + aspect_ratio_max, 0., + platform::errors::InvalidArgument( + "aspect_ratio_max should be greater than 0, but received " + "%f", + aspect_ratio_max)); + PADDLE_ENFORCE_GE( + aspect_ratio_max, aspect_ratio_min, + platform::errors::InvalidArgument( + "aspect_ratio_max should be greater than aspect_ratio_min, " + "but received aspect_ratio_max(%d) < aspect_ratio_min(%d)", + aspect_ratio_max, aspect_ratio_min)); + + auto area_min = ctx->Attrs().Get("area_min"); + auto area_max = ctx->Attrs().Get("area_max"); + PADDLE_ENFORCE_GT(area_min, 0., + platform::errors::InvalidArgument( + "area_minshould be greater than 0, but received " + "%f", + area_min)); + PADDLE_ENFORCE_GT(area_max, 0., + platform::errors::InvalidArgument( + "area_max should be greater than 0, but received " + "%f", + area_max)); + PADDLE_ENFORCE_GE(area_max, area_min, + platform::errors::InvalidArgument( + "area_max should be greater than area_min, " + "but received area_max(%f) < area_min(%f)", + area_max, area_min)); + + auto num_attempts = ctx->Attrs().Get("num_attempts"); + PADDLE_ENFORCE_GT(num_attempts, 0, + platform::errors::InvalidArgument( + "num_attempts should be a positive integerm, but " + "received %d", + num_attempts)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::UINT8, + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (var_name == "X") { + return expected_kernel_type; + } + + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place()); + } +}; + +class BatchDecodeRandomCropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(List[Tensor]) A one dimensional uint8 tensor containing the " + "raw bytes of the JPEG image. It is a tensor with rank 1.") + .AsDuplicable(); + AddOutput("Out", "The output tensor of BatchDecodeRandomCropOp") + .AsDuplicable(); + AddComment(R"DOC( +This operator decodes a JPEG image into a 3 dimensional RGB Tensor. +Optionally converts the image to the desired format. +The values of the output tensor are uint8 between 0 and 255. +)DOC"); + AddAttr("local_rank", + "(int64_t)" + "The index of the op to start execution"); + AddAttr("num_threads", "Path of the file to be readed.").SetDefault(2); + AddAttr("host_memory_padding", + "(int64, default 0), pinned memory allocation padding " + "number for Nvjpeg decoding") + .SetDefault(0); + AddAttr("device_memory_padding", + "(int64, default 0), device memory allocation padding " + "number for Nvjpeg decoding") + .SetDefault(0); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "an optional string from: \"NHWC\", \"NCHW\". " + "Specify that the data format of the input and output data is " + "channel_first or channel_last.") + .SetDefault("NCHW"); + AddAttr("aspect_ratio_min", + "(float) The minimum aspect ratio of random cropping boxes") + .SetDefault(3. / 4.); + AddAttr("aspect_ratio_max", + "(float) The maximum aspect ratio of random cropping boxes") + .SetDefault(4. / 3.); + AddAttr("area_min", + "(float) The min area ratio of random cropping boxes") + .SetDefault(0.08); + AddAttr("area_max", + "(float) The max area ratio of random cropping boxes") + .SetDefault(1.); + AddAttr("num_attempts", + "(int) The max attempt number of random cropping boxes") + .SetDefault(10); + AddAttr("program_id", + "(int64_t)" + "The unique hash id used as cache key for " + "decode thread pool"); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + batch_decode_random_crop, ops::data::BatchDecodeRandomCropOp, + ops::data::BatchDecodeRandomCropOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker) + +REGISTER_OP_CPU_KERNEL(batch_decode_random_crop, + ops::data::CPUBatchDecodeRandomCropKernel) diff --git a/paddle/fluid/operators/data/batch_decode_random_crop_op.cu b/paddle/fluid/operators/data/batch_decode_random_crop_op.cu new file mode 100644 index 00000000000000..dc06f0db496e06 --- /dev/null +++ b/paddle/fluid/operators/data/batch_decode_random_crop_op.cu @@ -0,0 +1,120 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_HIP) + +#include +#include +#include "paddle/fluid/operators/data/batch_decode_random_crop_op.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { +namespace data { + +using LoDTensorBlockingQueueHolder = + operators::reader::LoDTensorBlockingQueueHolder; +using DataLayout = framework::DataLayout; + +ImageDecoderThreadPool* decode_pool = nullptr; + +template +class GPUBatchDecodeRandomCropKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int num_threads = ctx.Attr("num_threads"); + auto local_rank = ctx.Attr("local_rank"); + auto program_id = ctx.Attr("program_id"); + auto host_memory_padding = ctx.Attr("host_memory_padding"); + auto device_memory_padding = ctx.Attr("device_memory_padding"); + + // multi-phrase decode thread pool + auto* decode_pool = + ImageDecoderThreadPoolManager::Instance()->GetDecoderThreadPool( + program_id, num_threads, local_rank, + static_cast(host_memory_padding), + static_cast(device_memory_padding)); + + auto inputs = ctx.MultiInput("X"); + int batch_size = inputs.size(); + + auto out_array = ctx.MultiOutput("Out"); + auto dev = platform::CUDAPlace(local_rank); + + const std::string data_format_str = ctx.Attr("data_format"); + const DataLayout data_format = + framework::StringToDataLayout(data_format_str); + + framework::LoDTensorArray temp_array; + if (data_format == DataLayout::kNCHW) { + temp_array.resize(batch_size); + } + + auto aspect_ratio_min = ctx.Attr("aspect_ratio_min"); + auto aspect_ratio_max = ctx.Attr("aspect_ratio_max"); + AspectRatioRange aspect_ratio_range{aspect_ratio_min, aspect_ratio_max}; + + auto area_min = ctx.Attr("area_min"); + auto area_max = ctx.Attr("area_max"); + AreaRange area_range{area_min, area_max}; + + auto* generators = GeneratorManager::Instance()->GetGenerators( + program_id, batch_size, aspect_ratio_range, area_range); + + for (size_t i = 0; i < inputs.size(); i++) { + const framework::LoDTensor* x = inputs.at(i); + auto* x_data = x->data(); + size_t x_numel = static_cast(x->numel()); + + ImageDecodeTask task; + task.bit_stream = x_data; + task.bit_len = x_numel; + task.roi_generator = generators->at(i).get(), task.place = dev; + task.tensor = + data_format == DataLayout::kNCHW ? &temp_array[i] : out_array[i]; + decode_pool->AddTask(std::make_shared(task)); + } + + decode_pool->RunAll(true); + + if (data_format == DataLayout::kNCHW) { + const auto& dev_ctx = ctx.cuda_device_context(); + phi::funcs::Transpose trans; + std::vector axis = {2, 0, 1}; + for (size_t i = 0; i < inputs.size(); i++) { + // Do transpose + const framework::DDim& in_sizes = temp_array[i].dims(); + framework::DDim transposed_input_shape = in_sizes.transpose(axis); + std::vector transposed_input_shape_ = + phi::vectorize(transposed_input_shape); + + out_array[i]->Resize(transposed_input_shape); + out_array[i]->mutable_data(dev_ctx.GetPlace()); + trans(dev_ctx, temp_array[i], out_array[i], axis); + } + } + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(batch_decode_random_crop, + ops::data::GPUBatchDecodeRandomCropKernel) + +#endif diff --git a/paddle/fluid/operators/data/batch_decode_random_crop_op.h b/paddle/fluid/operators/data/batch_decode_random_crop_op.h new file mode 100644 index 00000000000000..cda2c39ff89df9 --- /dev/null +++ b/paddle/fluid/operators/data/batch_decode_random_crop_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/data/image_decoder.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { +namespace data { + +template +class CPUBatchDecodeRandomCropKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // TODO(LieLinJiang): add cpu implement. + PADDLE_THROW(platform::errors::Unimplemented( + "BatchDecodeRandomCrop op only supports GPU now.")); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/batch_random_crop_and_resize_op.cc b/paddle/fluid/operators/data/batch_random_crop_and_resize_op.cc new file mode 100644 index 00000000000000..55a2e23cdab9ec --- /dev/null +++ b/paddle/fluid/operators/data/batch_random_crop_and_resize_op.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/batch_random_crop_and_resize_op.h" + +namespace paddle { +namespace operators { +namespace data { + +class BatchRandomCropAndResizeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + platform::errors::InvalidArgument( + "Inputs(X) of BatchRandomCropAndResize " + "should not be empty.")); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "BatchRandomCropAndResize"); + + auto size = ctx->Attrs().Get>("size"); + PADDLE_ENFORCE_EQ(size.size(), 2, + platform::errors::InvalidArgument( + "The length of Attrs(size) should be 2.")); + PADDLE_ENFORCE_GT(size[0], 0, + platform::errors::InvalidArgument( + "h in Attr(size) of Op(BatchRandomCropAndResize) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(size[1], 0, + platform::errors::InvalidArgument( + "w in Attr(size) of Op(BatchRandomCropAndResize) " + "should be greater than 0.")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::UINT8, + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "X") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class BatchRandomCropAndResizeOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(List(Tensor)). A batch of instances to random crop.") + .AsDuplicable(); + AddOutput("Out", "(Tensor). The cropped instance batch."); + AddAttr("aspect_ratio_min", + "(float) The minimum aspect ratio of random cropping boxes") + .SetDefault(3. / 4.); + AddAttr("aspect_ratio_max", + "(float) The maximum aspect ratio of random cropping boxes") + .SetDefault(4. / 3.); + AddAttr("area_min", + "(float) The min area ratio of random cropping boxes") + .SetDefault(0.08); + AddAttr("area_max", + "(float) The max area ratio of random cropping boxes") + .SetDefault(1.); + AddAttr("num_attempts", + "(int) The max attempt number of random cropping boxes") + .SetDefault(10); + AddAttr>( + "size", "expected output size of the crop, for each edge."); + AddAttr("interp_method", + "(string, default \"bilinear\"), interpolation " + "method, can be \"bilinear\" for " + "bilinear interpolation and \"nearest\" for nearest " + "neighbor interpolation.") + .SetDefault("bilinear"); + AddAttr( + "align_corners", + "an optional bool. Defaults to True. " + "If True, the centers of 4 corner pixels of the input and output " + "tensors are aligned, preserving the values at the corner pixels, " + "If False, are not aligned") + .SetDefault(true); + AddAttr("align_mode", + "(int, default \'1\'), optional for bilinear interpolation, " + "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , " + "can be \'1\' for src_idx = scale*dst_index .") + .SetDefault(1); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "an optional string from: \"NHWC\", \"NCHW\". " + "Specify that the data format of the input and output data is " + "channel_first or channel_last.") + .SetDefault("NCHW"); + AddComment(R"DOC( + Crop the input data to random size and aspect ratio. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made. + After applying crop transfrom, the input data will be resized to given size. + )DOC"); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + batch_random_crop_and_resize, ops::data::BatchRandomCropAndResizeOp, + ops::data::BatchRandomCropAndResizeOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(batch_random_crop_and_resize, + ops::data::BatchRandomCropAndResizeCPUKernel, + ops::data::BatchRandomCropAndResizeCPUKernel, + ops::data::BatchRandomCropAndResizeCPUKernel); diff --git a/paddle/fluid/operators/data/batch_random_crop_and_resize_op.cu b/paddle/fluid/operators/data/batch_random_crop_and_resize_op.cu new file mode 100644 index 00000000000000..bed4409aa777ab --- /dev/null +++ b/paddle/fluid/operators/data/batch_random_crop_and_resize_op.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/batch_random_crop_and_resize_op.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { +namespace data { + +using framework::LoDTensor; +using DataLayout = framework::DataLayout; +using LoDTensorBlockingQueueHolder = + operators::reader::LoDTensorBlockingQueueHolder; + +template +__global__ void KeNearestNeighborInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const size_t idx_h, const size_t idx_w, const bool align_corners, + const DataLayout data_format) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + // batch size + int out_id_h = tid / output_w; + // single image's index + int out_id_w = tid % output_w; + // input_w or output_w = c * h * w, img_size = h * w + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + // get output c, h, w index + int channel_id, out_img_idy, out_img_idx; + if (data_format == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + // get input h index with offset + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); + in_img_idy += idx_h; + // get input w index with offset + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx += idx_w; + + if (data_format == DataLayout::kNCHW) { + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + } else { + out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + } + } +} + +template +__global__ void KeBilinearInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const size_t idx_h, const size_t idx_w, const bool align_corners, + const int align_mode, const DataLayout data_format) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + // batch size + int out_id_h = tid / output_w; + // single image's index + int out_id_w = tid % output_w; + // input_w or output_w = c * h * w, img_size = h * w + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + // get output c, h, w index + int channel_id, out_img_idy, out_img_idx; + if (data_format == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + // get input h index with offset + int in_img_idy = align_flag + ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) + : static_cast(ratio_h * out_img_idy); + in_img_idy = (in_img_idy > 0) ? in_img_idy + idx_h : idx_h; + int h_id = (in_img_idy < in_img_h + idx_h - 1) ? 1 : 0; + T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = (src_h > 0) ? src_h + idx_h : idx_h; + T h1lambda = align_flag ? src_h - in_img_idy + : ratio_h * out_img_idy + idx_h - in_img_idy; + T h2lambda = 1.f - h1lambda; + + // get input w index with offset + int in_img_idx = align_flag + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx = (in_img_idx > 0) ? in_img_idx + idx_w : idx_w; + int w_id = (in_img_idx < in_img_w + idx_w - 1) ? 1 : 0; + T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = (src_w > 0) ? src_w + idx_w : idx_w; + T w1lambda = align_flag ? src_w - in_img_idx + : ratio_w * out_img_idx + idx_w - in_img_idx; + T w2lambda = 1.f - w1lambda; + + if (data_format == DataLayout::kNCHW) { + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id]); + } else { + const T* in_pos = + &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + h2lambda * + (w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] + + w1lambda * in_pos[h_id * in_img_w * num_channels + + w_id * num_channels]); + } + } +} + +template +static void BatchRandomCropAndResizeFwd( + const framework::ExecutionContext& ctx, const framework::LoDTensor& input, + framework::Tensor* output, const std::vector out_size, + const std::string interp_method, const bool align_corners, + const int align_mode, const int img_h, const int img_w, const int c, + const int idx_h, const int idx_w, const int crop_h, const int crop_w, + const DataLayout data_format) { + auto input_data = input.template data(); + int out_h = static_cast(out_size[0]); + int out_w = static_cast(out_size[1]); + + framework::DDim dim_out; + if (data_format == DataLayout::kNCHW) { + dim_out = {c, out_h, out_w}; + } else { + dim_out = {out_h, out_w, c}; + } + auto output_data = output->data(); + + if (img_h == crop_h && img_w == crop_w) { + framework::TensorCopy(input, ctx.GetPlace(), output); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(crop_h - 1) / (out_h - 1) + : static_cast(crop_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(crop_w - 1) / (out_w - 1) + : static_cast(crop_w) / out_w; + } + + int in_chw = c * crop_h * crop_w; + int out_chw = c * out_h * out_w; + + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_chw); + + if ("nearest" == interp_method) { + KeNearestNeighborInterpFw< + T><<>>( + input_data, crop_h, crop_w, 1, in_chw, output_data, out_h, out_w, 1, + out_chw, c, ratio_h, ratio_w, idx_h, idx_w, align_corners, data_format); + } else if ("bilinear" == interp_method) { + KeBilinearInterpFw<<>>( + input_data, crop_h, crop_w, 1, in_chw, output_data, out_h, out_w, 1, + out_chw, c, ratio_h, ratio_w, idx_h, idx_w, align_corners, align_mode, + data_format); + } +} + +static void GetCropParameters(const int height, const int width, + const std::vector scale, + const std::vector ratio, int* idx_h, + int* idx_w, int* crop_h, int* crop_w, + const int seed, int num_attempts = 10) { + double target_area, aspect_ratio; + double area = height * width; + std::vector log_ratio; + for (int i = 0; i < ratio.size(); i++) + log_ratio.push_back(std::log(ratio[i])); + std::default_random_engine engine(seed); + std::uniform_real_distribution dist_scale(scale[0], scale[1]); + std::uniform_real_distribution dist_log_ratio(log_ratio[0], + log_ratio[1]); + + for (int i = 0; i < num_attempts; i++) { + target_area = dist_scale(engine) * area; + aspect_ratio = std::exp(dist_log_ratio(engine)); + + *crop_w = + static_cast(std::round(std::sqrt(target_area * aspect_ratio))); + *crop_h = + static_cast(std::round(std::sqrt(target_area / aspect_ratio))); + if (*crop_w > 0 && *crop_w <= width && *crop_h > 0 && *crop_h <= height) { + std::uniform_int_distribution dist_crop_h(0, height - *crop_h); + *idx_h = dist_crop_h(engine); + std::uniform_int_distribution dist_crop_w(0, width - *crop_w); + *idx_w = dist_crop_w(engine); + return; + } + } + + // Fallback to central crop + float in_ratio = static_cast(width) / static_cast(height); + float min_ratio = ratio[0] > ratio[1] ? ratio[1] : ratio[0]; + float max_ratio = ratio[0] > ratio[1] ? ratio[0] : ratio[1]; + if (in_ratio < min_ratio) { + *crop_w = width; + *crop_h = static_cast(std::round(*crop_w / min_ratio)); + } else if (in_ratio > max_ratio) { + *crop_h = height; + *crop_w = static_cast(std::round(*crop_h * max_ratio)); + } else { + // return whole image + *crop_h = height; + *crop_w = width; + } + *idx_h = (height - *crop_h) / 2; + *idx_w = (width - *crop_w) / 2; +} + +template +class BatchRandomCropAndResizeCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::NotFound("This kernel only runs on GPU device.")); + // get input, output + auto x = ctx.MultiInput("X"); + PADDLE_ENFORCE_GT(x.size(), 0, + platform::errors::InvalidArgument( + "The size of X must be greater than 0.")); + auto* out = ctx.Output("Out"); + + auto aspect_ratio_min = ctx.Attr("aspect_ratio_min"); + auto aspect_ratio_max = ctx.Attr("aspect_ratio_max"); + AspectRatioRange aspect_ratio_range{aspect_ratio_min, aspect_ratio_max}; + + auto area_min = ctx.Attr("area_min"); + auto area_max = ctx.Attr("area_max"); + AreaRange area_range{area_min, area_max}; + + auto* generators = GeneratorManager::Instance()->GetGenerators( + x.size(), x.size(), aspect_ratio_range, area_range); + + const std::vector size = ctx.Attr>("size"); + + // get data_format + const std::string data_format_str = ctx.Attr("data_format"); + const DataLayout data_format = + framework::StringToDataLayout(data_format_str); + // get interpolation method + const std::string interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + auto* img = x.at(0); + int64_t img_c = + data_format == DataLayout::kNCHW ? img->dims()[0] : img->dims()[2]; + + std::vector out_dim; + if (data_format == DataLayout::kNCHW) { + out_dim = {static_cast(x.size()), img_c, size[0], size[1]}; + } else { + out_dim = {static_cast(x.size()), size[0], size[1], img_c}; + } + out->Resize(phi::make_ddim(out_dim)); + out->mutable_data(ctx.GetPlace()); + + int img_h, img_w, idx_h, idx_w, crop_h, crop_w; + for (int i = 0; i < x.size(); i++) { + img = x.at(i); + img_h = + data_format == DataLayout::kNCHW ? img->dims()[1] : img->dims()[0]; + img_w = + data_format == DataLayout::kNCHW ? img->dims()[2] : img->dims()[1]; + ROI roi; + generators->at(i)->GenerateRandomROI(img_w, img_h, &roi); + // GetCropParameters(img_h, img_w, scale, ratio, &idx_h, &idx_w, &crop_h, + // &crop_w, seed); + + auto out_tensor = out->Slice(i, i + 1); + BatchRandomCropAndResizeFwd(ctx, *img, &out_tensor, size, + interp_method, align_corners, align_mode, + img_h, img_w, img_c, roi.y, roi.x, roi.h, + roi.w, data_format); + } + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(batch_random_crop_and_resize, + ops::data::BatchRandomCropAndResizeCUDAKernel, + ops::data::BatchRandomCropAndResizeCUDAKernel, + ops::data::BatchRandomCropAndResizeCUDAKernel); diff --git a/paddle/fluid/operators/data/batch_random_crop_and_resize_op.h b/paddle/fluid/operators/data/batch_random_crop_and_resize_op.h new file mode 100644 index 00000000000000..76c0a334c13efc --- /dev/null +++ b/paddle/fluid/operators/data/batch_random_crop_and_resize_op.h @@ -0,0 +1,43 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_context.h" + +#include "paddle/fluid/operators/data/random_roi_generator.h" + +namespace paddle { +namespace operators { +namespace data { + +template +class BatchRandomCropAndResizeCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // no cpu kernel. + PADDLE_THROW(platform::errors::Unimplemented( + "BatchRandomCropAndResize op only supports GPU now.")); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/batch_resize_op.cc b/paddle/fluid/operators/data/batch_resize_op.cc new file mode 100644 index 00000000000000..564b200bed7f0c --- /dev/null +++ b/paddle/fluid/operators/data/batch_resize_op.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/batch_resize_op.h" + +namespace paddle { +namespace operators { +namespace data { + +class BatchResizeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + platform::errors::InvalidArgument( + "Inputs(X) of BatchResize should not be empty.")); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BatchResize"); + + auto size = ctx->Attrs().Get>("size"); + PADDLE_ENFORCE_EQ(size.size(), 2, + platform::errors::InvalidArgument( + "The length of Attrs(size) should be 2.")); + PADDLE_ENFORCE_GT(size[0], 0, platform::errors::InvalidArgument( + "h in Attr(size) of Op(BatchResize) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(size[1], 0, platform::errors::InvalidArgument( + "w in Attr(size) of Op(BatchResize) " + "should be greater than 0.")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::UINT8, + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "X") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place()); + } +}; + +class BatchResizeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(List[LoDTensor]). A batch of instances to random crop.") + .AsDuplicable(); + AddOutput("Out", "(Tensor). The cropped instance batch."); + AddAttr>( + "size", "expected output size of the crop, for each edge."); + AddAttr("interp_method", + "(string, default \"bilinear\"), interpolation " + "method, can be \"bilinear\" for " + "bilinear interpolation and \"nearest\" for nearest " + "neighbor interpolation.") + .SetDefault("bilinear"); + AddAttr( + "align_corners", + "an optional bool. Defaults to True. " + "If True, the centers of 4 corner pixels of the input and output " + "tensors are aligned, preserving the values at the corner pixels, " + "If False, are not aligned") + .SetDefault(true); + AddAttr("align_mode", + "(int, default \'1\'), optional for bilinear interpolation, " + "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , " + "can be \'1\' for src_idx = scale*dst_index .") + .SetDefault(1); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "an optional string from: \"NHWC\", \"NCHW\". " + "Specify that the data format of the input and output data is " + "channel_first or channel_last.") + .SetDefault("NCHW"); + AddComment(R"DOC( + Resize a batch of input images to given size. + )DOC"); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + batch_resize, ops::data::BatchResizeOp, ops::data::BatchResizeOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(batch_resize, + ops::data::BatchResizeCPUKernel, + ops::data::BatchResizeCPUKernel, + ops::data::BatchResizeCPUKernel) diff --git a/paddle/fluid/operators/data/batch_resize_op.cu b/paddle/fluid/operators/data/batch_resize_op.cu new file mode 100644 index 00000000000000..dc099daaa2ccc3 --- /dev/null +++ b/paddle/fluid/operators/data/batch_resize_op.cu @@ -0,0 +1,268 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/batch_resize_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { +namespace data { + +using framework::LoDTensor; +using DataLayout = framework::DataLayout; + +template +__global__ void KeNearestNeighborInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners, const DataLayout data_format) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + // batch size + int out_id_h = tid / output_w; + // single image's index + int out_id_w = tid % output_w; + // input_w or output_w = c * h * w, img_size = h * w + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + // get output c, h, w index + int channel_id, out_img_idy, out_img_idx; + if (data_format == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + // get input h index with offset + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); + // get input w index with offset + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); + + if (data_format == DataLayout::kNCHW) { + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + } else { + out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + } + } +} + +template +__global__ void KeBilinearInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners, const int align_mode, + const DataLayout data_format) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + bool align_flag = (align_mode == 0 && !align_corners); + for (; tid < nthreads; tid += stride) { + // batch size + int out_id_h = tid / output_w; + // single image's index + int out_id_w = tid % output_w; + // input_w or output_w = c * h * w, img_size = h * w + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + // get output c, h, w index + int channel_id, out_img_idy, out_img_idx; + if (data_format == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + // get input h index with offset + int in_img_idy = align_flag + ? static_cast(ratio_h * (out_img_idy + 0.5) - 0.5) + : static_cast(ratio_h * out_img_idy); + in_img_idy = in_img_idy > 0 ? in_img_idy : 0; + int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; + float src_h = ratio_h * (out_img_idy + 0.5) - 0.5; + src_h = src_h > 0 ? src_h : 0; + float h1lambda = + align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; + float h2lambda = 1.f - h1lambda; + + // get input w index with offset + int in_img_idx = align_flag + ? static_cast(ratio_w * (out_img_idx + 0.5) - 0.5) + : static_cast(ratio_w * out_img_idx); + in_img_idx = in_img_idx > 0 ? in_img_idx : 0; + int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; + float src_w = ratio_w * (out_img_idx + 0.5) - 0.5; + src_w = src_w > 0 ? src_w : 0; + float w1lambda = + align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; + float w2lambda = 1.f - w1lambda; + + if (data_format == DataLayout::kNCHW) { + const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + (T)(h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w] + + w1lambda * in_pos[h_id * in_img_w + w_id])); + } else { + const T* in_pos = + &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + + // bilinear interpolation + out[out_id_h * output_w + out_id_w] = + (T)(h2lambda * (w2lambda * in_pos[0] + + w1lambda * in_pos[w_id * num_channels]) + + h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] + + w1lambda * in_pos[h_id * in_img_w * num_channels + + w_id * num_channels])); + } + } +} + +template +static void ResizeFwd(const framework::ExecutionContext& ctx, + const framework::LoDTensor& input, + framework::Tensor* output, + const std::vector out_size, + const std::string interp_method, const bool align_corners, + const int align_mode, const int img_h, const int img_w, + const int c, const DataLayout data_format) { + auto input_data = input.template data(); + int out_h = static_cast(out_size[0]); + int out_w = static_cast(out_size[1]); + + framework::DDim dim_out; + if (data_format == DataLayout::kNCHW) { + dim_out = {c, out_h, out_w}; + } else { + dim_out = {out_h, out_w, c}; + } + auto output_data = output->data(); + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(img_h - 1) / (out_h - 1) + : static_cast(img_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(img_w - 1) / (out_w - 1) + : static_cast(img_w) / out_w; + } + + int in_chw = c * img_h * img_w; + int out_chw = c * out_h * out_w; + + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_chw); + + if ("nearest" == interp_method) { + KeNearestNeighborInterpFw< + T><<>>( + input_data, img_h, img_w, 1, in_chw, output_data, out_h, out_w, 1, + out_chw, c, ratio_h, ratio_w, align_corners, data_format); + } else if ("bilinear" == interp_method) { + KeBilinearInterpFw<<>>( + input_data, img_h, img_w, 1, in_chw, output_data, out_h, out_w, 1, + out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_format); + } +} + +template +class BatchResizeCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::NotFound("This kernel only runs on GPU device.")); + // get input, output + auto x = ctx.MultiInput("X"); + PADDLE_ENFORCE_GT(x.size(), 0, + platform::errors::InvalidArgument( + "The size of X must be greater than 0.")); + auto* out = ctx.Output("Out"); + + // get size, scale, ratio + auto size = ctx.Attr>("size"); + + const std::string data_format_str = ctx.Attr("data_format"); + const DataLayout data_format = + framework::StringToDataLayout(data_format_str); + // get interpolation method + const std::string interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + auto* img = x.at(0); + int64_t img_c = + data_format == DataLayout::kNCHW ? img->dims()[0] : img->dims()[2]; + + std::vector out_dim = {static_cast(x.size()), size[0], + size[1], img_c}; + if (data_format == DataLayout::kNCHW) { + out_dim = {static_cast(x.size()), img_c, size[0], size[1]}; + } + out->Resize(phi::make_ddim(out_dim)); + out->mutable_data(ctx.GetPlace()); + + int img_h, img_w, idx_h, idx_w, crop_h, crop_w; + for (int i = 0; i < x.size(); i++) { + img = x.at(i); + img_h = + data_format == DataLayout::kNCHW ? img->dims()[1] : img->dims()[0]; + img_w = + data_format == DataLayout::kNCHW ? img->dims()[2] : img->dims()[1]; + auto out_tensor = out->Slice(i, i + 1); + ResizeFwd(ctx, *img, &out_tensor, size, interp_method, align_corners, + align_mode, img_h, img_w, img_c, data_format); + } + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(batch_resize, + ops::data::BatchResizeCUDAKernel, + ops::data::BatchResizeCUDAKernel, + ops::data::BatchResizeCUDAKernel); diff --git a/paddle/fluid/operators/data/batch_resize_op.h b/paddle/fluid/operators/data/batch_resize_op.h new file mode 100644 index 00000000000000..cd39a8dd66272f --- /dev/null +++ b/paddle/fluid/operators/data/batch_resize_op.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { +namespace data { + +template +class BatchResizeCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // no cpu kernel. + PADDLE_THROW(platform::errors::Unimplemented( + "BatchResize op only supports GPU now.")); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/data_reader_op.cc b/paddle/fluid/operators/data/data_reader_op.cc new file mode 100644 index 00000000000000..af64d5acaf8d51 --- /dev/null +++ b/paddle/fluid/operators/data/data_reader_op.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/data_reader_op.h" + +namespace paddle { +namespace operators { +namespace data { + +// initialization static variables out of ReaderManager +ReaderManager* ReaderManager::rm_instance_ptr_ = nullptr; +std::mutex ReaderManager::m_; + +class DataReaderOp : public framework::OperatorBase { + public: + DataReaderOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext* ctx) const { + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "DataReaderOp"); + } + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + auto outputs = Outputs("Out"); + std::vector output_vars; + output_vars.reserve(outputs.size()); + for (auto& output : outputs) { + output_vars.emplace_back(scope.FindVar(output)); + } + + CheckAndInitOutputQueue(output_vars, /*capacity=*/2); + + auto batch_size = Attr("batch_size"); + auto num_samples = Attr("num_samples"); + auto shuffle = Attr("shuffle"); + auto drop_last = Attr("drop_last"); + auto seed = Attr("seed"); + auto rank = Attr("rank"); + auto world_size = Attr("world_size"); + auto indices_var_name = Attr("indices_var_name"); + auto output_var_names = Attr>("output_var_names"); + auto* reader_block = Attr("reader_block"); + auto reader_id = Attr("reader_id"); + + auto output_queues = GetQueueVecFromVariableVec(output_vars); + ReaderManager::Instance()->StartDataReader( + reader_id, reader_block, &scope, platform::CPUPlace(), indices_var_name, + output_var_names, output_queues, batch_size, num_samples, shuffle, + drop_last, seed, rank, world_size); + } +}; + +class DataReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "MapOp"); + } +}; + +class DataReaderInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override {} +}; + +class DataReaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Out", "The output queue variable of DataReader op") + .AsDuplicable(); + AddAttr("batch_size", "The batch size for reading samples") + .SetDefault(1); + AddAttr("num_samples", "The sample number in dataset"); + AddAttr("shuffle", "Whether shuffle the dataset").SetDefault(false); + AddAttr("drop_last", "Whether drop last incomplete batch") + .SetDefault(false); + AddAttr("seed", "Random seed for shuffle").SetDefault(0); + AddAttr("rank", "The logical rank of current device.").SetDefault(0); + AddAttr("world_size", "The number of running devices.").SetDefault(1); + AddAttr("reader_id", "The unique id to generate and get reader"); + AddAttr("reader_block", + "(BlockDesc *)" + "The global block of executed reader program " + "desc."); + AddAttr("indices_var_name", + "(string)" + "input variable names for sample indices"); + AddAttr>( + "output_var_names", + "(list of string)" + "output variable names for reader program"); + AddComment(R"DOC( + This operator read a file. +)DOC"); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::data; + +REGISTER_OPERATOR(data_reader, ops::DataReaderOp, ops::DataReaderOpMaker, + ops::DataReaderInferShape, ops::DataReaderInferVarType) + +REGISTER_OP_CPU_KERNEL(data_reader, ops::DataReaderCPUKernel) diff --git a/paddle/fluid/operators/data/data_reader_op.h b/paddle/fluid/operators/data/data_reader_op.h new file mode 100644 index 00000000000000..324fa9e1e50f34 --- /dev/null +++ b/paddle/fluid/operators/data/data_reader_op.h @@ -0,0 +1,366 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace operators { +namespace data { + +using Scope = framework::Scope; +using Variable = framework::Variable; +using BlockDesc = framework::BlockDesc; +using LoDTensor = framework::LoDTensor; +using LoDTensorArray = framework::LoDTensorArray; +using LoDTensorBlockingQueue = operators::reader::LoDTensorBlockingQueue; +using LoDTensorBlockingQueueHolder = + operators::reader::LoDTensorBlockingQueueHolder; + +class Sampler { + public: + explicit Sampler(const int64_t batch_size, const int64_t num_samples, + const bool shuffle, const bool drop_last, const int64_t seed, + const int rank, const int world_size) + : current_iter_(0), + batch_size_(batch_size), + shuffle_(shuffle), + drop_last_(drop_last), + rank_(rank), + world_size_(world_size) { + int trunc_num_samples; + if (drop_last) { + int total_batch_size = world_size * batch_size; + trunc_num_samples = + floor(num_samples / total_batch_size) * total_batch_size; + sample_ids_.reserve(trunc_num_samples); + } else { + sample_ids_.reserve(num_samples); + trunc_num_samples = num_samples; + } + for (int64_t i = 0; i < trunc_num_samples; i++) { + sample_ids_.emplace_back(i); + } + num_samples_ = sample_ids_.size(); + if (shuffle) { + rnd_.seed(seed); + std::shuffle(sample_ids_.begin(), sample_ids_.end(), rnd_); + } + } + + void GetNextIndices(std::vector* indices) { + int64_t start_idx = batch_size_ * world_size_ * current_iter_ + rank_; + current_iter_++; + + if (start_idx >= num_samples_) return; + + for (int64_t i = 0; i < batch_size_; i++) { + int cur_idx = start_idx + i * world_size_; + if (cur_idx >= num_samples_) return; + indices->emplace_back(sample_ids_[cur_idx]); + } + } + + void Reset() { + if (shuffle_) { + std::shuffle(sample_ids_.begin(), sample_ids_.end(), rnd_); + } + + current_iter_ = 0; + } + + private: + int64_t current_iter_; + const int64_t batch_size_; + const bool shuffle_; + int64_t num_samples_; + const bool drop_last_; + const int rank_; + const int world_size_; + + std::mt19937 rnd_; + std::vector sample_ids_; +}; + +class DataReader { + public: + explicit DataReader( + BlockDesc* reader_block, const Scope* scope, const platform::Place place, + const std::string& indices_var_name, + const std::vector& output_var_names, + const std::vector> output_queues, + const int batch_size, const int num_samples, const bool shuffle, + const bool drop_last, const int64_t seed, const int rank, + const int world_size) + : running_(true), + shutdown_(false), + reader_block_(reader_block), + place_(place), + indices_var_name_(indices_var_name), + output_var_names_(output_var_names), + output_queues_(output_queues), + batch_size_(batch_size), + sampler_(batch_size, num_samples, shuffle, drop_last, seed, rank, + world_size) { + StartReaderThread(scope); + } + + void StartReaderThread(const Scope* scope) { + if (reader_thread_.joinable()) { + return; + } + + reader_thread_ = std::thread([this, scope] { + auto& scope_ = scope->NewScope(); + framework::Executor executor(place_); + while (!shutdown_) { + // check running or shutdown + std::unique_lock lock(mutex_); + running_cond_.wait(lock, [this] { return running_ || shutdown_; }); + if (shutdown_) break; + + std::vector indices; + sampler_.GetNextIndices(&indices); + // shutdown reader if indices drained + if (indices.size() == 0) { + for (auto& queue : output_queues_) { + while (queue->Size()) sleep(0.5); + queue->Close(); + } + + running_ = false; + continue; + } + + ShareIndicesIntoScope(&scope_, indices); + + try { + executor.Run(*reader_block_->Program(), &scope_, + static_cast(reader_block_->ID()), false, true, + output_var_names_, false, true); + } catch (...) { + break; + } + + for (size_t i = 0; i < output_var_names_.size(); i++) { + auto* out_var = scope_.FindVar(output_var_names_[i]); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::NotFound( + "The output variable %s is not found in DataReader " + "program's internal scope", + output_var_names_[i])); + // CheckOutputVarStatus(*out_var, output_var_names_[i]); + + if (out_var->IsType()) { + framework::LoDTensorArray t_arr(1); + copy_tensor(out_var->Get(), &t_arr[0]); + output_queues_[i]->Push(t_arr); + } else { + auto out_arr = out_var->Get(); + framework::LoDTensorArray t_arr(out_arr.size()); + for (size_t i = 0; i < out_arr.size(); i++) { + copy_tensor(out_arr[i], &t_arr[i]); + } + output_queues_[i]->Push(t_arr); + } + } + } + scope->DeleteScope(&scope_); + }); + } + + void ShutDown() { + for (auto& queue : output_queues_) { + if (queue && !queue->IsClosed()) queue->Close(); + } + + shutdown_ = true; + running_ = false; + running_cond_.notify_all(); + + if (reader_thread_.joinable()) reader_thread_.join(); + } + + void Reset() { + // reopen all output queues + for (auto& queue : output_queues_) queue->ReOpen(); + + // reset sampler to regenerate indices + sampler_.Reset(); + + // set running flag to reset running + running_ = true; + running_cond_.notify_all(); + } + + void ShareIndicesIntoScope(Scope* scope, std::vector indices) { + auto* var = scope->Var(indices_var_name_); + + auto* indices_tensor = var->GetMutable(); + indices_tensor->Resize( + phi::make_ddim({static_cast(indices.size())})); + auto* indices_data = + indices_tensor->mutable_data(platform::CPUPlace()); + + for (size_t i = 0; i < indices.size(); i++) { + indices_data[i] = indices[i]; + } + } + + private: + bool running_; + std::condition_variable running_cond_; + bool shutdown_; + std::mutex mutex_; + + std::thread reader_thread_; + + BlockDesc* reader_block_; + platform::Place place_; + + std::string indices_var_name_; + std::vector output_var_names_; + std::vector> output_queues_; + + const int64_t batch_size_; + Sampler sampler_; + + void copy_tensor(const framework::LoDTensor& lod_tensor, + framework::LoDTensor* out) const { + if (lod_tensor.numel() == 0) return; + auto& out_tensor = *out; + framework::TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } +}; + +class ReaderManager { + private: + DISABLE_COPY_AND_ASSIGN(ReaderManager); + + static ReaderManager* rm_instance_ptr_; + static std::mutex m_; + + std::map> id_to_reader_; + + public: + static ReaderManager* Instance() { + if (rm_instance_ptr_ == nullptr) { + std::lock_guard lk(m_); + if (rm_instance_ptr_ == nullptr) { + rm_instance_ptr_ = new ReaderManager; + } + } + return rm_instance_ptr_; + } + + void StartDataReader( + const int64_t reader_id, BlockDesc* reader_block, const Scope* scope, + const platform::Place place, const std::string& indices_var_name, + const std::vector& output_var_names, + const std::vector>& output_queues, + const int batch_size, const int num_samples, const bool shuffle, + const bool drop_last, const int64_t seed, const int rank, + const int world_size) { + auto iter = id_to_reader_.find(reader_id); + if (iter == id_to_reader_.end()) { + id_to_reader_[reader_id] = std::unique_ptr(new DataReader( + reader_block, scope, place, indices_var_name, output_var_names, + output_queues, batch_size, num_samples, shuffle, drop_last, seed, + rank, world_size)); + } + } + + void ShutDownReader(const int64_t reader_id) { + auto iter = id_to_reader_.find(reader_id); + if (iter != id_to_reader_.end()) { + if (iter->second.get()) iter->second.get()->ShutDown(); + id_to_reader_.erase(reader_id); + } + } + + void ResetReader(const int64_t reader_id) { + auto iter = id_to_reader_.find(reader_id); + if (iter != id_to_reader_.end()) { + iter->second->Reset(); + } + } + + void ShutDown() { + auto iter = id_to_reader_.begin(); + while (iter != id_to_reader_.end()) { + if (iter->second.get()) { + iter->second.get()->ShutDown(); + } + iter++; + } + id_to_reader_.clear(); + } + + ReaderManager() {} + + ~ReaderManager() { ShutDown(); } +}; + +static void CheckAndInitOutputQueue(const std::vector& vars, + int capacity) { + for (auto var : vars) { + if (var->IsInitialized()) { + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "Output Variables of DataLoaderOp should hold " + "LoDTensorBlockingQueueHolder type")); + auto queue = var->Get().GetQueue(); + if (queue == nullptr) { + auto* holder = var->template GetMutable(); + holder->InitOnce(capacity); + } + } else { + auto* holder = var->GetMutable(); + holder->InitOnce(capacity); + } + } +} + +static std::vector> +GetQueueVecFromVariableVec(const std::vector& vars) { + std::vector> queues; + queues.reserve(vars.size()); + for (size_t i = 0; i < vars.size(); i++) { + queues.push_back(vars[i]->Get().GetQueue()); + } + return queues; +} + +template +class DataReaderCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/dataloader_op.cc b/paddle/fluid/operators/data/dataloader_op.cc new file mode 100644 index 00000000000000..46ede9d2a8fc7e --- /dev/null +++ b/paddle/fluid/operators/data/dataloader_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/data/dataloader_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/type_defs.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class DataLoaderOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "DataLoaderOp"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } +}; + +class DataLoaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Out", + "(vector)" + "The output tensors of DataLoader operator, also the fetch " + "targets of the loaded program.") + .AsDuplicable(); + AddAttr("global_block", + "(BlockDesc *)" + "The global block of executed dataloader program " + "desc."); + AddAttr("start_op_index", + "(int64_t)" + "The index of the op to start execution"); + AddAttr("end_op_index", + "(int64_t)" + "The index of the op to stop execution"); + AddAttr("program_id", + "(int64_t)" + "The unique hash id used as cache key for " + "ExecutorInfoCache"); + AddComment(R"DOC( + DataLoader OP + This OP runs DataPipeline programs to start up DataPipeline for + multi-thread and multi-stream data loading. For DataPipeline + program construct with :code:`paddle.io.map` and + :code:`paddle.io.data_reader`, which holds independent threads + and streams, so DataLoader Op simply initialize a ParallelExecutor + to run DataPipeline progran once. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(dataloader, ops::DataLoaderOp, ops::DataLoaderOpMaker); +REGISTER_OP_CPU_KERNEL( + dataloader, + ops::DataLoaderOpKernel); diff --git a/paddle/fluid/operators/data/dataloader_op.cu.cc b/paddle/fluid/operators/data/dataloader_op.cu.cc new file mode 100644 index 00000000000000..52dea24815fe1f --- /dev/null +++ b/paddle/fluid/operators/data/dataloader_op.cu.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/data/dataloader_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + dataloader, + ops::DataLoaderOpKernel); diff --git a/paddle/fluid/operators/data/dataloader_op.h b/paddle/fluid/operators/data/dataloader_op.h new file mode 100644 index 00000000000000..59273d463f759b --- /dev/null +++ b/paddle/fluid/operators/data/dataloader_op.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/data/pipeline.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +class Pipeline; + +template +class DataLoaderOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // Step1: get output vars and attrs + auto output_vars = ctx.MultiOutputVar("Out"); + auto output_var_names = ctx.OutputNames("Out"); + + auto* global_block = ctx.Attr("global_block"); + auto start_op_index = ctx.Attr("start_op_index"); + auto end_op_index = ctx.Attr("end_op_index"); + auto program_id = ctx.Attr("program_id"); + + auto pipeline = data::PipelineManager::Instance()->GetPipeline( + program_id, global_block, ctx.GetPlace(), start_op_index, end_op_index, + output_var_names); + + pipeline->ReadNext(output_vars); + + if (!pipeline->IsRunning()) { + PADDLE_THROW_EOF(); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/file_label_loader_op.cc b/paddle/fluid/operators/data/file_label_loader_op.cc new file mode 100644 index 00000000000000..a09daf5b5e85da --- /dev/null +++ b/paddle/fluid/operators/data/file_label_loader_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/file_label_loader_op.h" + +namespace paddle { +namespace operators { +namespace data { + +class FileLabelLoaderOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE_EQ(ctx->HasInput("Indices"), true, + platform::errors::InvalidArgument( + "Input(Indices) of ReadFileLoaderOp is null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Label"), true, + platform::errors::InvalidArgument( + "Output(Label) of ReadFileLoaderOp is null.")); + + auto dim_indices = ctx->GetInputDim("Indices"); + PADDLE_ENFORCE_EQ(dim_indices.size(), 1, + platform::errors::InvalidArgument( + "Input(Indices) should be a 1-D Tensor")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType(framework::proto::VarType::UINT8, + platform::CPUPlace()); + } +}; + +class FileLabelLoaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Indices", "The batch indices of input samples"); + AddOutput("Image", "The output image tensor of ReadFileLoader op") + .AsDuplicable(); + AddOutput("Label", "The output label tensor of ReadFileLoader op"); + AddAttr("data_root", "Path of root directory of dataset"); + AddComment(R"DOC( + This operator read ImageNet format dataset for :attr:`data_root` with + given indices. + There are 2 outputs: + 1. Image: a list of Tensor which holds the image bytes data + 2. Label: a Tensor with shape [N] and dtype as int64, N is the batch + size, which is the length of input indices. +)DOC"); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::data; + +REGISTER_OPERATOR( + file_label_loader, ops::FileLabelLoaderOp, ops::FileLabelLoaderOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker) + +REGISTER_OP_CPU_KERNEL(file_label_loader, + ops::FileLabelLoaderCPUKernel) diff --git a/paddle/fluid/operators/data/file_label_loader_op.h b/paddle/fluid/operators/data/file_label_loader_op.h new file mode 100644 index 00000000000000..ceba58293ef8a3 --- /dev/null +++ b/paddle/fluid/operators/data/file_label_loader_op.h @@ -0,0 +1,177 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace data { +using LoDTensor = framework::LoDTensor; +using LoDTensorArray = framework::LoDTensorArray; + +#ifdef _WIN32 +constexpr char DIR_SEP = '\\'; +#else +constexpr char DIR_SEP = '/'; +#endif + +static std::string JoinPath(const std::string path1, const std::string path2) { + // empty check + if (path1.empty()) return path2; + if (path1.empty()) return path1; + + // absolute path check + if (path2[0] == DIR_SEP) return path2; +#ifdef _WIN32 + if (path2[1] == ":") return path2; +#endif + + // concat path + if (path1[path1.length() - 1] == DIR_SEP) return path1 + path2; + return path1 + DIR_SEP + path2; +} + +static void ParseFilesAndLabels( + const std::string data_root, + std::vector>* samples) { + auto* dir = opendir(data_root.c_str()); + PADDLE_ENFORCE_NE(dir, nullptr, platform::errors::InvalidArgument( + "Cannot open directory %s", data_root)); + + // Step 1: parse classes info + std::vector classes; + auto* entry = readdir(dir); + while (entry) { + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { + entry = readdir(dir); + continue; + } + + auto cls_path = JoinPath(data_root, entry->d_name); + struct stat s; + int ret = stat(cls_path.c_str(), &s); + PADDLE_ENFORCE_EQ(ret, 0, platform::errors::InvalidArgument( + "Directory %s is unaccessiable.", cls_path)); + + if (S_ISDIR(s.st_mode)) classes.emplace_back(entry->d_name); + + entry = readdir(dir); + } + + closedir(dir); + + // sort directories in alphabetic order to generate class order + std::sort(classes.begin(), classes.end()); + + // Step 2: traverse directory to generate samples + for (int class_id = 0; class_id < static_cast(classes.size()); + class_id++) { + auto cur_dir = data_root + DIR_SEP + classes[class_id]; + dir = opendir(cur_dir.c_str()); + entry = readdir(dir); + while (entry) { + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { + entry = readdir(dir); + continue; + } + + auto file = cur_dir + DIR_SEP + entry->d_name; + samples->emplace_back(std::make_pair(file, class_id)); + + entry = readdir(dir); + } + closedir(dir); + } +} + +std::map>> + root_to_samples_; + +static std::vector>* GetFilesAndLabelsFromCache( + const std::string data_root) { + auto iter = root_to_samples_.find(data_root); + if (iter == root_to_samples_.end()) { + std::vector> samples; + ParseFilesAndLabels(data_root, &samples); + root_to_samples_[data_root] = samples; + } + + return &(root_to_samples_[data_root]); +} + +template +class FileLabelLoaderCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto image_arr = ctx.MultiOutput("Image"); + auto* label_tensor = ctx.Output("Label"); + + auto data_root = ctx.Attr("data_root"); + auto* samples = GetFilesAndLabelsFromCache(data_root); + + auto batch_size = indices->dims()[0]; + const int64_t* indices_data = indices->data(); + + label_tensor->Resize(phi::make_ddim({static_cast(batch_size)})); + auto* label_data = + label_tensor->mutable_data(platform::CPUPlace()); + for (int64_t i = 0; i < batch_size; i++) { + int64_t index = static_cast(indices_data[i]); + auto file = samples->at(index).first; + auto label = samples->at(index).second; + std::ifstream input(file.c_str(), + std::ios::in | std::ios::binary | std::ios::ate); + std::streamsize file_size = input.tellg(); + + input.seekg(0, std::ios::beg); + + auto image = image_arr[i]; + std::vector image_len = {file_size}; + image->Resize(phi::make_ddim(image_len)); + + uint8_t* data = image->mutable_data(platform::CPUPlace()); + + input.read(reinterpret_cast(data), file_size); + + label_data[i] = static_cast(label); + } + } + + private: + void copy_tensor(const framework::LoDTensor& lod_tensor, + framework::LoDTensor* out) const { + if (lod_tensor.numel() == 0) return; + auto& out_tensor = *out; + framework::TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/image_decoder.cc b/paddle/fluid/operators/data/image_decoder.cc new file mode 100644 index 00000000000000..975c05473774e6 --- /dev/null +++ b/paddle/fluid/operators/data/image_decoder.cc @@ -0,0 +1,337 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/data/image_decoder.h" + +namespace paddle { +namespace operators { +namespace data { + +ImageDecoder::ImageDecoder(int dev_id, size_t host_memory_padding, + size_t device_memory_padding) + : nvjpeg_streams_(2), pinned_buffers_(2), page_id_(0) { + platform::SetDeviceId(dev_id); + + // create nvjpeg handle and stream + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegCreateEx( + NVJPEG_BACKEND_HYBRID, &device_allocator_, &pinned_allocator_, 0, + &handle_)); + + // set pinned/device memory padding + if (host_memory_padding > 0) { + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegSetPinnedMemoryPadding(host_memory_padding, + handle_)); + } + if (device_memory_padding > 0) { + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegSetDeviceMemoryPadding(device_memory_padding, + handle_)); + } + + // create nvjpeg stream + for (size_t i = 0; i < nvjpeg_streams_.size(); i++) { + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegJpegStreamCreate( + handle_, &nvjpeg_streams_[i])); + } + + // create decode params, decoder and state + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegDecodeParamsCreate(handle_, &decode_params_)); + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegDecoderCreate( + handle_, NVJPEG_BACKEND_HYBRID, &decoder_)); + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegDecoderStateCreate(handle_, decoder_, &state_)); + + // create device & pinned buffer + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegBufferDeviceCreate( + handle_, &device_allocator_, &device_buffer_)); + for (size_t i = 0; i < pinned_buffers_.size(); i++) { + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegBufferPinnedCreate( + handle_, &pinned_allocator_, &pinned_buffers_[i])); + } +} + +ImageDecoder::~ImageDecoder() { + // destroy nvjpeg streams + for (size_t i = 0; i < nvjpeg_streams_.size(); i++) { + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegJpegStreamDestroy(nvjpeg_streams_[i])); + } + + // destroy decode params, decoder and state + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegDecodeParamsDestroy(decode_params_)); + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegDecoderDestroy(decoder_)); + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegJpegStateDestroy(state_)); + + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegBufferDeviceDestroy(device_buffer_)); + for (size_t i = 0; i < pinned_buffers_.size(); i++) { + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegBufferPinnedDestroy(pinned_buffers_[i])); + } + + // destroy nvjpeg handle at last + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegDestroy(handle_)); +} + +void ImageDecoder::CPUDecodeRandomCrop(const uint8_t* data, size_t length, + RandomROIGenerator* roi_generator, + unsigned char* workspace, + size_t workspace_size, + framework::LoDTensor* out, + platform::Place place) { +#ifdef PADDLE_WITH_OPENCV + cv::Mat image = cv::imdecode( + cv::Mat(1, length, CV_8UC1, const_cast(data)), + cv::IMREAD_COLOR); + + cv::Mat cropped; + int height = image.rows; + int width = image.cols; + if (roi_generator) { + ROI roi; + roi_generator->GenerateRandomROI(image.cols, image.rows, &roi); + cv::Rect cv_roi; + cv_roi.x = roi.x; + cv_roi.y = roi.y; + cv_roi.width = roi.w; + cv_roi.height = roi.h; + height = roi.h; + width = roi.w; + + image(cv_roi).copyTo(cropped); + } else { + cropped = image; + } + + // allocate cpu tensor and memory + framework::LoDTensor cpu_tensor; + std::vector out_shape = {height, width, 3}; + cpu_tensor.Resize(phi::make_ddim(out_shape)); + auto* cpu_data = cpu_tensor.mutable_data(platform::CPUPlace()); + + cv::Mat cpu_mat(height, width, CV_8UC3, const_cast(cpu_data), + cv::Mat::AUTO_STEP); + cv::cvtColor(cropped, cpu_mat, cv::COLOR_BGR2RGB); + + // copy cpu tensor to output gpu tensor + framework::TensorCopySync(cpu_tensor, place, out); +#else + PADDLE_THROW(platform::errors::Fatal( + "Nvjpeg decode failed and Paddle is not compiled with OpenCV")); +#endif +} + +nvjpegStatus_t ImageDecoder::ParseDecodeParams( + const uint8_t* bit_stream, size_t bit_len, framework::LoDTensor* out, + RandomROIGenerator* roi_generator, nvjpegImage_t* out_image, + platform::Place place) { + int components; + nvjpegChromaSubsampling_t subsampling; + int widths[NVJPEG_MAX_COMPONENT]; + int heights[NVJPEG_MAX_COMPONENT]; + + nvjpegStatus_t status = platform::dynload::nvjpegGetImageInfo( + handle_, bit_stream, bit_len, &components, &subsampling, widths, heights); + + if (status != NVJPEG_STATUS_SUCCESS) return status; + + int64_t width = static_cast(widths[0]); + int64_t height = static_cast(heights[0]); + + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegDecodeParamsSetOutputFormat(decode_params_, + NVJPEG_OUTPUT_RGBI)); + + if (roi_generator) { + ROI roi; + roi_generator->GenerateRandomROI(width, height, &roi); + + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegDecodeParamsSetROI( + decode_params_, roi.x, roi.y, roi.w, roi.h)); + height = roi.h; + width = roi.w; + } + + std::vector out_shape = {height, width, 3}; + out->Resize(phi::make_ddim(out_shape)); + + // allocate memory and assign to out_image + auto* data = out->mutable_data(place); + out_image->channel[0] = data; + out_image->pitch[0] = width * 3; + + return NVJPEG_STATUS_SUCCESS; +} + +nvjpegStatus_t ImageDecoder::GPUDecodeRandomCrop(const uint8_t* bit_stream, + size_t bit_len, + nvjpegImage_t* out_image) { + auto buffer = pinned_buffers_[page_id_]; + auto stream = nvjpeg_streams_[page_id_]; + page_id_ ^= 1; + + // decode jpeg in host to pinned buffer + PADDLE_ENFORCE_NVJPEG_SUCCESS(platform::dynload::nvjpegJpegStreamParse( + handle_, bit_stream, bit_len, false, false, stream)); + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegStateAttachPinnedBuffer(state_, buffer)); + nvjpegStatus_t status = platform::dynload::nvjpegDecodeJpegHost( + handle_, decoder_, state_, decode_params_, stream); + if (status != NVJPEG_STATUS_SUCCESS) return status; + + // transfer and decode to device buffer + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegStateAttachDeviceBuffer(state_, device_buffer_)); + PADDLE_ENFORCE_NVJPEG_SUCCESS( + platform::dynload::nvjpegDecodeJpegTransferToDevice( + handle_, decoder_, state_, stream, cuda_stream_)); + status = platform::dynload::nvjpegDecodeJpegDevice(handle_, decoder_, state_, + out_image, nullptr); + return status; +} + +void ImageDecoder::Run(const uint8_t* bit_stream, size_t bit_len, + framework::LoDTensor* out, + RandomROIGenerator* roi_generator, + const platform::Place& place) { + nvjpegImage_t image; + + nvjpegStatus_t status = + ParseDecodeParams(bit_stream, bit_len, out, roi_generator, &image, place); + if (status != NVJPEG_STATUS_SUCCESS) { + CPUDecodeRandomCrop(bit_stream, bit_len, roi_generator, nullptr, 0, out, + place); + return; + } + + status = GPUDecodeRandomCrop(bit_stream, bit_len, &image); + if (status != NVJPEG_STATUS_SUCCESS) { + CPUDecodeRandomCrop(bit_stream, bit_len, roi_generator, nullptr, 0, out, + place); + } +} + +ImageDecoderThreadPool::ImageDecoderThreadPool( + const int num_threads, const int dev_id, const size_t host_memory_padding, + const size_t device_memory_padding) + : threads_(num_threads), + dev_id_(dev_id), + shutdown_(false), + running_(false), + completed_(false), + outstand_tasks_(0) { + PADDLE_ENFORCE_GT(num_threads, 0, + platform::errors::InvalidArgument( + "num_threads shoule be a positive interger, " + "but got %d", + num_threads)); + for (int i = 0; i < num_threads; i++) { + threads_.emplace_back( + std::thread(std::bind(&ImageDecoderThreadPool::ThreadLoop, this, i, + host_memory_padding, device_memory_padding))); + } +} + +ImageDecoderThreadPool::~ImageDecoderThreadPool() { ShutDown(); } + +void ImageDecoderThreadPool::AddTask(std::shared_ptr task) { + task_queue_.push_back(task); +} + +void ImageDecoderThreadPool::RunAll(const bool wait, const bool sort) { + // Sort images in length desending order + if (sort) SortTaskByLengthDescend(); + + { + std::lock_guard lock(mutex_); + completed_ = false; + running_ = true; + } + running_cond_.notify_all(); + + if (wait) WaitTillTasksCompleted(); +} + +void ImageDecoderThreadPool::WaitTillTasksCompleted() { + std::unique_lock lock(mutex_); + completed_cond_.wait(lock, [this] { return this->completed_; }); + running_ = false; +} + +void ImageDecoderThreadPool::ShutDown() { + std::unique_lock lock(mutex_); + running_ = false; + shutdown_ = true; + running_cond_.notify_all(); + lock.unlock(); + + task_queue_.clear(); + + for (auto& thread : threads_) { + if (thread.joinable()) thread.join(); + } +} + +void ImageDecoderThreadPool::SortTaskByLengthDescend() { + std::lock_guard lock(mutex_); + std::sort(task_queue_.begin(), task_queue_.end(), + [](const std::shared_ptr a, + const std::shared_ptr b) { + return b->bit_len < a->bit_len; + }); +} + +void ImageDecoderThreadPool::ThreadLoop(const int thread_idx, + const size_t host_memory_padding, + const size_t device_memory_padding) { + ImageDecoder* decoder = + new ImageDecoder(dev_id_, host_memory_padding, device_memory_padding); + while (!shutdown_) { + std::unique_lock lock(mutex_); + running_cond_.wait(lock, [this] { + return (running_ && !task_queue_.empty()) || shutdown_; + }); + if (shutdown_) break; + + auto task = task_queue_.front(); + task_queue_.pop_front(); + outstand_tasks_++; + lock.unlock(); + + decoder->Run(task->bit_stream, task->bit_len, task->tensor, + task->roi_generator, task->place); + + lock.lock(); + outstand_tasks_--; + if (outstand_tasks_ == 0 && task_queue_.empty()) { + completed_ = true; + lock.unlock(); + completed_cond_.notify_one(); + } + } +} + +// initialization static variables out of ImageDecoderThreadPoolManager +ImageDecoderThreadPoolManager* ImageDecoderThreadPoolManager::pm_instance_ptr_ = + nullptr; +std::mutex ImageDecoderThreadPoolManager::m_; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/image_decoder.h b/paddle/fluid/operators/data/image_decoder.h new file mode 100644 index 00000000000000..6882b7c9364901 --- /dev/null +++ b/paddle/fluid/operators/data/image_decoder.h @@ -0,0 +1,195 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#ifdef PADDLE_WITH_OPENCV +#include +#endif + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/dynload/nvjpeg.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/stream/cuda_stream.h" + +#include "paddle/fluid/operators/data/random_roi_generator.h" + +namespace paddle { +namespace operators { +namespace data { + +static int dev_malloc(void** p, size_t s) { + return static_cast(cudaMalloc(p, s)); +} + +static int dev_free(void* p) { return static_cast(cudaFree(p)); } + +static int host_malloc(void** p, size_t s, unsigned int f) { + return static_cast(cudaHostAlloc(p, s, f)); +} + +static int host_free(void* p) { return static_cast(cudaFreeHost(p)); } + +struct ImageDecodeTask { + const uint8_t* bit_stream; + size_t bit_len; + framework::LoDTensor* tensor; + RandomROIGenerator* roi_generator; + platform::Place place; +}; + +class ImageDecoder { + public: + ImageDecoder(int dev_id, size_t host_memory_padding = 0, + size_t device_memory_padding = 0); + + ~ImageDecoder(); + + void Run(const uint8_t* bit_stream, size_t bit_len, framework::LoDTensor* out, + RandomROIGenerator* roi_generator, const platform::Place& place); + + private: + DISABLE_COPY_AND_ASSIGN(ImageDecoder); + + void CPUDecodeRandomCrop(const uint8_t* data, size_t length, + RandomROIGenerator* roi_generator, + unsigned char* workspace, size_t workspace_size, + framework::LoDTensor* out, platform::Place place); + + nvjpegStatus_t ParseDecodeParams(const uint8_t* bit_stream, size_t bit_len, + framework::LoDTensor* out, + RandomROIGenerator* roi_generator, + nvjpegImage_t* out_image, + platform::Place place); + + nvjpegStatus_t GPUDecodeRandomCrop(const uint8_t* bit_stream, size_t bit_len, + nvjpegImage_t* out_image); + + cudaStream_t cuda_stream_ = nullptr; + std::vector nvjpeg_streams_; + + nvjpegHandle_t handle_ = nullptr; + nvjpegJpegState_t state_ = nullptr; + nvjpegJpegDecoder_t decoder_ = nullptr; + nvjpegDecodeParams_t decode_params_ = nullptr; + + nvjpegPinnedAllocator_t pinned_allocator_ = {&host_malloc, &host_free}; + nvjpegDevAllocator_t device_allocator_ = {&dev_malloc, &dev_free}; + std::vector pinned_buffers_; + nvjpegBufferDevice_t device_buffer_ = nullptr; + + int page_id_; +}; + +class ImageDecoderThreadPool { + public: + ImageDecoderThreadPool(const int num_threads, const int dev_id, + size_t host_memory_padding, + size_t device_memory_padding); + + ~ImageDecoderThreadPool(); + + void AddTask(std::shared_ptr task); + + void RunAll(const bool wait, const bool sort = true); + + void WaitTillTasksCompleted(); + + void ShutDown(); + + private: + DISABLE_COPY_AND_ASSIGN(ImageDecoderThreadPool); + + void SortTaskByLengthDescend(); + + void ThreadLoop(const int thread_idx, const size_t host_memory_padding, + const size_t device_memory_padding); + + std::vector threads_; + int dev_id_; + + std::deque> task_queue_; + std::mutex mutex_; + + bool shutdown_; + std::condition_variable running_cond_; + bool running_; + std::condition_variable completed_cond_; + bool completed_; + + int outstand_tasks_; +}; + +class ImageDecoderThreadPoolManager { + private: + DISABLE_COPY_AND_ASSIGN(ImageDecoderThreadPoolManager); + + static ImageDecoderThreadPoolManager* pm_instance_ptr_; + static std::mutex m_; + + std::map> prog_id_to_pool_; + + public: + static ImageDecoderThreadPoolManager* Instance() { + if (pm_instance_ptr_ == nullptr) { + std::lock_guard lk(m_); + if (pm_instance_ptr_ == nullptr) { + pm_instance_ptr_ = new ImageDecoderThreadPoolManager; + } + } + return pm_instance_ptr_; + } + + ImageDecoderThreadPool* GetDecoderThreadPool( + const int64_t program_id, const int num_threads, const int dev_id, + const size_t host_memory_padding, const size_t device_memory_padding) { + auto iter = prog_id_to_pool_.find(program_id); + if (iter == prog_id_to_pool_.end()) { + prog_id_to_pool_[program_id] = + std::unique_ptr(new ImageDecoderThreadPool( + num_threads, dev_id, host_memory_padding, device_memory_padding)); + } + return prog_id_to_pool_[program_id].get(); + } + + void ShutDownDecoder(const int64_t program_id) { + auto iter = prog_id_to_pool_.find(program_id); + if (iter != prog_id_to_pool_.end()) { + iter->second.get()->ShutDown(); + prog_id_to_pool_.erase(program_id); + } + } + + void ShutDown() { + if (prog_id_to_pool_.empty()) return; + + std::lock_guard lk(m_); + auto iter = prog_id_to_pool_.begin(); + for (; iter != prog_id_to_pool_.end(); iter++) { + if (iter->second.get()) iter->second.get()->ShutDown(); + } + } + + ImageDecoderThreadPoolManager() {} + + ~ImageDecoderThreadPoolManager() { ShutDown(); } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/map_op.cc b/paddle/fluid/operators/data/map_op.cc new file mode 100644 index 00000000000000..519b97df48407f --- /dev/null +++ b/paddle/fluid/operators/data/map_op.cc @@ -0,0 +1,124 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/data/map_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/type_defs.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class MapOp : public framework::OperatorBase { + public: + MapOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext* ctx) const { + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "MapOp"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + // Step1: get output vars and attrs + auto inputs = Inputs("In"); + std::vector input_vars; + input_vars.reserve(inputs.size()); + for (auto& input : inputs) { + input_vars.emplace_back(scope.FindVar(input)); + } + + auto outputs = Outputs("Out"); + std::vector output_vars; + output_vars.reserve(outputs.size()); + for (auto& output : outputs) { + output_vars.emplace_back(scope.FindVar(output)); + } + + CheckInputQueueStatus(input_vars); + CheckAndInitOutputQueue(output_vars, /*capacity=*/2); + + auto input_var_names = Attr>("input_var_names"); + auto output_var_names = Attr>("output_var_names"); + auto* map_block = Attr("map_block"); + auto program_id = Attr("program_id"); + + auto input_queues = GetQueueVecFromVariableVec(input_vars); + auto output_queues = GetQueueVecFromVariableVec(output_vars); + data::MapRunnerManager::Instance()->StartMapRunner( + map_block, program_id, &scope, dev_place, input_var_names, + output_var_names, input_queues, output_queues); + } +}; + +class MapInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", "MapOp"); + } +}; + +class MapInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override {} +}; + +class MapOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("In", + "(LoDTensorBlockingQueueHolder)" + "The output tensors of Map operator") + .AsDuplicable(); + AddOutput("Out", + "(LoDTensorBlockingQueueHolder)" + "The output tensors of Map operator") + .AsDuplicable(); + AddAttr("map_block", + "(BlockDesc *)" + "The global block of executed map program " + "desc."); + AddAttr("program_id", + "(int64_t)" + "The unique hash id used as cache key for " + "ExecutorInfoCache"); + AddAttr>("input_var_names", + "(list of string)" + "input variable names for map program"); + AddAttr>("output_var_names", + "(list of string)" + "output variable names for map program"); + AddComment(R"DOC( + This OP used to split data loading stages of DataPipeline, the + map function will be run in independent C++ thread and stream. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(map, ops::MapOp, ops::MapOpMaker, ops::MapInferShape, + ops::MapInferVarType); +REGISTER_OP_CPU_KERNEL( + map, ops::MapOpKernel); diff --git a/paddle/fluid/operators/data/map_op.cu.cc b/paddle/fluid/operators/data/map_op.cu.cc new file mode 100644 index 00000000000000..ac9009bfced1e0 --- /dev/null +++ b/paddle/fluid/operators/data/map_op.cu.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/data/map_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + map, ops::MapOpKernel); diff --git a/paddle/fluid/operators/data/map_op.h b/paddle/fluid/operators/data/map_op.h new file mode 100644 index 00000000000000..7a2266d06c57aa --- /dev/null +++ b/paddle/fluid/operators/data/map_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/data/map_runner.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { + +using Variable = framework::Variable; +using LoDTensor = framework::LoDTensor; +using LoDTensorBlockingQueueHolder = + operators::reader::LoDTensorBlockingQueueHolder; + +static void CheckInputQueueStatus(const std::vector& vars) { + for (auto var : vars) { + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "Input Variables of MapOp should hold " + "LoDTensorBlockingQueueHolder type")); + auto queue = var->Get().GetQueue(); + PADDLE_ENFORCE_NE(queue, nullptr, + platform::errors::InvalidArgument( + "Input LoDTensorBlockingQueue is not initialized")); + } +} + +static void CheckAndInitOutputQueue(const std::vector& vars, + int capacity) { + for (auto var : vars) { + if (var->IsInitialized()) { + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "Output Variables of MapOp should hold " + "LoDTensorBlockingQueueHolder type")); + auto queue = var->Get().GetQueue(); + if (queue == nullptr) { + auto* holder = var->template GetMutable(); + holder->InitOnce(capacity); + } + } else { + auto* holder = var->GetMutable(); + holder->InitOnce(capacity); + } + } +} + +static std::vector> +GetQueueVecFromVariableVec(const std::vector& vars) { + std::vector> queues; + queues.reserve(vars.size()); + for (size_t i = 0; i < vars.size(); i++) { + queues.push_back(vars[i]->Get().GetQueue()); + } + return queues; +} + +template +class MapOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/map_runner.cc b/paddle/fluid/operators/data/map_runner.cc new file mode 100644 index 00000000000000..300cfdd246fdcd --- /dev/null +++ b/paddle/fluid/operators/data/map_runner.cc @@ -0,0 +1,220 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include + +#include "paddle/fluid/framework/executor_cache.h" +#include "paddle/fluid/operators/data/map_runner.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace operators { +namespace data { + +MapRunner::MapRunner( + const std::shared_ptr map_block, const int64_t program_id, + const Scope* scope, const platform::Place& place, + const std::vector& input_var_names, + const std::vector& output_var_names, + const std::vector> input_queues, + const std::vector> output_queues) + : running_(true), + shutdown_(false), + map_block_(map_block), + program_id_(program_id), + place_(place), + input_var_names_(input_var_names), + output_var_names_(output_var_names), + input_queues_(input_queues), + output_queues_(output_queues) { + + PADDLE_ENFORCE_EQ( + input_var_names_.size(), input_queues_.size(), + platform::errors::InvalidArgument( + "input_var_names length should be equal to input_queues length, " + "but recieve %d != %d.", + input_var_names_.size(), input_queues_.size())); + PADDLE_ENFORCE_EQ( + output_var_names_.size(), output_queues_.size(), + platform::errors::InvalidArgument( + "output_var_names length should be equal to output_queues length, " + "but recieve %d != %d.", + output_var_names_.size(), output_queues_.size())); + + StartMapThread(scope); +} + +bool MapRunner::ShareInputsIntoScope(Scope* scope) { + for (size_t i = 0; i < input_queues_.size(); i++) { + // If input queue closed, namely EOE(end of epoch) from + // dataset reader to here, read failed + auto queue = input_queues_[i]; + + // read LoDTensorArray from queue + bool success = true; + auto tensor_arr = queue->Pop(&success); + if (!success) return false; + + if (tensor_arr.size() == 1) { + // input array length = 1, treat input type as LoDTensor + // FIXME(dkp): this may incur error if batch size = 1 + auto tensor = tensor_arr[0]; + if (!tensor.IsInitialized()) return false; + + // get dst variable from scope and check status + auto name = input_var_names_[i]; + auto* var = scope->Var(name); + + // share input tensor to dst variable + auto* dst_tensor = var->GetMutable(); + dst_tensor->ShareDataWith(tensor); + dst_tensor->set_lod(tensor.lod()); + } else { + // input array length > 1 treat input type as LoDTensorArray + for (auto tensor : tensor_arr) { + if (!tensor.IsInitialized()) return false; + } + + // get dst variable from scope and check status + auto name = input_var_names_[i]; + auto* var = scope->Var(name); + + // share input tensor to dst variable + auto& dst_tensor_arr = *(var->GetMutable()); + for (auto& tensor : dst_tensor_arr) tensor.clear(); + dst_tensor_arr.clear(); + dst_tensor_arr.reserve(tensor_arr.size()); + for (size_t i = 0; i < tensor_arr.size(); i++) { + dst_tensor_arr.emplace_back(tensor_arr[i]); + } + } + } + return true; +} + +void signal_handler(int sig_num) { _exit(-1); } + +void MapRunner::StartMapThread(const Scope* scope) { + map_thread_ = std::thread([this, scope]() -> void { + // MapThread may crash with SIGSEGV singal in Executor::Prepare + // when Python program break and exit, catch SIGSEGV singal and + // exit thread silently + signal(SIGSEGV, signal_handler); + + auto& scope_ = scope->NewScope(); + framework::Executor executor(place_); + while (!shutdown_) { + // check running or shutdown + std::unique_lock lock(mutex_); + running_cond_.wait(lock, [this] { return running_ || shutdown_; }); + if (shutdown_) break; + + // Step 1: get input LoDTensor and share into Scope + bool success = ShareInputsIntoScope(&scope_); + if (!success) { + for (auto& queue : output_queues_) { + while (queue->Size()) sleep(0.5); + queue->Close(); + } + running_ = false; + continue; + } + + // Step 2: run ops by executor without fetch + try { + executor.Run(*map_block_->Program(), &scope_, + static_cast(map_block_->ID()), false, true, + output_var_names_, false, true); + } catch (...) { + break; + } + + // Step 3: fetch output variable to LoDTensor vector + // and push to output queue + for (size_t i = 0; i < output_var_names_.size(); i++) { + auto* out_var = scope_.FindVar(output_var_names_[i]); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::NotFound( + "The output variable %s is not found in Map " + "program's internal scope", + output_var_names_[i])); + CheckOutputVarStatus(*out_var, output_var_names_[i]); + + if (out_var->IsType()) { + framework::LoDTensorArray t_arr(1); + copy_tensor(out_var->Get(), &t_arr[0]); + output_queues_[i]->Push(t_arr); + } else { + auto out_arr = out_var->Get(); + framework::LoDTensorArray t_arr(out_arr.size()); + for (size_t i = 0; i < out_arr.size(); i++) { + copy_tensor(out_arr[i], &t_arr[i]); + } + output_queues_[i]->Push(t_arr); + } + } + } + scope->DeleteScope(&scope_); + }); +} + +void MapRunner::CheckOutputVarStatus(const Variable& var, + const std::string& var_name) { + // only LoDTensor & LoDTensorArray variable type support currently + if (var.IsType()) { + PADDLE_ENFORCE_EQ(var.Get().IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor in output variable %s get from Map" + "program's internal scope is not initialized.", + var_name)); + } else if (var.IsType()) { + auto tensor_array = var.Get(); + for (auto tensor : tensor_array) { + PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor in LoDTensorArray of output " + "variable %s get from Map program's internal " + "scope is not initialized.", + var_name)); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "MapOp can only support LoDTensor or LoDTensorArray")); + } +} + +void MapRunner::ShutDown() { + // close all output queue, op after this op can shutdown itself + for (auto queue : output_queues_) { + if (queue && !queue->IsClosed()) queue->Close(); + } + + shutdown_ = true; + running_ = false; + running_cond_.notify_all(); + + if (map_thread_.joinable()) map_thread_.join(); +} + +void MapRunner::Reset() { + for (auto queue : output_queues_) queue->ReOpen(); + + running_ = true; + running_cond_.notify_all(); +} + +// initialization static variables out of MapRunnerManager +MapRunnerManager* MapRunnerManager::pm_instance_ptr_ = nullptr; +std::mutex MapRunnerManager::m_; + +} // data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/map_runner.h b/paddle/fluid/operators/data/map_runner.h new file mode 100644 index 00000000000000..2d33bdf79a2581 --- /dev/null +++ b/paddle/fluid/operators/data/map_runner.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include + +#include "paddle/fluid/framework/parallel_executor.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { + +using BlockDesc = framework::BlockDesc; +using Scope = framework::Scope; + +using Variable = framework::Variable; +using LoDTensor = framework::LoDTensor; +using LoDTensorArray = framework::LoDTensorArray; +using LoDTensorBlockingQueue = operators::reader::LoDTensorBlockingQueue; +using LoDTensorBlockingQueueHolder = + operators::reader::LoDTensorBlockingQueueHolder; + +namespace data { + +class MapRunner { + public: + MapRunner( + const std::shared_ptr map_block, const int64_t program_id, + const Scope *scope, const platform::Place &place, + const std::vector &input_var_names, + const std::vector &output_var_names, + const std::vector> input_queues, + const std::vector> output_queues); + + ~MapRunner() { ShutDown(); } + + void ShutDown(); + + void Reset(); + + inline bool IsRunning() { return running_; } + + private: + void copy_tensor(const framework::LoDTensor &lod_tensor, + framework::LoDTensor *out) const { + if (lod_tensor.numel() == 0) return; + auto &out_tensor = *out; + framework::TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } + + bool ShareInputsIntoScope(Scope *scope); + + void StartMapThread(const Scope *scope); + + void CheckInputVarStatus(const Variable &var, const std::string &var_name); + void CheckOutputVarStatus(const Variable &var, const std::string &var_name); + + std::thread map_thread_; + bool running_; + std::condition_variable running_cond_; + bool shutdown_; + std::mutex mutex_; + + std::shared_ptr map_block_; + int64_t program_id_; + platform::Place place_; + + std::vector input_var_names_; + std::vector output_var_names_; + std::vector> input_queues_; + std::vector> output_queues_; +}; + +class MapRunnerManager { + // MapRunnerManager is a signleton manager for MapRunner, we + // create single MapRunner for a program id + private: + DISABLE_COPY_AND_ASSIGN(MapRunnerManager); + + static MapRunnerManager *pm_instance_ptr_; + static std::mutex m_; + + std::map> prog_id_to_runner_; + + public: + static MapRunnerManager *Instance() { + if (pm_instance_ptr_ == nullptr) { + std::lock_guard lk(m_); + if (pm_instance_ptr_ == nullptr) { + pm_instance_ptr_ = new MapRunnerManager; + } + } + return pm_instance_ptr_; + } + + void StartMapRunner( + BlockDesc *map_block, const int64_t program_id, const Scope *scope, + const platform::Place &place, + const std::vector &input_var_names, + const std::vector &output_var_names, + const std::vector> &input_queues, + const std::vector> + &output_queues) { + auto iter = prog_id_to_runner_.find(program_id); + if (iter == prog_id_to_runner_.end()) { + prog_id_to_runner_[program_id] = std::unique_ptr(new MapRunner( + std::shared_ptr(map_block), program_id, scope, place, + input_var_names, output_var_names, input_queues, output_queues)); + } + } + + void ShutDownMapRunner(int program_id) { + std::lock_guard lk(m_); + auto iter = prog_id_to_runner_.find(program_id); + if (iter != prog_id_to_runner_.end()) { + if (iter->second.get()) iter->second.get()->ShutDown(); + prog_id_to_runner_.erase(iter); + } + } + + void ResetMapRunner(int program_id) { + std::lock_guard lk(m_); + auto iter = prog_id_to_runner_.find(program_id); + if (iter != prog_id_to_runner_.end()) { + iter->second.get()->Reset(); + } + } + + void ShutDown() { + if (prog_id_to_runner_.empty()) return; + + std::lock_guard lk(m_); + auto iter = prog_id_to_runner_.begin(); + for (; iter != prog_id_to_runner_.end(); iter++) { + if (iter->second.get()) iter->second.get()->ShutDown(); + } + } + + MapRunnerManager() {} + + ~MapRunnerManager() { ShutDown(); } +}; + +} // data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/mirror_normalize_op.cc b/paddle/fluid/operators/data/mirror_normalize_op.cc new file mode 100644 index 00000000000000..6608d217ab8f47 --- /dev/null +++ b/paddle/fluid/operators/data/mirror_normalize_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/data/mirror_normalize_op.h" + +namespace paddle { +namespace operators { +namespace data { + +class MirrorNormalizeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::NotFound( + "Input(X) of MirrorNormalizeOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Mirror"), true, + platform::errors::NotFound( + "Input(Mirror) of MirrorNormalizeOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of MirrorNormalizeOp should not be null.")); + + auto x_dims = ctx->GetInputDim("X"); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + x_dims.size(), 4, + platform::errors::NotFound( + "Input(X) of MirrorNormalizeOp should be a 4-D Tensor")); + + auto c = x_dims[1]; + auto mean = ctx->Attrs().Get>("mean"); + auto std = ctx->Attrs().Get>("std"); + PADDLE_ENFORCE_EQ( + mean.size(), c, + platform::errors::NotFound( + "The channel number of Input(X) should equal to length of mean")); + PADDLE_ENFORCE_EQ( + mean.size(), c, + platform::errors::NotFound( + "The channel number of Input(X) should equal to length of mean")); + } + + std::vector output_dims(x_dims.size()); + for (int i = 0; i < x_dims.size(); ++i) { + output_dims[i] = x_dims[i]; + } + ctx->SetOutputDim("Out", phi::make_ddim(output_dims)); + ctx->ShareLoD("X", "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class MirrorNormalizeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of mirror_normalize op."); + AddInput("Mirror", + "(Tensor), The mirror vector for random flip, the " + "shape is {N, 1}, N is the batch size of input X"); + AddOutput("Out", + "(Tensor), The output tensor in the same shape as " + "input X."); + AddAttr>("mean", "The mean value to normalize data"); + AddAttr>("std", "The stdvalue to normalize data"); + AddComment(R"DOC( + This OP perform horizintal flipping on input Tensor. Mirror is used + to define whether flipping is need in the give sample. + )DOC"); + } +}; + +class MirrorNormalizeOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::data; +namespace plat = paddle::platform; +REGISTER_OPERATOR(mirror_normalize, ops::MirrorNormalizeOp, + ops::MirrorNormalizeOpMaker, + ops::MirrorNormalizeOpInferVarType); + +REGISTER_OP_CPU_KERNEL(mirror_normalize, ops::MirrorNormalizeCPUKernel, + ops::MirrorNormalizeCPUKernel); diff --git a/paddle/fluid/operators/data/mirror_normalize_op.cu b/paddle/fluid/operators/data/mirror_normalize_op.cu new file mode 100644 index 00000000000000..8bca0640d75991 --- /dev/null +++ b/paddle/fluid/operators/data/mirror_normalize_op.cu @@ -0,0 +1,93 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/data/mirror_normalize_op.h" + +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { +namespace data { + +using framework::LoDTensor; + +template +__global__ void KeMirrorNormalize(const int numel, const T* in_data, + const bool* mirrors, T* out_data, + const float* mean, const float* std, + const int chw, const int hw, const int w) { + CUDA_KERNEL_LOOP(idx, numel) { + int ni = idx / chw; + int ci = (idx % chw) / hw; + int wi = idx % w; + + int out_idx = idx; + if (mirrors[ni]) out_idx = idx - 2 * wi + w - 1; + out_data[out_idx] = (in_data[idx] - mean[ci]) / std[ci]; + } +} + +template +class MirrorNormalizeCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* mirror = ctx.Input("Mirror"); + auto* out = ctx.Output("Out"); + + auto mean = ctx.Attr>("mean"); + auto std = ctx.Attr>("std"); + + auto numel = x->numel(); + auto n = x->dims()[0]; + auto c = x->dims()[1]; + auto h = x->dims()[2]; + auto w = x->dims()[3]; + auto hw = h * w; + auto chw = c * hw; + + const T* x_data = x->data(); + const bool* mirror_data = mirror->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.cuda_device_context(); + const auto cplace = platform::CPUPlace(); + int bytes = sizeof(float) * mean.size(); + + auto mean_ptr = memory::Alloc(dev_ctx, bytes); + float* mean_data = reinterpret_cast(mean_ptr->ptr()); + memory::Copy(ctx.GetPlace(), mean_data, cplace, mean.data(), bytes, + dev_ctx.stream()); + auto std_ptr = memory::Alloc(dev_ctx, bytes); + float* std_data = reinterpret_cast(std_ptr->ptr()); + memory::Copy(ctx.GetPlace(), std_data, cplace, std.data(), bytes, + dev_ctx.stream()); + + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(dev_ctx, numel); + KeMirrorNormalize<<>>( + numel, x_data, mirror_data, out_data, mean_data, std_data, chw, hw, w); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(mirror_normalize, + ops::data::MirrorNormalizeCUDAKernel, + ops::data::MirrorNormalizeCUDAKernel); diff --git a/paddle/fluid/operators/data/mirror_normalize_op.h b/paddle/fluid/operators/data/mirror_normalize_op.h new file mode 100644 index 00000000000000..fce477c527dc84 --- /dev/null +++ b/paddle/fluid/operators/data/mirror_normalize_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +namespace data { + +template +class MirrorNormalizeCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // no cpu kernel. + PADDLE_THROW(platform::errors::Unimplemented( + "BatchResize op only supports GPU now.")); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/pipeline.cc b/paddle/fluid/operators/data/pipeline.cc new file mode 100644 index 00000000000000..216b41ea6da274 --- /dev/null +++ b/paddle/fluid/operators/data/pipeline.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/data/pipeline.h" +#include "paddle/fluid/framework/executor_cache.h" + +namespace paddle { +namespace operators { +namespace data { + +Pipeline::Pipeline(const std::shared_ptr global_block, + const platform::Place &place, int64_t start_op_index, + int64_t end_op_index, int64_t program_id, + const std::vector &output_var_names) + : running_(true), + global_block_(global_block), + place_(place), + start_op_index_(start_op_index), + end_op_index_(end_op_index), + program_id_(program_id), + output_var_names_(output_var_names) { + PADDLE_ENFORCE_GT(end_op_index_, start_op_index_, + platform::errors::InvalidArgument( + "end_op_index should be greater than start_op_index, " + "but recieve %d <= %d.", + end_op_index_, start_op_index_)); + + // Step1: prepare executor + auto *program = global_block_->Program(); + auto cache_info = framework::GetExecutorInfoFromCache( + *program, place_, start_op_index_, end_op_index_, + /*is_grad=*/false, program_id, &scope_); + auto ¶llel_executor = cache_info.first; + + // Step2: parset persistable variables + auto &skip_eager_delete_vars = + framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( + program_id, /*is_grad=*/false); + if (cache_info.second /*is_new_created*/) { + // DataLoader program do not has input variables, not need to + // skip memory reuse for input variables here + skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), + output_var_names.begin(), + output_var_names.end()); + framework::details::ParseSafeEagerDeletionSkipVars( + *program, end_op_index, output_var_names, &skip_eager_delete_vars); + } + + // Step3: start prefetch thread + parallel_executor->RunWithoutFetch(skip_eager_delete_vars); +} + +void Pipeline::CheckOutputVarStatus(const Variable &var, + const std::string &var_name) { + // only LoDTensor variable type support currently + PADDLE_ENFORCE_EQ(var.IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor in output variable %s get from DataLoader " + "program's internal scope is not initialized.", + var_name)); + PADDLE_ENFORCE_EQ( + var.IsType(), true, + platform::errors::InvalidArgument( + "The output variable %s get from DataLoader program's " + "internal scope holds wrong type. Expect type is " + "LoDTensor, but receive type is %s.", + var_name, platform::demangle(framework::ToTypeName(var.Type())))); +} + +void Pipeline::ReadNext(std::vector &out_vars) { + PADDLE_ENFORCE_EQ( + out_vars.size(), output_var_names_.size(), + platform::errors::InvalidArgument( + "Out variable number should equal to output variable name " + "number, but receive %d != %d", + out_vars.size(), output_var_names_.size())); + for (size_t i = 0; i < output_var_names_.size(); i++) { + auto *out_var = scope_.FindVar(output_var_names_[i]); + PADDLE_ENFORCE_NOT_NULL( + out_var, platform::errors::NotFound( + "The output variable %s is not found in DataLoader " + "program's internal scope", + output_var_names_[i])); + auto out_queue = out_var->Get().GetQueue(); + if (out_queue->IsClosed()) { + running_.store(false); + return; + } + + bool success = true; + auto outputs = out_queue->Pop(&success); + PADDLE_ENFORCE_EQ(success, true, platform::errors::PreconditionNotMet( + "Read from output queue %s failed", + output_var_names_[i])); + + CheckOutputVarStatus(*(out_vars[i]), output_var_names_[i]); + copy_tensor(outputs.at(0), out_vars[i]->GetMutable()); + for (auto &output : outputs) output.clear(); + outputs.clear(); + } +} + +// initialization static variables out of PipelineManager +PipelineManager *PipelineManager::pm_instance_ptr_ = nullptr; +std::mutex PipelineManager::m_; + +} // data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/pipeline.h b/paddle/fluid/operators/data/pipeline.h new file mode 100644 index 00000000000000..0c673b52eb785a --- /dev/null +++ b/paddle/fluid/operators/data/pipeline.h @@ -0,0 +1,132 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "ThreadPool.h" + +#include "paddle/fluid/framework/parallel_executor.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { + +using BlockDesc = framework::BlockDesc; +using Scope = framework::Scope; +using ParallelExecutor = framework::ParallelExecutor; + +using Variable = framework::Variable; +using LoDTensor = framework::LoDTensor; +using LoDTensorBlockingQueue = operators::reader::LoDTensorBlockingQueue; +using LoDTensorBlockingQueueHolder = + operators::reader::LoDTensorBlockingQueueHolder; + +namespace data { + +class Pipeline { + public: + Pipeline(const std::shared_ptr global_block, + const platform::Place &place, int64_t start_op_index, + int64_t end_op_index, int64_t program_id, + const std::vector &output_var_names); + + ~Pipeline() {} + + void ReadNext(std::vector &out_vars); + + inline bool IsRunning() { return running_.load(); } + + void Reset() { running_.store(true); } + + private: + void CheckOutputVarStatus(const Variable &var, const std::string &var_name); + + void copy_tensor(const framework::LoDTensor &lod_tensor, + framework::LoDTensor *out) const { + if (lod_tensor.numel() == 0) return; + auto &out_tensor = *out; + framework::TensorCopy(lod_tensor, lod_tensor.place(), &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } + + std::atomic running_; + + Scope scope_; + std::shared_ptr global_block_; + platform::Place place_; + int64_t start_op_index_; + int64_t end_op_index_; + int64_t program_id_; + + std::vector output_var_names_; +}; + +class PipelineManager { + // PipelineManager is a signleton manager for Pipeline, we + // create single Pipeline for a program id + private: + DISABLE_COPY_AND_ASSIGN(PipelineManager); + + static PipelineManager *pm_instance_ptr_; + static std::mutex m_; + + std::map> prog_id_to_pipeline_; + + public: + static PipelineManager *Instance() { + if (pm_instance_ptr_ == nullptr) { + std::lock_guard lk(m_); + if (pm_instance_ptr_ == nullptr) { + pm_instance_ptr_ = new PipelineManager; + } + } + return pm_instance_ptr_; + } + + Pipeline *GetPipeline(int64_t program_id, BlockDesc *global_block, + const platform::Place &place, int64_t start_op_index, + int64_t end_op_index, + const std::vector &output_var_names) { + auto iter = prog_id_to_pipeline_.find(program_id); + if (iter == prog_id_to_pipeline_.end()) { + prog_id_to_pipeline_[program_id] = std::unique_ptr(new Pipeline( + std::shared_ptr(global_block), place, start_op_index, + end_op_index, program_id, output_var_names)); + return prog_id_to_pipeline_[program_id].get(); + } else { + return iter->second.get(); + } + } + + void ShutDownPipeline(int64_t program_id) { + prog_id_to_pipeline_.erase(program_id); + } + + void ResetPipeline(int64_t program_id) { + auto iter = prog_id_to_pipeline_.find(program_id); + if (iter != prog_id_to_pipeline_.end()) { + iter->second.get()->Reset(); + } + } + + void ShutDown() { prog_id_to_pipeline_.clear(); } + + PipelineManager() {} + + ~PipelineManager() { ShutDown(); } +}; + +} // data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/random_roi_generator.cc b/paddle/fluid/operators/data/random_roi_generator.cc new file mode 100644 index 00000000000000..af512218a39980 --- /dev/null +++ b/paddle/fluid/operators/data/random_roi_generator.cc @@ -0,0 +1,105 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/data/random_roi_generator.h" + +namespace paddle { +namespace operators { +namespace data { + +RandomROIGenerator::RandomROIGenerator(AspectRatioRange aspect_ratio_range, + AreaRange area_range, int64_t seed, + int64_t num_attempts) + : aspect_ratio_range_(aspect_ratio_range), + area_range_(area_range), + random_generator_(seed), + seed_(seed), + num_attempts_(num_attempts) {} + +void RandomROIGenerator::GenerateRandomROI(const int64_t width, + const int64_t height, ROI* roi) { + if (width <= 0 || height <= 0) return; + + float min_wh_ratio = aspect_ratio_range_.first; + float max_wh_ratio = aspect_ratio_range_.second; + float max_hw_ratio = 1 / aspect_ratio_range_.first; + float min_area = width * height * area_distribution_.a(); + auto max_width = std::max(1, height * max_wh_ratio); + auto max_height = std::max(1, width * max_hw_ratio); + + // process max_width/height cannot satisfy min_area restriction firstly + if (height * max_width < min_area) { + roi->w = max_width; + roi->h = height; + } else if (width * max_height < min_area) { + roi->w = width; + roi->h = max_height; + } else { + int64_t attempts = num_attempts_; + while (attempts-- > 0) { + // calc ROI area + float scale = area_distribution_(random_generator_); + float roi_area = scale * height * width; + + // calc ROI width/height + float ratio = std::exp(aspect_ratio_distribution_(random_generator_)); + auto w = static_cast(std::roundf(sqrtf(roi_area * ratio))); + auto h = static_cast(std::roundf(sqrtf(roi_area / ratio))); + w = std::max(1, w); + h = std::max(1, h); + + // check restrictions + ratio = static_cast(w) / h; + if (w <= width && h <= height && ratio >= min_wh_ratio && + ratio <= max_hw_ratio) { + roi->w = w; + roi->h = h; + break; + } + } + + if (attempts <= 0) { + float max_area = area_distribution_.b() * width * height; + float ratio = static_cast(width) / height; + int64_t w, h; + if (ratio > max_wh_ratio) { + w = max_width; + h = height; + } else if (ratio < min_wh_ratio) { + w = width; + h = max_height; + } else { + w = width; + h = height; + } + float scale = std::min(1.f, max_area / (w * h)); + roi->w = std::max(1, w * sqrtf(scale)); + roi->h = std::max(1, h * sqrtf(scale)); + } + + // generate random left top coordination x, y + roi->x = std::uniform_int_distribution( + 0, width - roi->w)(random_generator_); + roi->y = std::uniform_int_distribution( + 0, height - roi->h)(random_generator_); + } +} + +// initialization static variables out of GeneratorManager +GeneratorManager* GeneratorManager::gm_instance_ptr_ = nullptr; +std::mutex GeneratorManager::m_; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/random_roi_generator.h b/paddle/fluid/operators/data/random_roi_generator.h new file mode 100644 index 00000000000000..88b65045826381 --- /dev/null +++ b/paddle/fluid/operators/data/random_roi_generator.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace operators { +namespace data { + +using AspectRatioRange = std::pair; +using AreaRange = std::pair; + +struct ROI { + // left top coordination (x, y) + int64_t x; + int64_t y; + // width/height of crop window (w, h) + int64_t w; + int64_t h; +}; + +class RandomROIGenerator { + public: + explicit RandomROIGenerator(AspectRatioRange aspect_ratio_range, + AreaRange area_range, int64_t seed = time(0), + int64_t num_attempts = 10); + + void GenerateRandomROI(const int64_t width, const int64_t height, ROI* roi); + + private: + AspectRatioRange aspect_ratio_range_; + AreaRange area_range_; + + std::uniform_real_distribution aspect_ratio_distribution_; + std::uniform_real_distribution area_distribution_; + std::mt19937 random_generator_; + + int64_t seed_; + int64_t num_attempts_; +}; + +class GeneratorManager { + using Generators = std::vector>; + + private: + static GeneratorManager* gm_instance_ptr_; + static std::mutex m_; + + std::map> prog_id_to_generators_; + + public: + static GeneratorManager* Instance() { + if (gm_instance_ptr_ == nullptr) { + std::lock_guard lk(m_); + if (gm_instance_ptr_ == nullptr) { + gm_instance_ptr_ = new GeneratorManager; + } + } + return gm_instance_ptr_; + } + + Generators* GetGenerators(const int64_t program_id, const int batch_size, + const AspectRatioRange aspect_ratio_range, + const AreaRange area_range) { + auto iter = prog_id_to_generators_.find(program_id); + if (iter == prog_id_to_generators_.end()) { + prog_id_to_generators_[program_id] = + std::unique_ptr(new Generators(batch_size)); + + std::seed_seq rand_seq{static_cast(time(0))}; + std::vector rands(batch_size); + rand_seq.generate(rands.begin(), rands.end()); + + for (int i = 0; i < batch_size; i++) { + prog_id_to_generators_[program_id]->at(i).reset( + new RandomROIGenerator(aspect_ratio_range, area_range, rands[i])); + } + } + return prog_id_to_generators_[program_id].get(); + } +}; + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/data/unity_build_rule.cmake b/paddle/fluid/operators/data/unity_build_rule.cmake new file mode 100644 index 00000000000000..354e611aa570bd --- /dev/null +++ b/paddle/fluid/operators/data/unity_build_rule.cmake @@ -0,0 +1,27 @@ +# This file records the Unity Build compilation rules. +# The source files in a `register_unity_group` called are compiled in a unity +# file. +# Generally, the combination rules in this file do not need to be modified. +# If there are some redefined error in compiling with the source file which +# in combination rule, you can remove the source file from the following rules. +register_unity_group(cc + pipeline.cc + map_runner.cc + random_roi_generator.cc + nvjpeg_decoder.cc + dataloader_op.cc + map_op.cc + batch_decode_random_crop_op.cc + batch_decode_op.cc + batch_resize_op.cc + mirror_normalize_op.cc + batch_random_crop_and_resize_op.cc) + +register_unity_group(cu + dataloader_op.cu.cc + map_op.cu.cc + batch_decode_random_crop_op.cu + batch_decode_op.cu + batch_resize_op.cu + mirror_normalize_op.cu + batch_random_crop_and_resize_op.cu) diff --git a/paddle/fluid/operators/data/utils.h b/paddle/fluid/operators/data/utils.h new file mode 100644 index 00000000000000..7b63cb0d3d2867 --- /dev/null +++ b/paddle/fluid/operators/data/utils.h @@ -0,0 +1,76 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/operators/data/data_reader_op.h" +#include "paddle/fluid/operators/data/map_runner.h" +#include "paddle/fluid/operators/data/pipeline.h" + +#ifdef PADDLE_WITH_GPU +#include "paddle/fluid/operators/data/image_decoder.h" +#endif + +namespace paddle { +namespace operators { +namespace data { + +#ifdef PADDLE_WITH_GPU +extern ImageDecoderThreadPool* decode_pool; +#endif + +void ShutDownAllDataLoaders() { + // step 1: shutdown reader + ReaderManager::Instance()->ShutDown(); + +#ifdef PADDLE_WITH_GPU + // step 2: shutdown decoder + if (decode_pool) decode_pool->ShutDown(); +#endif + + // step 3: shutdown MapRunner + MapRunnerManager::Instance()->ShutDown(); +} + +void ShutDownReadersAndDecoders(const int64_t program_id) { + // step 1: shutdown reader + ReaderManager::Instance()->ShutDownReader(program_id); + +#ifdef PADDLE_WITH_GPU + // step 2: shutdown decoder + ImageDecoderThreadPoolManager::Instance()->ShutDownDecoder(program_id); +#endif +} + +void ShutDownPipeline(const int64_t program_id) { + PipelineManager::Instance()->ShutDownPipeline(program_id); +} + +void ResetDataLoader(const int64_t reader_id, + const std::vector map_ids, + const int64_t pipeline_id) { + // step 1: reset readers + ReaderManager::Instance()->ResetReader(reader_id); + + // step 2: reset maps + for (auto& map_id : map_ids) { + MapRunnerManager::Instance()->ResetMapRunner(map_id); + } + + // step3: reset pipeline + PipelineManager::Instance()->ResetPipeline(pipeline_id); +} + +} // namespace data +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu new file mode 100644 index 00000000000000..0ee26752aebc34 --- /dev/null +++ b/paddle/fluid/operators/math/math_function.cu @@ -0,0 +1,293 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/math_function_impl.h" +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { +namespace math { + +using float16 = paddle::platform::float16; +using bfloat16 = paddle::platform::bfloat16; + +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant>; +template struct SetConstant>; + +#define DEFINE_GPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose, RANK>; \ + template struct Transpose, RANK>; + +DEFINE_GPU_TRANS(1); +DEFINE_GPU_TRANS(2); +DEFINE_GPU_TRANS(3); +DEFINE_GPU_TRANS(4); +DEFINE_GPU_TRANS(5); +DEFINE_GPU_TRANS(6); + +#define REINTERPRET(T, DST_PTR, SRC_PTR) \ + T* DST_PTR = reinterpret_cast(SRC_PTR) + +template +__global__ void TransposeNormalKernel(const T* in_ptr, T* out_ptr, + int64_t element, + const int64_t* in_stride_ptr, + const int64_t* out_stride_ptr, + const int64_t* axis_ptr, int rank) { + CUDA_KERNEL_LOOP(out_idx, element) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride_ptr[i]; + tmp_idx -= coordinate * out_stride_ptr[i]; + in_idx += coordinate * in_stride_ptr[axis_ptr[i]]; + } + out_ptr[out_idx] = in_ptr[in_idx]; + } +} + +template +struct TransposeNormal { + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& in, framework::Tensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = framework::stride(in.dims()); + auto out_stride = framework::stride(out->dims()); + auto* in_ptr = in.data(); + auto* out_ptr = out->data(); + + // copy in_stride, out_stride, axis to gpu device + const platform::CUDAPlace& cuda_place = + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()); + platform::CPUPlace cpu_place = platform::CPUPlace(); + size_t size = 3 * rank * sizeof(int64_t); + auto cpu_buf_holder = memory::AllocShared(cpu_place, size); + auto cuda_buf_holder = memory::AllocShared(cuda_place, size); + REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr()); + REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr()); + for (int i = 0; i < rank; ++i) { + cpu_buf[i] = in_stride[i]; + cpu_buf[rank + i] = out_stride[i]; + cpu_buf[2 * rank + i] = axis[i]; + } + memory::Copy(cuda_place, cuda_buf, cpu_place, cpu_buf, size, + context.stream()); + REINTERPRET(const int64_t, in_stride_ptr, cuda_buf); + REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank); + REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank); + + const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock(); + const int MAX_GRID_DIM = + context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; + int64_t elements = in.numel(); + int block_size = (elements >= MAX_BLOCK_DIM) + ? MAX_BLOCK_DIM + : (1 << static_cast(std::log2(elements))); + int grid_size = elements / block_size; + grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size; + TransposeNormalKernel<<>>( + in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr, + rank); + } +}; + +// define transpose normal +#define DEFINE_GPU_TRANS_NORMAL(TYPE) \ + template struct TransposeNormal + +DEFINE_GPU_TRANS_NORMAL(float16); +DEFINE_GPU_TRANS_NORMAL(bfloat16); +DEFINE_GPU_TRANS_NORMAL(float); +DEFINE_GPU_TRANS_NORMAL(double); +DEFINE_GPU_TRANS_NORMAL(int); +DEFINE_GPU_TRANS_NORMAL(int64_t); +DEFINE_GPU_TRANS_NORMAL(bool); +DEFINE_GPU_TRANS_NORMAL(int16_t); +DEFINE_GPU_TRANS_NORMAL(uint8_t); +DEFINE_GPU_TRANS_NORMAL(int8_t); +DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex); +DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex); + +struct TensorSetConstantGPU { + TensorSetConstantGPU(const platform::DeviceContext& context, + framework::Tensor* tensor, float value) + : context_(context), tensor_(tensor), value_(value) {} + + template + void apply() const { + SetConstant functor; + functor(reinterpret_cast(context_), + tensor_, static_cast(value_)); + } + + const platform::DeviceContext& context_; + framework::Tensor* tensor_; + float value_; +}; + +template <> +void set_constant_with_place( + const platform::DeviceContext& context, framework::Tensor* tensor, + float value) { + framework::VisitDataType(tensor->type(), + TensorSetConstantGPU(context, tensor, value)); +} + +template +__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width, + int num) { + T tmp = 1.0 / width; + CUDA_KERNEL_LOOP(i, num) { + int h = i * tmp; + int w = i - h * width; + c[i] = a[i] + b[w]; + } +} + +template +struct RowwiseAdd { + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, framework::Tensor* output) { + auto in_dims = input.dims(); + auto out_dims = output->dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ( + vector.numel(), size, + platform::errors::InvalidArgument( + "The input vector size" + " should be equal to the size of each row of input tensor." + " Expected vector size=%d, but received %d", + size, vector.numel())); + const char* in_dims_cstr = in_dims.to_str().c_str(); + const char* out_dims_cstr = out_dims.to_str().c_str(); + PADDLE_ENFORCE_EQ( + out_dims, in_dims, + platform::errors::InvalidArgument( + "The output tensor shape should be same as the input tensor" + " shape. Expected output tensor shape: %s," + " but received %s", + in_dims_cstr, out_dims_cstr)); + int blocks = 512; + int grids = (input.numel() + blocks - 1) / blocks; + RowwiseAddKernel<<>>( + input.data(), vector.data(), output->data(), + static_cast(in_dims[1]), static_cast(input.numel())); + } +}; + +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; +template struct ColwiseSum; +// template struct ColwiseSum; +// The ColwiseSum failed in debug mode, +// and only failed for this case. So reimplemented it. +template <> +void ColwiseSum::operator()( + const platform::CUDADeviceContext& context, const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector->numel(), size, + platform::errors::InvalidArgument( + "The size of input vector" + " should be equal to the size of input tensor column" + " dimension. Expected vector size=%d, but received %d", + size, vector->numel())); + framework::Tensor one; + one.mutable_data({in_dims[0]}, context.GetPlace()); + SetConstant set; + set(context, &one, static_cast(1.0)); + GetBlas(context).GEMV( + true, static_cast(in_dims[0]), static_cast(in_dims[1]), 1.0, + input.data(), one.data(), 0.0, vector->data()); +} + +template struct RowwiseSum; +// template struct RowwiseSum; +// TODO(zcd): Following ColwiseSum format, need to confirm. +// The RowwiseSum failed in debug mode, +// and only failed for this case. So reimplemented it. +template <> +void RowwiseSum::operator()( + const platform::CUDADeviceContext& context, const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0], + platform::errors::InvalidArgument( + "The size of input vector" + " should be equal to the size of input tensor row" + " dimension. Expected vector size=%d, but received %d", + in_dims[0], vector->numel())); + framework::Tensor one; + one.mutable_data({size}, context.GetPlace()); + SetConstant set; + set(context, &one, static_cast(1.0)); + GetBlas(context).GEMV( + true, static_cast(in_dims[1]), static_cast(in_dims[0]), 1.0, + one.data(), input.data(), 0.0, vector->data()); +} + +template struct RowwiseMean; +template struct RowwiseMean; + +template +struct ElementwiseAddTo { + void operator()(platform::CUDADeviceContext* ctx, + const framework::Tensor& src, framework::Tensor* dst) { + auto in = framework::EigenVector::Flatten(src); + auto out = framework::EigenVector::Flatten(*dst); + auto& place = *(ctx->eigen_device()); + out.device(place) = out + in; + } +}; + +template struct ElementwiseAddTo; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index f3934c7d8713b2..6ca2ac36309517 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -1,4 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +CPUDeviceContext::CPUDeviceContext() { + eigen_device_.reset(new Eigen::DefaultDevice()); +} + Copyright (c) 2022 NVIDIA Corporation. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -278,6 +282,95 @@ DeviceContextPool::DeviceContextPool( } } +template +inline void EmplaceAsyncDeviceContext( + std::map>>* map_ptr, + platform::Place p, const int64_t stream_id) { + using PtrType = std::unique_ptr; + + auto* dev_ctx = new DevCtx(p); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + auto* cuda_ctx = dynamic_cast(dev_ctx); + PADDLE_ENFORCE_NOT_NULL( + cuda_ctx, platform::errors::InvalidArgument( + "Failed to dynamic_cast dev_ctx into CUDADeviceContext.")); + dev_ctx->SetAllocator( + memory::allocation::AllocatorFacade::Instance().GetAllocator(p).get()); + dev_ctx->SetPinnedAllocator( + memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CUDAPinnedPlace()) + .get()); + + cuda_ctx->PartialInitWithAllocator(); + dev_ctx->SetGenerator( + framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get()); +#endif + + dev_ctx->SetHostGenerator(framework::DefaultCPUGenerator().get()); + dev_ctx->SetHostAllocator(memory::allocation::AllocatorFacade::Instance() + .GetAllocator(platform::CPUPlace()) + .get()); + dev_ctx->SetZeroAllocator(memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(p) + .get()); + + (*map_ptr)[p].emplace(stream_id, PtrType(dev_ctx)); +} + +AsyncDeviceContextPool* AsyncDeviceContextPool::pool = nullptr; + +platform::DeviceContext* AsyncDeviceContextPool::Get( + const platform::Place& place, const int64_t stream_id) { + VLOG(6) << "AsyncDeviceContextPool Get: " << place << ", " << stream_id; + if (!platform::is_gpu_place(place)) return nullptr; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + auto place_it = device_contexts_.find(place); + if (place_it == device_contexts_.end()) { + PADDLE_THROW(platform::errors::Unimplemented( + "Place %s is not supported. Please check that your paddle compiles " + "with WITH_GPU, WITH_XPU or WITH_ASCEND_CL option or check that " + "your train process set the correct device id if you use Executor.", + place)); + } + + if (device_contexts_[place].count(stream_id) > 0) { + return device_contexts_[place][stream_id].get(); + } else { + // auto* dev_ctx = dynamic_cast + // device_contexts_[place].emplace(stream_id, + // std::unique_ptr( + // new platform::CUDADeviceContext(place))); + EmplaceAsyncDeviceContext(&device_contexts_, place, + stream_id); + return device_contexts_[place][stream_id].get(); + } +#else + return nullptr; +#endif +} + +AsyncDeviceContextPool::AsyncDeviceContextPool( + const std::vector& places) { + PADDLE_ENFORCE_GT( + places.size(), 0, + platform::errors::InvalidArgument("The number of platform places should " + "be larger than 0. But received %d.", + places.size())); + std::set set; + for (auto& p : places) { + set.insert(p); + } + for (auto& p : set) { + if (platform::is_gpu_place(p)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + device_contexts_.emplace( + p, std::map>()); +#endif + } + } +} + CPUDeviceContext::CPUDeviceContext() : phi::CPUContext() { phi::CPUContext::Init(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 2c5f24d28c6d6b..cd2d2fa6eda50a 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -928,5 +928,37 @@ class DeviceContextPool { DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; +/*! \brief async device context pool singleton */ +class AsyncDeviceContextPool { + public: + explicit AsyncDeviceContextPool(const std::vector& places); + + static AsyncDeviceContextPool& Instance() { + PADDLE_ENFORCE_NOT_NULL(pool, + platform::errors::PreconditionNotMet( + "Need to Create DeviceContextPool firstly!")); + return *pool; + } + + /*! \brief Create should only called by Init function */ + static AsyncDeviceContextPool& Init( + const std::vector& places) { + if (pool == nullptr) { + pool = new AsyncDeviceContextPool(places); + } + return *pool; + } + + /*! \brief Return handle of single device context. */ + platform::DeviceContext* Get(const platform::Place& place, + const int64_t stream_id); + + private: + static AsyncDeviceContextPool* pool; + std::map>> + device_contexts_; + DISABLE_COPY_AND_ASSIGN(AsyncDeviceContextPool); +}; + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/nvjpeg.h b/paddle/fluid/platform/dynload/nvjpeg.h index 8aaf672fe67b9f..a6649086fc18f7 100644 --- a/paddle/fluid/platform/dynload/nvjpeg.h +++ b/paddle/fluid/platform/dynload/nvjpeg.h @@ -24,11 +24,34 @@ namespace dynload { using DynLoad__##__name = phi::dynload::DynLoad__##__name; \ extern DynLoad__##__name __name -#define NVJPEG_RAND_ROUTINE_EACH(__macro) \ - __macro(nvjpegCreateSimple); \ - __macro(nvjpegJpegStateCreate); \ - __macro(nvjpegGetImageInfo); \ - __macro(nvjpegJpegStateDestroy); \ +#define NVJPEG_RAND_ROUTINE_EACH(__macro) \ + __macro(nvjpegCreateSimple); \ + __macro(nvjpegCreateEx); \ + __macro(nvjpegSetDeviceMemoryPadding); \ + __macro(nvjpegSetPinnedMemoryPadding); \ + __macro(nvjpegJpegStateCreate); \ + __macro(nvjpegJpegStreamCreate); \ + __macro(nvjpegDecodeParamsCreate); \ + __macro(nvjpegDecoderCreate); \ + __macro(nvjpegDecoderStateCreate); \ + __macro(nvjpegBufferDeviceCreate); \ + __macro(nvjpegBufferPinnedCreate); \ + __macro(nvjpegDecodeParamsSetOutputFormat); \ + __macro(nvjpegDecodeParamsSetROI); \ + __macro(nvjpegStateAttachPinnedBuffer); \ + __macro(nvjpegStateAttachDeviceBuffer); \ + __macro(nvjpegJpegStreamParse); \ + __macro(nvjpegDecodeJpegHost); \ + __macro(nvjpegDecodeJpegTransferToDevice); \ + __macro(nvjpegDecodeJpegDevice); \ + __macro(nvjpegJpegStreamDestroy); \ + __macro(nvjpegDecodeParamsDestroy); \ + __macro(nvjpegDecoderDestroy); \ + __macro(nvjpegBufferDeviceDestroy); \ + __macro(nvjpegBufferPinnedDestroy); \ + __macro(nvjpegGetImageInfo); \ + __macro(nvjpegJpegStateDestroy); \ + __macro(nvjpegDestroy); \ __macro(nvjpegDecode); NVJPEG_RAND_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_NVJPEG_WRAP); diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index c7a6bdc3cefae8..ffbd93b101a080 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -71,6 +71,7 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/port.h" #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/dynload/nvjpeg.h" #include "paddle/phi/backends/dynload/cublas.h" #include "paddle/phi/backends/dynload/cudnn.h" #include "paddle/phi/backends/dynload/curand.h" @@ -240,6 +241,7 @@ DEFINE_EXTERNAL_API_TYPE(cusparseStatus_t, CUSPARSE_STATUS_SUCCESS, CUSPARSE); DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER); DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT); DEFINE_EXTERNAL_API_TYPE(CUresult, CUDA_SUCCESS, CU); +DEFINE_EXTERNAL_API_TYPE(nvjpegStatus_t, NVJPEG_STATUS_SUCCESS, NVJPEG); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL); @@ -280,10 +282,15 @@ inline const char* GetErrorMsgUrl(T status) { break; case platform::proto::ApiType::CUFFT: return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult"; + break; case platform::proto::ApiType::CUSPARSE: return "https://docs.nvidia.com/cuda/cusparse/" "index.html#cusparseStatus_t"; break; + case platform::proto::ApiType::NVJPEG: + return "https://docs.nvidia.com/cuda/nvjpeg/" + "index.html#nvjpeg-api-return-codes"; + break; default: return "Unknown type of External API, can't get error message URL!"; break; @@ -455,6 +462,37 @@ inline std::string build_nvidia_error_msg(cufftResult_t stat) { return sout.str(); } +/**************** NVJPEG ERROR ****************/ +inline bool is_error(nvjpegStatus_t stat) { + return stat != NVJPEG_STATUS_SUCCESS; +} + +inline std::string get_nvjpeg_error_str(nvjpegStatus_t stat) { + switch (stat) { + case NVJPEG_STATUS_SUCCESS: + return "NVJPEG_STATUS_SUCCESS"; + case NVJPEG_STATUS_NOT_INITIALIZED: + return "NVJPEG_STATUS_NOT_INITIALIZED"; + case NVJPEG_STATUS_INVALID_PARAMETER: + return "NVJPEG_STATUS_INVALID_PARAMETER"; + case NVJPEG_STATUS_BAD_JPEG: + return "NVJPEG_STATUS_BAD_JPEG"; + case NVJPEG_STATUS_JPEG_NOT_SUPPORTED: + return "NVJPEG_STATUS_JPEG_NOT_SUPPORTED"; + case NVJPEG_STATUS_ALLOCATOR_FAILURE: + return "NVJPEG_STATUS_ALLOCATOR_FAILURE"; + case NVJPEG_STATUS_EXECUTION_FAILED: + return "NVJPEG_STATUS_EXECUTION_FAILED"; + case NVJPEG_STATUS_ARCH_MISMATCH: + return "NVJPEG_STATUS_ARCH_MISMATCH"; + case NVJPEG_STATUS_INTERNAL_ERROR: + return "NVJPEG_STATUS_INTERNAL_ERROR"; + case NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED: + return "NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED"; + } + return "Invalid nvjpeg status code"; +} + /*************** CUresult ERROR ***************/ inline bool is_error(CUresult stat) { return stat != CUDA_SUCCESS; } @@ -514,6 +552,21 @@ inline std::string build_nvidia_error_msg(ncclResult_t nccl_result) { } \ } while (0) +#define PADDLE_ENFORCE_NVJPEG_SUCCESS(COND) \ + do { \ + auto __cond__ = (COND); \ + using __NVJPEG_STATUS_TYPE__ = decltype(__cond__); \ + constexpr auto __success_type__ = \ + ::paddle::platform::details::ExternalApiType< \ + __NVJPEG_STATUS_TYPE__>::kSuccess; \ + if (UNLIKELY(__cond__ != __success_type__)) { \ + auto __summary__ = ::paddle::platform::errors::External( \ + "Nvjpeg failed: %s", \ + ::paddle::platform::get_nvjpeg_error_str(__cond__)); \ + __THROW_ERROR_INTERNAL__(__summary__); \ + } \ + } while (0) + inline void retry_sleep(unsigned milliseconds) { #ifdef _WIN32 Sleep(milliseconds); diff --git a/paddle/fluid/platform/external_error.proto b/paddle/fluid/platform/external_error.proto index 8861c2c2ff4fb6..7e454b2f4bec19 100644 --- a/paddle/fluid/platform/external_error.proto +++ b/paddle/fluid/platform/external_error.proto @@ -27,6 +27,7 @@ enum ApiType { CUFFT = 6; CU = 7; CUSPARSE = 8; + NVJPEG = 9; } message MessageDesc { diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 293a71dbd968c6..7cd8b0db585624 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -248,6 +248,7 @@ void InitDevices(const std::vector devices) { places.emplace_back(platform::CPUPlace()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) places.emplace_back(platform::CUDAPinnedPlace()); + platform::AsyncDeviceContextPool::Init(places); #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE const char *custom_kernel_root_p = std::getenv("CUSTOM_DEVICE_ROOT"); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index d9aab3dbb04ce4..a24e2e1bb4b867 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -76,6 +76,7 @@ std::map> op_ins_map = { {"Param", "Grad", "Velocity", "Index", "LearningRate", "MasterParam"}}, {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, {"run_program", {"X", "Params"}}, + {"map", {"X"}}, {"fused_feedforward", {"Dropout1Seed", "Dropout2Seed", "Linear1Bias", "Linear2Bias", "Ln1Scale", "Ln1Bias", "Ln2Scale", "Ln2Bias"}}, @@ -239,6 +240,9 @@ std::map> op_passing_outs_map = { {"Out", "OutScale", "OutAccum", "OutState"}}, {"rnn", {"DropoutState"}}, {"run_program", {"Out", "DOut", "OutScope"}}, + {"dataloader", {"Out"}}, + {"map", {"Out"}}, + {"file_label_loader", {"Image"}}, {"clear_float_status", {"FloatStatusOut"}}, {"get_float_status", {"FloatStatusOut"}}, {"assign", {"Out"}}, diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 66bf8c95179afb..78f0e88e6232bf 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -235,6 +235,8 @@ void BindVarDsec(pybind11::module *m) { .value("LOD_TENSOR_ARRAY", pd::proto::VarType::LOD_TENSOR_ARRAY) .value("PLACE_LIST", pd::proto::VarType::PLACE_LIST) .value("READER", pd::proto::VarType::READER) + .value("LOD_TENSOR_BLOCKING_QUEUE", + pd::proto::VarType::LOD_TENSOR_BLOCKING_QUEUE) .value("RAW", pd::proto::VarType::RAW) .value("STRING", pd::proto::VarType::STRING) .value("STRINGS", pd::proto::VarType::STRINGS) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 96d86ee1a31004..3b251a504d68dd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -70,6 +70,7 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/operators/data/utils.h" #include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_info.h" @@ -769,6 +770,18 @@ PYBIND11_MODULE(core_noavx, m) { m.def("_promote_types_if_complex_exists", &paddle::framework::PromoteTypesIfComplexExists); + m.def("_shutdown_all_dataloaders", + &paddle::operators::data::ShutDownAllDataLoaders); + m.def("_shutdown_readers_and_decoders", + &paddle::operators::data::ShutDownReadersAndDecoders); + m.def("_shutdown_pipeline", &paddle::operators::data::ShutDownPipeline); + m.def("_reset_dataloader", [](const int64_t reader_id, + const std::vector map_ids, + const int64_t pipeline_id) { + paddle::operators::data::ResetDataLoader(reader_id, map_ids, pipeline_id); + + }); + py::class_ custom_op_kernel_ctx( m, "CustomOpKernelContext", R"DOC()DOC"); g_custom_op_kernel_ctx_pytype = diff --git a/paddle/phi/backends/dynload/nvjpeg.h b/paddle/phi/backends/dynload/nvjpeg.h index 13bb8a5698f152..fd40bb795f3911 100644 --- a/paddle/phi/backends/dynload/nvjpeg.h +++ b/paddle/phi/backends/dynload/nvjpeg.h @@ -36,11 +36,34 @@ extern void *nvjpeg_dso_handle; }; \ extern DynLoad__##__name __name -#define NVJPEG_RAND_ROUTINE_EACH(__macro) \ - __macro(nvjpegCreateSimple); \ - __macro(nvjpegJpegStateCreate); \ - __macro(nvjpegGetImageInfo); \ - __macro(nvjpegJpegStateDestroy); \ +#define NVJPEG_RAND_ROUTINE_EACH(__macro) \ + __macro(nvjpegCreateSimple); \ + __macro(nvjpegCreateEx); \ + __macro(nvjpegSetDeviceMemoryPadding); \ + __macro(nvjpegSetPinnedMemoryPadding); \ + __macro(nvjpegJpegStateCreate); \ + __macro(nvjpegJpegStreamCreate); \ + __macro(nvjpegDecodeParamsCreate); \ + __macro(nvjpegDecoderCreate); \ + __macro(nvjpegDecoderStateCreate); \ + __macro(nvjpegBufferDeviceCreate); \ + __macro(nvjpegBufferPinnedCreate); \ + __macro(nvjpegDecodeParamsSetOutputFormat); \ + __macro(nvjpegDecodeParamsSetROI); \ + __macro(nvjpegStateAttachPinnedBuffer); \ + __macro(nvjpegStateAttachDeviceBuffer); \ + __macro(nvjpegJpegStreamParse); \ + __macro(nvjpegDecodeJpegHost); \ + __macro(nvjpegDecodeJpegTransferToDevice); \ + __macro(nvjpegDecodeJpegDevice); \ + __macro(nvjpegJpegStreamDestroy); \ + __macro(nvjpegDecodeParamsDestroy); \ + __macro(nvjpegDecoderDestroy); \ + __macro(nvjpegBufferDeviceDestroy); \ + __macro(nvjpegBufferPinnedDestroy); \ + __macro(nvjpegGetImageInfo); \ + __macro(nvjpegJpegStateDestroy); \ + __macro(nvjpegDestroy); \ __macro(nvjpegDecode); NVJPEG_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NVJPEG_WRAP); diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index df2af82d551ee0..ce5b7b5df37040 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -86,6 +86,9 @@ template struct SetConstant; \ + template struct Transpose; \ template struct Transpose; \ diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 625728c0fcef20..0d8cf47e9e252c 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -281,6 +281,10 @@ def to_list(s): from .core_avx import _get_current_stream from .core_avx import _Profiler, _ProfilerResult, _RecordEvent from .core_avx import _set_current_stream + from .core_avx import _shutdown_all_dataloaders + from .core_avx import _shutdown_readers_and_decoders + from .core_avx import _shutdown_pipeline + from .core_avx import _reset_dataloader if sys.platform != 'win32': from .core_avx import _set_process_pids from .core_avx import _erase_process_pids @@ -337,6 +341,10 @@ def to_list(s): from .core_noavx import _device_synchronize from .core_noavx import _get_current_stream from .core_noavx import _set_current_stream + from .core_noavx import _shutdown_all_dataloaders + from .core_noavx import _shutdown_readers_and_decoders + from .core_noavx import _shutdown_pipeline + from .core_noavx import _reset_dataloader from .core_noavx import _Profiler, _ProfilerResult, _RecordEvent if sys.platform != 'win32': from .core_noavx import _set_process_pids diff --git a/python/paddle/fluid/dataloader/__init__.py b/python/paddle/fluid/dataloader/__init__.py index 597f1f217483cc..5ad110908fab4f 100644 --- a/python/paddle/fluid/dataloader/__init__.py +++ b/python/paddle/fluid/dataloader/__init__.py @@ -26,7 +26,15 @@ from . import sampler from .sampler import * +from . import pipeline +from .pipeline import * + +from . import ops +from .ops import * + __all__ = dataset.__all__ \ + batch_sampler.__all__ \ + dataloader_iter.__all__ \ - + sampler.__all__ + + sampler.__all__ \ + + pipeline.__all__ \ + + ops.__all__ diff --git a/python/paddle/fluid/dataloader/ops.py b/python/paddle/fluid/dataloader/ops.py new file mode 100755 index 00000000000000..a7a6288483cf3b --- /dev/null +++ b/python/paddle/fluid/dataloader/ops.py @@ -0,0 +1,343 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle + +from ...fluid import core, framework, Program, program_guard, unique_name +from ...fluid.layers.utils import _hash_with_id +from ..layer_helper import LayerHelper +from ...fluid.framework import _non_static_mode + +from collections.abc import Sequence, Mapping + +__all__ = ["map", "data_reader"] + + +def _to_list(l): + if isinstance(l, (list, tuple)): + return l + return [l] + + +class _ProgramGuard(object): + def __init__(self, main_program): + if not isinstance(main_program, Program): + raise TypeError("MapGuard should init with a Program") + self._main_program = main_program + + def __enter__(self): + self._main_program._create_block() + + def __exit__(self, exc_type, exc_val, exc_tb): + self._main_program._rollback() + return exc_type is None + + +class _StreamIDGenerator(object): + def __init__(self): + self.stream_id = 0 + + def get_stream_id(self): + self.stream_id += 1 + return self.stream_id - 1 + + +_stream_id_generator = _StreamIDGenerator() + + +def _generate_stream_id(): + return _stream_id_generator.get_stream_id() + + +def map(map_func, *args, **kwargs): + """ + This API used to split data loading stages of :attr:`DataPipeline`, the + map function will be run in independent C++ thread and stream. + + Args: + map_func (callable): A callable function construct of data + preprocess OPs. + + Returns: + The output of map function + + Examples: + .. code-block:: python + + import os + import paddle + from paddle.utils.download import get_path_from_url + + DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") + DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" + DATASET_MD5 = "c7110519124a433901cf005a4a91b607" + BATCH_SIZE = 100 + + data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + def imagenet_pipeline(): + image, label = paddle.vision.reader.file_label_reader( + data_root, batch_size=BATCH_SIZE) + + def decode(image): + image = paddle.vision.ops.image_decode_random_crop(image, num_threads=4) + return image + def resize(image): + image = paddle.vision.ops.image_resize(image, size=224) + return image + def flip_normalize(image): + mirror = paddle.vision.ops.random_flip(image, prob=0.5) + image = paddle.vision.ops.mirror_normalize(image, mirror) + return image + + image = paddle.io.map(decode, image) + image = paddle.io.map(resize, image) + image = paddle.io.map(flip_normalize, image) + + return {'image': image, 'label': label} + + dataloader = paddle.io.DataLoader(imagenet_pipeline) + for data in dataloader: + print(data['image'].shape, data['label'].shape) + + """ + if _non_static_mode(): + return map_func(*args, **kwargs) + + helper = LayerHelper("map", **locals()) + + # NOTE: map_func can take List(Tensor) (while batch_size > 1) as + # inputs or outputs, which means we need to keep the structure + # info when calling map_func, _build_program_inputs used to + # generate 3 kinds of infos: + # 1. return value: holds variables in map_block, and keeps the + # structure info of map inputs, will be used to call map_func + # 2. input_vars: holds variables in map_block in flatten format, + # will be used to generate input_var_names + # 3. flat_inputs: holds variables in main_program/global_block in + # flatten format, will be used as inputs for appendding map OP + # and _parse_program_outputs follows similar logic + def _build_program_inputs(inputs, map_block, input_vars=[], flat_inputs=[]): + if isinstance(inputs, Sequence): + return [ + _build_program_inputs(inp, map_block, input_vars, flat_inputs) + for inp in inputs + ] + elif isinstance(inputs, Mapping): + return { + k: _build_program_inputs(v, map_block, input_vars, flat_inputs) + for k, v in inputs.items() + } + else: + var = map_block.create_var( + name=unique_name.generate("map_sub"), + type=inputs.desc.type(), + dtype=inputs.desc.dtype(), + persistable=False) + input_vars.append(var) + flat_inputs.append(inputs) + return var + + def _parse_program_outputs(outputs, output_vars=[], flat_outputs=[]): + if isinstance(outputs, Sequence): + return [ + _parse_program_outputs(outp, output_vars, flat_outputs) + for outp in outputs + ] + elif isinstance(outputs, Mapping): + return { + k: _parse_program_outputs(v, output_vars, flat_outputs) + for outp in outputs + } + else: + var = helper.create_variable( + name=unique_name.generate("map"), + type=outputs.desc.type(), + dtype=outputs.desc.dtype(), + persistable=True) + flat_outputs.append(var) + output_vars.append(outputs) + return var + + # build map block + main_program = helper.main_program + with _ProgramGuard(main_program): + program_id = _hash_with_id(main_program, map_func) + map_block = main_program.current_block() + + input_vars, flat_inputs = [], [] + program_inputs_args = _build_program_inputs(args, map_block, input_vars, + flat_inputs) + program_inputs_kwargs = _build_program_inputs(kwargs, map_block, + input_vars, flat_inputs) + + program_outputs = map_func(*program_inputs_args, + **program_inputs_kwargs) + + # NOTE: _parse_program_outputs create main_program variables, so + # we need to call it outside of _ProgramGuard + output_vars, flat_outputs = [], [] + outputs = _parse_program_outputs(program_outputs, output_vars, flat_outputs) + input_var_names = [v.name for v in input_vars] + output_var_names = [v.name for v in output_vars] + + attrs = { + "map_block": map_block, + "program_id": program_id, + "input_var_names": input_var_names, + "output_var_names": output_var_names + } + + stream_id = _generate_stream_id() + for idx in range(map_block.desc.op_size()): + map_block.desc.op(idx)._set_attr('_stream_id', stream_id) + + helper.append_op( + type="map", + inputs={"In": flat_inputs}, + outputs={"Out": flat_outputs}, + attrs=attrs) + + return outputs + + +def data_reader(reader_func, + batch_size=1, + num_samples=1, + shuffle=False, + drop_last=False, + seed=None): + """ + This API used to auto loading dataset in :attr:`DataPipeline`, the + reader function will be run in independent C++ thread. + + Args: + reader_func (callable): A callable function construct of a data + loader OP. + batch_size (int): The batch size of a mini-batch. Default 1. + shuffle (bool): Whether to shuffle samples. Default False. + drop_last (bool): Whether to drop the last incomplete batch. Default False. + seed (int, optional): The seed for sample shuffling. Default None. + + Returns: + The output of reader function + + Examples: + .. code-block:: python + + import os + import paddle + from paddle.utils.download import get_path_from_url + + DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") + DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" + DATASET_MD5 = "c7110519124a433901cf005a4a91b607" + BATCH_SIZE = 100 + NUM_SAMPLES = 100 + + data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + def imagenet_pipeline(): + def imagenet_reader(indices): + return paddle.vision.reader.file_label_loader( + data_root, indices, BATCH_SIZE) + + outs = paddle.io.data_reader(imagenet_reader, + BATCH_SIZE, NUM_SAMPLES) + image = outs[:-1] + label = outs[-1] + + def decode(image): + image = paddle.vision.ops.image_decode_random_crop(image, num_threads=4) + return image + def resize(image): + image = paddle.vision.ops.image_resize(image, size=224) + return image + + image = paddle.io.map(decode, image) + image = paddle.io.map(resize, image) + + return {'image': image, 'label': label} + + dataloader = paddle.io.DataLoader(imagenet_pipeline) + for data in dataloader: + print(data['image'].shape, data['label'].shape) + + """ + assert not _non_static_mode(), \ + "paddle.io.data_reader can only be used in static mode" + helper = LayerHelper("data_reader", **locals()) + + # build reader block + main_program = helper.main_program + with _ProgramGuard(main_program): + reader_block = main_program.current_block() + + indices_var = reader_block.create_var( + name=unique_name.generate("data_reader_sub"), + type=core.VarDesc.VarType.LOD_TENSOR, + dtype="int64", + persistable=False) + program_outputs = reader_func(indices_var) + program_outputs = _to_list(program_outputs) + + indices_var_name = indices_var.name + output_var_names = [] + for outs in program_outputs: + if isinstance(outs, (list, tuple)): + for out in outs: + output_var_names.append(out.name) + else: + output_var_names.append(outs.name) + + outputs = [] + for outps in program_outputs: + if isinstance(outps, (list, tuple)): + for outp in outps: + outputs.append( + helper.create_variable( + name=unique_name.generate("data_reader"), + type=outp.desc.type(), + dtype=outp.desc.dtype(), + persistable=True)) + else: + outputs.append( + helper.create_variable( + name=unique_name.generate("data_reader"), + type=outps.desc.type(), + dtype=outps.desc.dtype(), + persistable=True)) + + attrs = { + "reader_id": _hash_with_id(main_program), + "reader_block": reader_block, + "indices_var_name": indices_var_name, + "output_var_names": output_var_names, + "batch_size": batch_size, + "num_samples": num_samples, + "shuffle": shuffle, + "drop_last": drop_last, + "seed": 0 if seed is None else seed, + "rank": paddle.distributed.get_rank(), + "world_size": paddle.distributed.get_world_size() + } + + helper.append_op( + type="data_reader", inputs={}, outputs={"Out": outputs}, attrs=attrs) + + return outputs diff --git a/python/paddle/fluid/dataloader/pipeline.py b/python/paddle/fluid/dataloader/pipeline.py new file mode 100755 index 00000000000000..1aee03e4b93e45 --- /dev/null +++ b/python/paddle/fluid/dataloader/pipeline.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid + +from paddle import _C_ops +from paddle.fluid import core, framework +from paddle.fluid.layers.utils import _hash_with_id +from ..multiprocess_utils import CleanupFuncRegistrar + +from collections.abc import Sequence, Mapping + +__all__ = ["DataPipeline"] + +CleanupFuncRegistrar.register(core._shutdown_all_dataloaders) + +AVAILABLE_OP_TYPES = ['data_reader', 'map'] + + +class DataPipeline(object): + """ + Data pipeline + + Args: + queue_depth(int): queue depth for caching data between OPs + """ + + def __init__(self, queue_depth=2): + assert isinstance(queue_depth, int), \ + "queue_depth should be an integer" + self._queue_depth = queue_depth + self._init_programs() + + self.is_shutdown = False + + if paddle.distributed.ParallelEnv().nranks > 1: + paddle.set_device('gpu:%d' % + paddle.distributed.ParallelEnv().dev_id) + + def _init_programs(self): + self._main_program = fluid.Program() + self._out_vars = [] + self._out_names = [] + self._is_built = False + + def __enter__(self): + # switch main and startup program + paddle.enable_static() + self._main_program = framework.switch_main_program(self._main_program) + return self + + def __exit__(self, exception_type, exception_value, traceback): + self._main_program = framework.switch_main_program(self._main_program) + + local_rank = paddle.distributed.get_rank() + paddle.disable_static("gpu:" + str(local_rank)) + + self._check_op_type() + + def _check_op_type(self): + for op in self._main_program.block(0).ops: + if op.type not in ['data_reader', 'map']: + raise RuntimeError( + "pipeline given to DataLoader.from_pipeline should be " + "composed of reader OPs and map OP, other OPs(e.g. " + "decoder OPs or Paddle OPs) should be run under " + "paddle.io.map") + + def set_outputs(self, outputs): + if isinstance(outputs, Sequence): + for var in outputs: + self._out_vars.append(output) + elif isinstance(outputs, Mapping): + for name, var in outputs.items(): + self._out_vars.append(var) + self._out_names.append(name) + else: + assert isinstance(outputs, fluid.Variable), \ + "outputs should be list, dict or Variable" + + def build(self): + global_block = self._main_program.desc.block(0) + self._program_id = _hash_with_id(self._main_program, self) + + self._attrs = ('global_block', global_block, 'start_op_index', 0, + 'end_op_index', global_block.op_size(), 'program_id', + self._program_id) + self._is_built = True + + def _prepare_output_vars(self): + output_vars = [] + for var in self._out_vars: + if isinstance(var, (list, tuple)): + var = var[0] + assert isinstance(var, framework.Variable), \ + "output of DataLoader program should be Variable" + var_desc = var.desc + output_var = core.VarBase(var_desc.dtype(), + var_desc.shape(), + var_desc.name(), var_desc.type(), False) + output_vars.append(output_var) + + return output_vars + + def __iter__(self): + return self + + def __next__(self): + assert self._is_built, \ + "Pipeline not built, please call build() firstly" + self._output_vars = self._prepare_output_vars() + + try: + _C_ops.dataloader(self._output_vars, *self._attrs) + except KeyboardInterrupt: + pass + except: + raise StopIteration + + if paddle.distributed.ParallelEnv().nranks > 1: + paddle.distributed.barrier() + return {k: v for k, v in zip(self._out_names, self._output_vars)} + + # Python 2 compatable + def next(self): + return self.__next__() + + def reset(self): + reader_id = _hash_with_id(self._main_program) + + map_ids = [] + for op in self._main_program.block(0).ops: + if op.type == "map" and op.has_attr('program_id'): + map_ids.append(op.attr('program_id')) + + core._reset_dataloader(reader_id, map_ids, self._program_id) + + def shutdown(self): + if not self.is_shutdown: + try: + program_id = _hash_with_id(self._main_program) + core._shutdown_readers_and_decoders(program_id) + core._shutdown_pipeline(program_id) + del self._main_program + finally: + self.is_shutdown = True + + def __del__(self): + self.shutdown() diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 0f5f217442135f..907fe62a1c9f15 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -22,7 +22,7 @@ from .executor import global_scope from .data_feeder import DataFeeder, BatchedTensorProvider from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler -from .dataloader import BatchSampler, Dataset, IterableDataset +from .dataloader import BatchSampler, Dataset, IterableDataset, DataPipeline from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, _DatasetKind, default_collate_fn from .dataloader.batch_sampler import _InfiniteIterableSampler from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer @@ -184,9 +184,17 @@ class DataLoader(object): Args: - dataset(Dataset): the dataset to load data from, should be an - instance of subclass of :code:`paddle.io.Dataset` or - :code:`paddle.io.IterableDataset`. + dataset(Dataset|callable): the dataset to load data from, there + are 2 available types: + 1. an instance of subclass of :code:`paddle.io.Dataset` or + :code:`paddle.io.IterableDataset` for Python multi-process + DataLoader. + 2. a callable function constructed with + :code:`paddle.io.data_reader`, :code:`paddle.io.map` and other + data processing OPs from :code:`paddle.vision.ops` for C++ + multi-thread and multi-stream DataLoader. Only support data + preprocessing of ImageNet dataset currently. Please see + :code:`paddle.io.map` for example codes. feed_list (list(Tensor)|tuple(Tensor)): feed Tensor list. The Tensors should be created by :code:`paddle.static.data()`. :attr:`feed_list` must be set if :attr:`return_list` is @@ -327,6 +335,17 @@ def __init__(self, timeout=0, worker_init_fn=None, persistent_workers=False): + + # Whether use multi-stream/thread GPU DataLoader + self._use_data_pipeline = False + if callable(dataset): + self._use_data_pipeline = True + with DataPipeline() as self._data_pipeline: + outputs = dataset() + self._data_pipeline.set_outputs(outputs) + self._data_pipeline.build() + return + self.return_list = return_list self.collate_fn = collate_fn self.use_buffer_reader = use_buffer_reader @@ -420,6 +439,11 @@ def __len__(self): return len(self.dataset) def __iter__(self): + # use DataPipeline + if self._use_data_pipeline: + return self._data_pipeline + + # use multi-process DataLoader if self.num_workers == 0: return _DataLoaderIterSingleProcess(self) elif self._persistent_workers: @@ -434,6 +458,13 @@ def __iter__(self): def __call__(self): return self.__iter__() + def reset(self): + assert self._use_data_pipeline, \ + "reset() can only be used in DataPipeline mode, "\ + "which takes callabe function as dataset input "\ + "instead of paddle.io.Dataset" + self._data_pipeline.reset() + @staticmethod def from_generator(feed_list=None, capacity=None, diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index 5781f78c6e4e4a..b487a3a4dfce37 100755 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -29,6 +29,8 @@ from ..fluid.dataloader import WeightedRandomSampler # noqa: F401 from ..fluid.dataloader import Subset # noqa: F401 from ..fluid.dataloader import random_split # noqa: F401 +from ..fluid.dataloader import map # noqa: F401 +from ..fluid.dataloader import data_reader # noqa: F401 __all__ = [ #noqa 'Dataset', @@ -45,5 +47,7 @@ 'RandomSampler', 'WeightedRandomSampler', 'random_split', - 'Subset' + 'Subset', + 'map', + 'data_reader', ] diff --git a/python/paddle/tests/CMakeLists.txt b/python/paddle/tests/CMakeLists.txt index bc9f402ed9686d..dd4761e58654be 100644 --- a/python/paddle/tests/CMakeLists.txt +++ b/python/paddle/tests/CMakeLists.txt @@ -8,6 +8,16 @@ foreach(TEST_OP ${DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() +if (WIN32) + LIST(REMOVE_ITEM TEST_OPS test_data_pipeline_static) + LIST(REMOVE_ITEM TEST_OPS test_data_pipeline_dynamic) + LIST(REMOVE_ITEM TEST_OPS test_ops_file_label_loader) + LIST(REMOVE_ITEM TEST_OPS test_ops_decode) + LIST(REMOVE_ITEM TEST_OPS test_ops_crop_resize) + LIST(REMOVE_ITEM TEST_OPS test_ops_file_label_loader) + LIST(REMOVE_ITEM TEST_OPS test_ops_mirror_normalize) + LIST(REMOVE_ITEM TEST_OPS test_data_apis) +endif() if(NOT WITH_COVERAGE) LIST(REMOVE_ITEM TEST_OPS test_hapi_hub) endif() diff --git a/python/paddle/tests/test_data_apis.py b/python/paddle/tests/test_data_apis.py new file mode 100644 index 00000000000000..43ed9e6ff7c9c7 --- /dev/null +++ b/python/paddle/tests/test_data_apis.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import numpy as np + +import paddle +import paddle.fluid.core as core +from paddle.vision.ops import random_flip + + +class TestRandomFlip(unittest.TestCase): + def test_errors(self): + try: + data = paddle.ones([16, 3, 32, 32], dtype="float32") + out = random_flip(data, 1.5) + + # should not execute following lines + assert False + except ValueError: + pass + + try: + data = paddle.ones([16, 3, 32, 32], dtype="float32") + out = random_flip(data, -0.5) + + # should not execute following lines + assert False + except ValueError: + pass + + def test_output_dynamic(self): + data = paddle.ones([16, 3, 32, 32], dtype="float32") + out = random_flip(data, 0.5) + + assert out.dtype == paddle.bool + assert out.shape == [16, 1] + + def test_output_static(self): + paddle.enable_static() + input_data = paddle.static.data( + shape=[16, 3, 32, 32], dtype="float32", name="input") + out_data = random_flip(input_data, 0.5) + + places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + + for place in places: + exe = paddle.static.Executor(place) + out, = exe.run( + paddle.static.default_main_program(), + feed={"input": np.ones( + [16, 3, 32, 32], dtype="float32")}, + fetch_list=[out_data]) + assert out.dtype == np.bool + assert out.shape == (16, 1) + paddle.disable_static() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_data_pipeline_dynamic.py b/python/paddle/tests/test_data_pipeline_dynamic.py new file mode 100644 index 00000000000000..dfd7ef11ed065c --- /dev/null +++ b/python/paddle/tests/test_data_pipeline_dynamic.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.utils.download import get_path_from_url +from paddle.vision.datasets import DatasetFolder +from paddle.vision.ops import image_decode_random_crop, image_resize, \ + random_flip, mirror_normalize +from paddle.vision.reader import file_label_reader + +import test_data_pipeline_static +from test_data_pipeline_static import DATASET_HOME, DATASET_URL, \ + DATASET_MD5, IMAGE_NUM + +DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") +DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" +DATASET_MD5 = "c7110519124a433901cf005a4a91b607" +IMAGE_NUM = 100 + + +class TestDataPipelineDynamicCase1( + test_data_pipeline_static.TestDataPipelineStaticCase1): + def test_output(self): + # NOTE: only supoort CUDA kernel currently + if not core.is_compiled_with_cuda(): + return + + data = self.reader() + + image = data['image'].numpy() + assert image.shape[0] == self.batch_size + assert image.shape[1] == 3 + assert image.shape[2] == self.target_size + assert image.shape[3] == self.target_size + assert image.dtype == np.float32 + + restore_image = image * self.std_np + self.mean_np + assert np.all(restore_image > -1.) + assert np.all(restore_image < 256.) + + label = data['label'].numpy() + assert label.shape[0] == self.batch_size + assert label.dtype == np.int64 + assert np.all(label >= 0) + assert np.all(label <= 1) + + +class TestDataPipelineDynamicCase2(TestDataPipelineDynamicCase1): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.num_epoches = 1 + self.batch_size = 16 + self.num_threads = 4 + self.host_memory_padding = 0 + self.device_memory_padding = 0 + + self.shuffle = True + self.drop_last = True + self.calc_iter_info() + + self.target_size = 128 + self.flip_prob = 0.5 + self.mean = [123.675, 116.28, 103.53] + self.std = [58.395, 57.120, 57.375] + + self.mean_np = np.array(self.mean).reshape([1, 3, 1, 1]) + self.std_np = np.array(self.std).reshape([1, 3, 1, 1]) + + self.build_reader() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_data_pipeline_static.py b/python/paddle/tests/test_data_pipeline_static.py new file mode 100644 index 00000000000000..3db5755c953785 --- /dev/null +++ b/python/paddle/tests/test_data_pipeline_static.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.utils.download import get_path_from_url +from paddle.vision.datasets import DatasetFolder +from paddle.vision.ops import image_decode_random_crop, image_resize, \ + random_flip, mirror_normalize +from paddle.vision.reader import file_label_reader + +DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") +DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" +DATASET_MD5 = "c7110519124a433901cf005a4a91b607" +IMAGE_NUM = 100 + + +class TestDataPipelineStaticCase1(unittest.TestCase): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.num_epoches = 2 + self.batch_size = 16 + self.num_threads = 2 + self.host_memory_padding = 1000000 + self.device_memory_padding = 1000000 + + self.shuffle = False + self.drop_last = True + self.calc_iter_info() + + self.target_size = 224 + self.flip_prob = 0.5 + self.mean = [123.675, 116.28, 103.53] + self.std = [58.395, 57.120, 57.375] + + self.mean_np = np.array(self.mean).reshape([1, 3, 1, 1]) + self.std_np = np.array(self.std).reshape([1, 3, 1, 1]) + + self.build_reader() + + def calc_iter_info(self): + if self.drop_last: + self.num_iters = IMAGE_NUM // self.batch_size + else: + self.num_iters = (IMAGE_NUM + self.batch_size - 1) \ + // self.batch_size + + if self.drop_last: + self.last_batch_size = self.batch_size + else: + self.last_batch_size = IMAGE_NUM % self.batch_size + if self.last_batch_size == 0: + self.last_batch_size = self.batch_size + + def build_reader(self): + def imagenet_reader(): + image, label = file_label_reader( + self.data_root, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_last=self.drop_last) + + def decode(image): + image = image_decode_random_crop( + image, num_threads=self.num_threads) + return image + + def resize(image): + image = image_resize(image, size=self.target_size) + return image + + def flip_normalize(image): + mirror = random_flip(image, prob=self.flip_prob) + image = mirror_normalize( + image, mirror, mean=self.mean, std=self.std) + return image + + image = paddle.io.map(decode, image) + image = paddle.io.map(resize, image) + image = paddle.io.map(flip_normalize, image) + + return {'image': image, 'label': label} + + self.reader = imagenet_reader + + def test_output(self): + # NOTE: only supoort CUDA kernel currently + if not core.is_compiled_with_cuda(): + return + + loader = paddle.io.DataLoader(self.reader) + + for eid in range(self.num_epoches): + num_iters = 0 + for data in loader: + image = data['image'].numpy() + assert image.shape[0] == self.batch_size + assert image.shape[1] == 3 + assert image.shape[2] == self.target_size + assert image.shape[3] == self.target_size + assert image.dtype == np.float32 + + restore_image = image * self.std_np + self.mean_np + assert np.all(restore_image > -1.) + assert np.all(restore_image < 256.) + + label = data['label'].numpy() + assert label.shape[0] == self.batch_size + assert label.dtype == np.int64 + assert np.all(label >= 0) + assert np.all(label <= 1) + + num_iters += 1 + + assert num_iters == self.num_iters + if eid < self.num_epoches - 1: + loader.reset() + + +class TestDataPipelineStaticCase2(TestDataPipelineStaticCase1): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.num_epoches = 1 + self.batch_size = 32 + self.num_threads = 4 + self.host_memory_padding = 0 + self.device_memory_padding = 0 + + self.shuffle = True + self.drop_last = True + self.calc_iter_info() + + self.target_size = 128 + self.flip_prob = 0.5 + self.mean = [123.675, 116.28, 103.53] + self.std = [58.395, 57.120, 57.375] + + self.mean_np = np.array(self.mean).reshape([1, 3, 1, 1]) + self.std_np = np.array(self.std).reshape([1, 3, 1, 1]) + + self.build_reader() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_ops_crop_resize.py b/python/paddle/tests/test_ops_crop_resize.py new file mode 100644 index 00000000000000..87df7a2eb37689 --- /dev/null +++ b/python/paddle/tests/test_ops_crop_resize.py @@ -0,0 +1,545 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.vision.ops import image_resize, random_crop_and_resize + + +def np_nearest_interp(image, size, align_corners=True, data_format='NCHW'): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if isinstance(size, int): + size = (size, size) + + if data_format == "NHWC": + image = np.transpose(image, (2, 0, 1)) # HWC => CHW + + channel, in_h, in_w = image.shape + out_h, out_w = size + + ratio_h = ratio_w = 0.0 + if (out_h > 1): + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if (out_w > 1): + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((channel, out_h, out_w)) + + if align_corners: + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, i, j] = image[:, in_i, in_j] + else: + for i in range(out_h): + in_i = int(ratio_h * i) + for j in range(out_w): + in_j = int(ratio_w * j) + out[:, i, j] = image[:, in_i, in_j] + + if data_format == "NHWC": + out = np.transpose(out, (1, 2, 0)) # CHW => HWC + + return out.astype(image.dtype) + + +def np_bilinear_interp(image, + size, + align_corners=True, + align_mode=0, + data_format='NCHW'): + """bilinear interpolation implement in shape [N, C, H, W]""" + if isinstance(size, int): + size = (size, size) + + if data_format == "NHWC": + image = np.transpose(image, (2, 0, 1)) # HWC => CHW + + channel, in_h, in_w = image.shape + out_h, out_w = size + + ratio_h = ratio_w = 0.0 + if out_h > 1: + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if out_w > 1: + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((channel, out_h, out_w)) + + for i in range(out_h): + if (align_mode == 0 and not align_corners): + h = int(ratio_h * (i + 0.5) - 0.5) + else: + h = int(ratio_h * i) + + h = max(0, h) + hid = 1 if h < in_h - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_h = max(ratio_h * (i + 0.5) - 0.5, 0) + h1lambda = idx_src_h - h + else: + h1lambda = ratio_h * i - h + h2lambda = 1.0 - h1lambda + for j in range(out_w): + if (align_mode == 0 and not align_corners): + w = int(ratio_w * (j + 0.5) - 0.5) + else: + w = int(ratio_w * j) + w = max(0, w) + wid = 1 if w < in_w - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0) + w1lambda = idx_src_w - w + else: + w1lambda = ratio_w * j - w + w2lambda = 1.0 - w1lambda + + out[:, i, j] = h2lambda*(w2lambda * image[:, h, w] + + w1lambda * image[:, h, w+wid]) + \ + h1lambda*(w2lambda * image[:, h+hid, w] + + w1lambda * image[:, h+hid, w+wid]) + + if data_format == "NHWC": + out = np.transpose(out, (1, 2, 0)) # CHW => HWC + + return out.astype(image.dtype) + + +def np_image_resize(images, + size, + interp_method, + align_corners=True, + align_mode=1, + data_format="NCHW"): + if isinstance(size, int): + size = (size, size) + + results = [] + if interp_method == "nearest": + for image in images: + results.append( + np_nearest_interp( + image, + size=size, + align_corners=align_corners, + data_format=data_format)) + elif interp_method == "bilinear": + for image in images: + results.append( + np_bilinear_interp( + image, + size=size, + align_corners=align_corners, + align_mode=align_mode, + data_format=data_format)) + else: + raise ValueError("unknown interp_method") + + return np.stack(results, axis=0) + + +class TestImageResizeNearestNCHW(unittest.TestCase): + def setUp(self): + self.image_shape1 = [3, 32, 32] + self.image_shape2 = [3, 16, 16] + self.size = (20, 30) + self.interp_method = "nearest" + self.data_format = "NCHW" + self.align_corners = False + self.align_mode = 1 + + self.build_np_data() + + def build_np_data(self): + self.image1 = np.random.randint( + 0, 256, self.image_shape1, dtype="uint8") + self.image2 = np.random.randint( + 0, 256, self.image_shape2, dtype="uint8") + self.np_result = np_image_resize( + [self.image1, self.image2], + size=self.size, + interp_method=self.interp_method, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format) + + def test_output_dynamic(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + paddle.disable_static() + + images = paddle.tensor.create_array(dtype="uint8") + images = paddle.tensor.array_write( + paddle.to_tensor(self.image1), paddle.to_tensor(0), images) + images = paddle.tensor.array_write( + paddle.to_tensor(self.image2), paddle.to_tensor(1), images) + + result = image_resize( + images, + self.size, + interp_method=self.interp_method, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format) + assert np.allclose(result.numpy(), self.np_result, rtol=1) + + def test_output_static(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + paddle.enable_static() + + image1 = fluid.layers.assign(self.image1.astype('int32')) + image1 = fluid.layers.cast(image1, dtype='uint8') + + image2 = fluid.layers.assign(self.image2.astype('int32')) + image2 = fluid.layers.cast(image2, dtype='uint8') + + out = image_resize( + [image1, image2], + self.size, + interp_method=self.interp_method, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format) + + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + result, = exe.run(paddle.static.default_main_program(), + fetch_list=[out]) + assert np.allclose(result, self.np_result, rtol=1) + + paddle.disable_static() + + +class TestImageResizeNearestNHWC(TestImageResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [32, 32, 3] + self.image_shape2 = [16, 16, 3] + self.size = 20 + self.interp_method = "nearest" + self.data_format = "NHWC" + self.align_corners = True + self.align_mode = 1 + + self.build_np_data() + + +class TestImageResizeNearestNCHWAlignCorner(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [3, 32, 32] + self.image_shape2 = [3, 16, 16] + self.size = 30 + self.interp_method = "nearest" + self.data_format = "NCHW" + self.align_corners = True + self.align_mode = 1 + + self.build_np_data() + + +class TestImageResizeNearestNHWCAlignCorner(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [32, 32, 3] + self.image_shape2 = [16, 16, 3] + self.size = (20, 30) + self.interp_method = "nearest" + self.data_format = "NHWC" + self.align_corners = True + self.align_mode = 1 + + self.build_np_data() + + +class TestImageResizeBilinearNCHW(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [3, 32, 32] + self.image_shape2 = [3, 16, 16] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NCHW" + self.align_corners = False + self.align_mode = 1 + + self.build_np_data() + + +class TestImageResizeBilinearNHWC(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [32, 32, 3] + self.image_shape2 = [16, 16, 3] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NHWC" + self.align_corners = False + self.align_mode = 1 + + self.build_np_data() + + +class TestImageResizeBilinearNCHWAlignMode0(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [3, 32, 32] + self.image_shape2 = [3, 16, 16] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NCHW" + self.align_corners = False + self.align_mode = 0 + + self.build_np_data() + + +class TestImageResizeBilinearNHWCAlignMode0(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [32, 32, 3] + self.image_shape2 = [16, 16, 3] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NHWC" + self.align_corners = False + self.align_mode = 0 + + self.build_np_data() + + +class TestImageResizeBilinearNCHWAlignCorner(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [3, 32, 32] + self.image_shape2 = [3, 16, 16] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NCHW" + self.align_corners = True + self.align_mode = 1 + + self.build_np_data() + + +class TestImageResizeBilinearNHWCAlignCorner(TestImageResizeNearestNHWC): + def setUp(self): + self.image_shape1 = [32, 32, 3] + self.image_shape2 = [16, 16, 3] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NHWC" + self.align_corners = True + self.align_mode = 1 + + self.build_np_data() + + +class TestImageCropResizeNearestNCHW(unittest.TestCase): + def setUp(self): + self.image_shape1 = [3, 16, 16] + self.image_shape2 = [3, 32, 32] + self.size = (20, 30) + self.interp_method = "nearest" + self.data_format = "NCHW" + self.align_corners = False + self.align_mode = 1 + + self.out_shape = (2, 3, 20, 30) + + self.build_np_data() + + def build_np_data(self): + self.image1 = np.random.randint( + 0, 256, self.image_shape1, dtype="uint8") + self.image2 = np.random.randint( + 0, 256, self.image_shape2, dtype="uint8") + + def test_output_dynamic(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + paddle.disable_static() + + images = paddle.tensor.create_array(dtype="uint8") + images = paddle.tensor.array_write( + paddle.to_tensor(self.image1), paddle.to_tensor(0), images) + images = paddle.tensor.array_write( + paddle.to_tensor(self.image2), paddle.to_tensor(1), images) + + result = random_crop_and_resize( + images, + self.size, + interp_method=self.interp_method, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format) + result = result.numpy() + assert result.shape == self.out_shape + assert result.dtype == np.uint8 + + def test_output_static(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + paddle.enable_static() + + image1 = fluid.layers.assign(self.image1.astype('int32')) + image1 = fluid.layers.cast(image1, dtype='uint8') + + image2 = fluid.layers.assign(self.image2.astype('int32')) + image2 = fluid.layers.cast(image2, dtype='uint8') + + out = random_crop_and_resize( + [image1, image2], + self.size, + interp_method=self.interp_method, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format) + + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + result, = exe.run(paddle.static.default_main_program(), + fetch_list=[out]) + assert result.shape == self.out_shape + assert result.dtype == np.uint8 + + paddle.disable_static() + + +class TestImageCropResizeNearestNHWC(TestImageCropResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [16, 16, 3] + self.image_shape2 = [32, 32, 3] + self.size = 20 + self.interp_method = "nearest" + self.data_format = "NHWC" + self.align_corners = False + self.align_mode = 1 + + self.out_shape = (2, 20, 20, 3) + + self.build_np_data() + + +class TestImageCropResizeNearestNCHWAlignCorner(TestImageCropResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [3, 16, 16] + self.image_shape2 = [3, 32, 32] + self.size = 20 + self.interp_method = "nearest" + self.data_format = "NCHW" + self.align_corners = True + self.align_mode = 1 + + self.out_shape = (2, 3, 20, 20) + + self.build_np_data() + + +class TestImageCropResizeNearestNHWCAlignCorner(TestImageCropResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [16, 16, 3] + self.image_shape2 = [32, 32, 3] + self.size = (20, 30) + self.interp_method = "nearest" + self.data_format = "NHWC" + self.align_corners = True + self.align_mode = 1 + + self.out_shape = (2, 20, 30, 3) + + self.build_np_data() + + +class TestImageCropResizeBilinearNCHW(TestImageCropResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [3, 16, 16] + self.image_shape2 = [3, 32, 32] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NCHW" + self.align_corners = False + self.align_mode = 1 + + self.out_shape = (2, 3, 20, 30) + + self.build_np_data() + + +class TestImageCropResizeBilinearNCHWAlignMode0(TestImageCropResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [3, 16, 16] + self.image_shape2 = [3, 32, 32] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NCHW" + self.align_corners = False + self.align_mode = 0 + + self.out_shape = (2, 3, 20, 30) + + self.build_np_data() + + +class TestImageCropResizeNearestNHWCAlignMode0(TestImageCropResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [16, 16, 3] + self.image_shape2 = [32, 32, 3] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NHWC" + self.align_corners = False + self.align_mode = 0 + + self.out_shape = (2, 20, 30, 3) + + self.build_np_data() + + +class TestImageCropResizeBilinearNCHWAlignCorner( + TestImageCropResizeNearestNCHW): + def setUp(self): + self.image_shape1 = [3, 16, 16] + self.image_shape2 = [3, 32, 32] + self.size = (20, 30) + self.interp_method = "bilinear" + self.data_format = "NCHW" + self.align_corners = True + self.align_mode = 1 + + self.out_shape = (2, 3, 20, 30) + + self.build_np_data() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_ops_decode.py b/python/paddle/tests/test_ops_decode.py new file mode 100644 index 00000000000000..4bff90f7672fbb --- /dev/null +++ b/python/paddle/tests/test_ops_decode.py @@ -0,0 +1,247 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.utils.download import get_path_from_url +from paddle.vision.datasets import DatasetFolder +from paddle.vision.ops import image_decode, image_decode_random_crop +from paddle.vision.reader import file_label_loader + +DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") +DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" +DATASET_MD5 = "c7110519124a433901cf005a4a91b607" + + +class TestImageReaderDecodeCase1(unittest.TestCase): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.batch_size = 16 + self.num_threads = 2 + self.host_memory_padding = 1000000 + self.device_memory_padding = 1000000 + + def test_static_output(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + paddle.enable_static() + + indices = paddle.arange(self.batch_size) + image, label = file_label_loader(self.data_root, indices, + self.batch_size) + image = image_decode(image, num_threads=self.num_threads) + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + rets = exe.run(paddle.static.default_main_program(), + fetch_list=image + [label]) + out_image = rets[:-1] + out_label = rets[-1] + + assert len(out_image) == self.batch_size + for i in range(self.batch_size): + img = np.array(out_image[i]) + assert img.dtype == np.uint8 + assert img.shape[2] == 3 + assert np.all(img >= 0) + assert np.all(img <= 255) + + label = np.array(out_label) + assert label.dtype == np.int64 + assert label.shape[0] == self.batch_size + assert np.all(label >= 0) + assert np.all(label <= 1) + + paddle.disable_static() + + def test_dynamic_output(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + indices = paddle.arange(self.batch_size) + image, label = file_label_loader(self.data_root, indices, + self.batch_size) + image = image_decode(image, num_threads=self.num_threads) + + assert len(image) == self.batch_size + for i in range(self.batch_size): + img = image[i].numpy() + assert img.dtype == np.uint8 + assert img.shape[2] == 3 + assert np.all(img >= 0) + assert np.all(img <= 255) + + label = label.numpy() + assert label.dtype == np.int64 + assert label.shape[0] == self.batch_size + assert np.all(label >= 0) + assert np.all(label <= 1) + + +class TestImageReaderDecodeCase2(TestImageReaderDecodeCase1): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.batch_size = 32 + self.num_threads = 4 + self.host_memory_padding = 0 + self.device_memory_padding = 0 + + +class TestImageReaderDecodeRandomCropNCHW(unittest.TestCase): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.batch_size = 16 + self.num_threads = 2 + self.host_memory_padding = 1000000 + self.device_memory_padding = 1000000 + + self.aspect_ratio_min = 3. / 4. + self.aspect_ratio_max = 4. / 3. + self.area_min = 0.08 + self.area_max = 1.0 + self.num_attempts = 10 + + self.data_format = "NCHW" + self.channel_dim = 0 + + def test_static_output(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + paddle.enable_static() + + indices = paddle.arange(self.batch_size) + image, label = file_label_loader(self.data_root, indices, + self.batch_size) + image = image_decode_random_crop( + image, + num_threads=self.num_threads, + aspect_ratio_min=self.aspect_ratio_min, + aspect_ratio_max=self.aspect_ratio_max, + area_min=self.area_min, + area_max=self.area_max, + num_attempts=self.num_attempts, + data_format=self.data_format) + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + rets = exe.run(paddle.static.default_main_program(), + fetch_list=image + [label]) + out_image = rets[:-1] + out_label = rets[-1] + + assert len(out_image) == self.batch_size + for i in range(self.batch_size): + img = np.array(out_image[i]) + assert img.dtype == np.uint8 + assert img.shape[self.channel_dim] == 3 + assert np.all(img >= 0) + assert np.all(img <= 255) + + assert len(out_label) == self.batch_size + assert label.dtype == paddle.int64 + label = np.array(out_label) + assert np.all(label >= 0) + assert np.all(label <= 1) + + paddle.disable_static() + + def test_dynamic_output(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + indices = paddle.arange(self.batch_size) + image, label = file_label_loader(self.data_root, indices, + self.batch_size) + image = image_decode_random_crop( + image, + num_threads=self.num_threads, + aspect_ratio_min=self.aspect_ratio_min, + aspect_ratio_max=self.aspect_ratio_max, + area_min=self.area_min, + area_max=self.area_max, + num_attempts=self.num_attempts, + data_format=self.data_format) + + assert len(image) == self.batch_size + for i in range(self.batch_size): + img = image[i].numpy() + assert img.dtype == np.uint8 + assert img.shape[self.channel_dim] == 3 + assert np.all(img >= 0) + assert np.all(img <= 255) + + label = label.numpy() + assert label.shape[0] == self.batch_size + assert label.dtype == np.int64 + assert np.all(label >= 0) + assert np.all(label <= 1) + + +class TestImageReaderDecodeRandomCropNHWC(TestImageReaderDecodeRandomCropNCHW): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.batch_size = 16 + self.num_threads = 4 + self.host_memory_padding = 0 + self.device_memory_padding = 0 + + self.aspect_ratio_min = 4. / 5. + self.aspect_ratio_max = 5. / 4. + self.area_min = 0.1 + self.area_max = 0.9 + self.num_attempts = 20 + + self.data_format = "NHWC" + self.channel_dim = 2 + + +class TestImageReaderDecodeRandomCropThread8( + TestImageReaderDecodeRandomCropNCHW): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + + self.batch_size = 16 + self.num_threads = 8 + self.host_memory_padding = 20000 + self.device_memory_padding = 20000 + + self.aspect_ratio_min = 1. / 2. + self.aspect_ratio_max = 3. / 2. + self.area_min = 0.01 + self.area_max = 0.99 + self.num_attempts = 50 + + self.data_format = "NCHW" + self.channel_dim = 0 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_ops_file_label_loader.py b/python/paddle/tests/test_ops_file_label_loader.py new file mode 100644 index 00000000000000..e8137d190c9c08 --- /dev/null +++ b/python/paddle/tests/test_ops_file_label_loader.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import unittest +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.utils.download import get_path_from_url +from paddle.vision.datasets import DatasetFolder +from paddle.vision.reader import _sampler_manager, file_label_loader + +DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") +DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" +DATASET_MD5 = "c7110519124a433901cf005a4a91b607" + + +class TestFileLabelLoaderStatic(unittest.TestCase): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + self.batch_size = 16 + self.shuffle = False + self.drop_last = False + self.dynamic = False + + def build_program(self): + paddle.enable_static() + self.indices_data = paddle.static.data( + shape=[self.batch_size], dtype='int64', name='indices') + self.sample_data, self.label_data = file_label_loader( + self.data_root, self.indices_data, self.batch_size) + self.exe = paddle.static.Executor(paddle.CPUPlace()) + paddle.disable_static() + + def loader_function(self, indices): + if paddle.in_dynamic_mode(): + indices = paddle.to_tensor(indices) + return file_label_loader(self.data_root, indices, self.batch_size) + else: + paddle.enable_static() + return self.exe.run(paddle.static.default_main_program(), + feed={'indices': indices}, + fetch_list=[self.sample_data, self.label_data]) + + def test_check_output(self): + # NOTE: only support cuda kernel currently + if not core.is_compiled_with_cuda(): + return + + if not self.dynamic: + self.build_program() + + data_folder = DatasetFolder(self.data_root) + samples = [s[0] for s in data_folder.samples] + targets = [s[1] for s in data_folder.samples] + + sampler_id = fluid.layers.utils._hash_with_id( + self.data_root, self.batch_size, self.shuffle, self.drop_last, + self.dynamic) + sampler = _sampler_manager.get(sampler_id, + batch_size=self.batch_size, + num_samples=len(samples), + shuffle=self.shuffle, + drop_last=self.drop_last) + + num_iters = (len(samples) + self.batch_size - 1) // self.batch_size + for _ in range(num_iters): + indices = next(sampler) + sample, target = self.loader_function(indices) + assert np.array_equal(target, np.array(targets)[indices]) + + +class TestFileLabelLoaderDynamic(TestFileLabelLoaderStatic): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + self.batch_size = 16 + self.shuffle = False + self.drop_last = False + self.dynamic = True + + +class TestFileLabelLoaderStaticShuffle(TestFileLabelLoaderStatic): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + self.batch_size = 16 + self.shuffle = True + self.drop_last = False + self.dynamic = False + + +class TestFileLabelLoaderDynamicShuffle(TestFileLabelLoaderStatic): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + self.batch_size = 16 + self.shuffle = True + self.drop_last = False + self.dynamic = True + + +class TestFileLabelLoaderStaticDropLast(TestFileLabelLoaderStatic): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + self.batch_size = 16 + self.shuffle = True + self.drop_last = True + self.dynamic = False + + +class TestFileLabelLoaderDynamicDropLast(TestFileLabelLoaderStatic): + def setUp(self): + self.data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + self.batch_size = 16 + self.shuffle = True + self.drop_last = True + self.dynamic = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_ops_mirror_normalize.py b/python/paddle/tests/test_ops_mirror_normalize.py new file mode 100644 index 00000000000000..fa536eb6629ec4 --- /dev/null +++ b/python/paddle/tests/test_ops_mirror_normalize.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import copy +import unittest +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.vision.ops import mirror_normalize + + +def np_mirror_normalize(image, mirror, mean, std): + image = copy.deepcopy(image) + for i, m in enumerate(mirror): + if m[0]: + image[i] = image[i][:, :, -1::-1] + + mean = np.array(mean) + std = np.array(std) + if np.size(mean) == 1: + mean = np.tile(mean, (3, )) + if np.size(std) == 1: + std = np.tile(std, (3, )) + + mean = np.array(mean[:]).reshape([1, 3, 1, 1]) + std = np.array(std[:]).reshape([1, 3, 1, 1]) + + return (image - mean) / std + + +class TestMirrorNormalize(unittest.TestCase): + def setUp(self): + self.image_shape = [16, 3, 32, 32] + self.mirror_shape = [16, 1] + self.mean = [123.675, 116.28, 103.53] + self.std = [58.395, 57.120, 57.375] + + self.image = np.random.randint(0, 256, self.image_shape, + 'int32').astype("float32") + self.mirror = np.random.randint(0, 2, self.mirror_shape, + 'int32').astype("bool") + + self.result = np_mirror_normalize(self.image, self.mirror, self.mean, + self.std) + + def test_check_output_dynamic(self): + # NOTE: only supoort CUDA kernel currently + if not core.is_compiled_with_cuda(): + return + + dy_result = mirror_normalize( + paddle.to_tensor(self.image), + paddle.to_tensor(self.mirror), self.mean, self.std) + assert np.allclose(self.result, dy_result.numpy()) + + def test_check_output_static(self): + # NOTE: only supoort CUDA kernel currently + if not core.is_compiled_with_cuda(): + return + + paddle.enable_static() + + image_data = paddle.static.data( + shape=self.image_shape, dtype='float32', name="image") + mirror_data = paddle.static.data( + shape=self.mirror_shape, dtype='bool', name="mirror") + result_data = mirror_normalize(image_data, mirror_data, self.mean, + self.std) + + # NOTE: only supoort CUDA kernel currently + places = [] + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + + for place in places: + exe = paddle.static.Executor(place) + st_result = exe.run( + paddle.static.default_main_program(), + feed={"image": self.image, + "mirror": self.mirror}, + fetch_list=[result_data]) + + assert np.allclose(self.result, st_result) + + paddle.disable_static() + + +class TestMirrorNormalizeSingleMeanStd(TestMirrorNormalize): + def setUp(self): + self.image_shape = [16, 3, 32, 32] + self.mirror_shape = [16, 1] + self.mean = [123.675] + self.std = [58.395] + + self.image = np.random.randint(0, 256, self.image_shape, + 'int32').astype("float32") + self.mirror = np.random.randint(0, 2, self.mirror_shape, + 'int32').astype("bool") + + self.result = np_mirror_normalize(self.image, self.mirror, self.mean, + self.std) + + +class TestMirrorNormalizeFloatMeanStd(TestMirrorNormalize): + def setUp(self): + self.image_shape = [16, 3, 32, 32] + self.mirror_shape = [16, 1] + self.mean = 123.675 + self.std = 58.395 + + self.image = np.random.randint(0, 256, self.image_shape, + 'int32').astype("float32") + self.mirror = np.random.randint(0, 2, self.mirror_shape, + 'int32').astype("bool") + + self.result = np_mirror_normalize(self.image, self.mirror, self.mean, + self.std) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/vision/__init__.py b/python/paddle/vision/__init__.py index 3749e0f64fc6a8..7e04c91f3b832d 100644 --- a/python/paddle/vision/__init__.py +++ b/python/paddle/vision/__init__.py @@ -17,6 +17,7 @@ from . import transforms # noqa: F401 from . import datasets # noqa: F401 from . import ops # noqa: F401 +from . import reader # noqa: F401 from .image import set_image_backend # noqa: F401 from .image import get_image_backend # noqa: F401 from .image import image_load # noqa: F401 diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 7d29e4b1c9c180..2bc032b4af1f4d 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -13,12 +13,14 @@ # limitations under the License. import numpy as np -from ..fluid.layer_helper import LayerHelper +from ..fluid.layer_helper import LayerHelper, unique_name from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype -from ..fluid import core, layers +from ..fluid import core, layers, default_main_program from ..fluid.layers import nn, utils from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D from ..fluid.initializer import Normal + +import paddle from ..fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph from paddle.common_ops_import import * from paddle import _C_ops @@ -30,12 +32,17 @@ 'DeformConv2D', 'read_file', 'decode_jpeg', + 'image_decode', + 'image_decode_random_crop', + 'random_flip', 'roi_pool', 'RoIPool', 'psroi_pool', 'PSRoIPool', 'roi_align', 'RoIAlign', + 'random_crop_and_resize', + 'image_resize', 'nms', ] @@ -839,18 +846,20 @@ def read_file(filename, name=None): Examples: .. code-block:: python - - import cv2 import paddle + from paddle.utils.download import get_path_from_url - fake_img = (np.random.random( - (400, 300, 3)) * 255).astype('uint8') + DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") + DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" + DATASET_MD5 = "c7110519124a433901cf005a4a91b607" + BATCH_SIZE = 16 - cv2.imwrite('fake.jpg', fake_img) - - img_bytes = paddle.vision.ops.read_file('fake.jpg') - - print(img_bytes.shape) + data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + indices = paddle.arange(BATCH_SIZE) + outs = paddle.vision.reader.file_label_loader(data_root, + indices, BATCH_SIZE) + print(outs[0].shape) """ @@ -868,6 +877,273 @@ def read_file(filename, name=None): return out +def image_decode(x, + num_threads=2, + host_memory_padding=0, + device_memory_padding=0, + name=None): + """ + Decodes a batch of JPEG images into a list of 3 dimensional RGB + Tensors with multi-threads and Nvjpeg. Default Nvjpeg decoding + output format is RGBI, for detail infomations, please see + https://docs.nvidia.com/cuda/nvjpeg/index.html. + + This api is only available for Paddle GPU version + + The values of the output tensors are uint8 between 0 and 255. + + Args: + x (List[Tensor]): A list of one dimensional uint8 Tensors + containing the raw bytes of the JPEG image. + num_threads (int): The parallel thread number for decoding + host_memory_padding (int): The CUDA pinned memory allocation + padding size of Nvjpeg decoding. Default 0. + host_memory_padding (int): The CUDA memory allocation padding + size of Nvjpeg decoding. Default 0. + name (str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A list of decoded image tensors with shape of + (imge_channels, image_height, image_width) + + Examples: + .. code-block:: python + import cv2 + import paddle + import numpy as np + + fake_img = (np.random.random( + (400, 300, 3)) * 255).astype('uint8') + + cv2.imwrite('fake.jpg', fake_img) + + img_bytes = paddle.vision.ops.read_file('fake.jpg') + imgs = paddle.vision.ops.image_decode([img_bytes]) + + print(imgs[0].shape) + """ + + local_rank = paddle.distributed.get_rank() + + if in_dygraph_mode(): + out = core.VarBase(core.VarDesc.VarType.UINT8, [], + unique_name.generate("image_decode"), + core.VarDesc.VarType.LOD_TENSOR_ARRAY, False) + program_id = utils._hash_with_id(mode, num_threads, name, local_rank) + return _C_ops.batch_decode( + x, out, "num_threads", num_threads, "local_rank", local_rank, + "program_id", program_id, "host_memory_padding", + host_memory_padding, "device_memory_padding", device_memory_padding) + + inputs = {'X': x} + attrs = { + "num_threads": num_threads, + "local_rank": local_rank, + "program_id": utils._hash_with_id(default_main_program()), + "host_memory_padding": host_memory_padding, + "device_memory_padding": device_memory_padding + } + + helper = LayerHelper("batch_decode", **locals()) + out = [ + helper.create_variable( + name=unique_name.generate("image_decode"), + type=core.VarDesc.VarType.LOD_TENSOR, + dtype='uint8') for i in range(len(x)) + ] + helper.append_op( + type="batch_decode", inputs=inputs, attrs=attrs, outputs={"Out": out}) + + return out + + +def image_decode_random_crop(x, + num_threads=2, + host_memory_padding=0, + device_memory_padding=0, + data_format='NCHW', + aspect_ratio_min=3. / 4., + aspect_ratio_max=4. / 3., + area_min=0.08, + area_max=1., + num_attempts=10, + name=None): + """ + Decodes and performs random cropping on a batch of JPEG images into + a list of 3 dimensional RGB Tensors with multi-threads and Nvjpeg. + Default Nvjpeg decoding output format is RGBI, for detail infomations, + please see https://docs.nvidia.com/cuda/nvjpeg/index.html. + + This api is only available for Paddle GPU version + + The values of the output tensors are uint8 between 0 and 255. + + Args: + x (List[Tensor]): A list of one dimensional uint8 Tensors + containing the raw bytes of the JPEG image. + num_threads (int): The parallel thread number for decoding + host_memory_padding (int): The CUDA pinned memory allocation + padding size of Nvjpeg decoding. Default 0. + host_memory_padding (int): The CUDA memory allocation padding + size of Nvjpeg decoding. Default 0. + data_format (string): The output image format, if NCHW, output + images will be in shape of (channels, image_height, + image_width), if NHWC, output images will be in shape of + (image_height, image_width, channels). Default NCHW. + aspect_ratio_min (float): The minimum aspect ratio of random + cropping boxes, this should be a value between 0 and + 1. Default :attr:`3. / 4.`. + aspect_ratio_max (float): The maximum aspect ratio of random + cropping boxes, this should be a value greater than 1. + Default :attr:`4. / 3.`. + area_min (float): The minimum area ratio of random cropping boxes, + this should be a value between 0 and 1. Default 0.08. + area_max (float): The maximum area ratio of random cropping boxes, + this should be a value between 0 and 1. Default 1. + num_attempts (int): The max attempt number to find random cropping + boxes, this should be a position integer. Default 10. + name (str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A list of decoded image tensors with shape of + (imge_channels, image_height, image_width) + + Examples: + .. code-block:: python + import cv2 + import paddle + import numpy as np + + fake_img = (np.random.random( + (400, 300, 3)) * 255).astype('uint8') + + cv2.imwrite('fake.jpg', fake_img) + + img_bytes = paddle.vision.ops.read_file('fake.jpg') + imgs = paddle.vision.ops.image_decode_random_crop([img_bytes]) + + print(imgs[0].shape) + """ + + local_rank = paddle.distributed.get_rank() + if in_dygraph_mode(): + out = core.VarBase(core.VarDesc.VarType.UINT8, [], + unique_name.generate("image_decode_random_crop"), + core.VarDesc.VarType.LOD_TENSOR_ARRAY, False) + program_id = utils._hash_with_id(mode, num_threads, name, local_rank) + return _C_ops.batch_decode_random_crop( + x, out, "num_threads", num_threads, "data_format", data_format, + "aspect_ratio_min", aspect_ratio_min, "aspect_ratio_max", + aspect_ratio_max, "area_min", area_min, "area_max", area_max, + "num_attempts", num_attempts, "local_rank", local_rank, + "program_id", program_id, "host_memory_padding", + host_memory_padding, "device_memory_padding", device_memory_padding) + + inputs = {'X': x} + attrs = { + "num_threads": num_threads, + "host_memory_padding": host_memory_padding, + "device_memory_padding": device_memory_padding, + "data_format": data_format, + "aspect_ratio_min": aspect_ratio_min, + "aspect_ratio_max": aspect_ratio_max, + "area_min": area_min, + "area_max": area_max, + "num_attempts": num_attempts, + "local_rank": local_rank, + "program_id": utils._hash_with_id(default_main_program()) + } + + helper = LayerHelper("image_decode_random_crop", **locals()) + out = [ + helper.create_variable( + name=unique_name.generate("image_decode_random_crop"), + type=core.VarDesc.VarType.LOD_TENSOR, + dtype='uint8') for i in range(len(x)) + ] + helper.append_op( + type="batch_decode_random_crop", + inputs=inputs, + attrs=attrs, + outputs={"Out": out}) + + return out + + +def random_flip(x, prob=0.5, name=None): + """ + This API generates flipping mirror flags for input Tensor, it treats + the 1st dimension as batch size and generates a bool value of whether + to flip the input samples for each sample. + + Args: + x (Tensor): The input tensor in shape of [N, ...], N if the batch + size to generate random flipping mirror flags. + prob (float): The probability for flip the input samples, this + should be a float value between 0 and 1. Default 0.5 + name (str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A bool Tensor in shape of [N, 1], N is the shape of 1st + dimension of input Tensor. + + Examples: + .. code-block:: python + import paddle + + x = paddle.rand(shape=[8, 3, 32, 32]) + mirror = paddle.vision.ops.random_flip(x) + + print(mirror) + """ + if prob < 0. or prob > 1.: + raise ValueError("prob should in (0, 1) in random_flip") + + rand_vec = layers.uniform_random_batch_size_like(x, [1, 1], min=0., max=1.) + return rand_vec < prob + + +def mirror_normalize(x, + mirror, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.120, 57.375], + name=None): + def _to_list_3(l): + if isinstance(l, (list, tuple)): + assert len(l) == 1 or len(l) == 3, \ + "input list length should be 1 or 3" + if len(l) == 1: + l = l * 3 + return l + else: + return [l] * 3 + + x = paddle.cast(x, dtype='float32') + mean = _to_list_3(mean) + std = _to_list_3(std) + + if _non_static_mode(): + return _C_ops.mirror_normalize(x, mirror, "mean", mean, "std", std) + + helper = LayerHelper("mirror_normalize", **locals()) + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="mirror_normalize", + inputs={"X": x, + "Mirror": mirror}, + outputs={"Out": out}, + attrs={"mean": mean, + "std": std}) + return out + + def decode_jpeg(x, mode='unchanged', name=None): """ Decodes a JPEG image into a 3 dimensional RGB Tensor or 1 dimensional Gray Tensor. @@ -1311,6 +1587,203 @@ def forward(self, x, boxes, boxes_num, aligned=True): aligned=aligned) +def random_crop_and_resize(x, + size, + aspect_ratio_min=3. / 4., + aspect_ratio_max=4. / 3., + area_min=0.08, + area_max=1., + num_attempts=10, + interp_method='bilinear', + align_corners=True, + align_mode=1, + data_format='NCHW', + seed=0, + name=None): + """ + This operator implements the paddle.vision.transforms.RandomResizedCrop. + Please refer to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/vision/transforms/RandomResizedCrop_cn.html#randomresizedcrop + for details. This operator has only a GPU kernel. + + Args: + x (List[Tensor]): A list of input images, 3D-Tensor with the shape + of [C,H,W] or [H,W,c]. The data type is uint8 or float32. + size (int|list|tuple): Target size of output image, with (height, + width) shape. + aspect_ratio_min (float): The minimum aspect ratio of random + cropping boxes, this should be a value between 0 and + 1. Default :attr:`3. / 4.`. + aspect_ratio_max (float): The maximum aspect ratio of random + cropping boxes, this should be a value greater than 1. + Default :attr:`4. / 3.`. + area_min (float): The minimum area ratio of random cropping boxes, + this should be a value between 0 and 1. Default 0.08. + area_max (float): The maximum area ratio of random cropping boxes, + this should be a value between 0 and 1. Default 1. + num_attempts (int): The max attempt number to find random cropping + boxes, this should be a position integer. Default 10. + data_format (string): The input image format, if NCHW, input + images will be in shape of (channels, image_height, + image_width), if NHWC, input images will be in shape of + (image_height, image_width, channels). Default NCHW. + interp_method (str, optional): Interpolation method. Default: 'bilinear'. + support method are as following: + - "nearest", + - "bilinear" + align_corners (bool, optional): If True, the centers of 4 corner pixels + of the input and output tensors are aligned, preserving the values + at the corner pixels, If False, are not aligned. Default: True + align_mode (int32, optional): Optional for bilinear interpolation, + can be 0 for src_idx = scale*(dst_indx+0.5)-0.5, can be 1 for + src_idx = scale*dst_index. Default: 1 + data_format (str, optional): Only used in an optional string + from: NHWC, NCHW. Specify that the data format of the input + and output data is channel_first or channel_last. Default: NCHW + seed (int, optional): The random seed. Default: 0 + name(str, optional): For detailed information, please refer to : + ref:`api_guide_Name`. Usually name is no need to set and None by + default. + + Returns: + Tensor: The output is a 4-D tensor with shape (batch_size, + channels, h, w) or (batch_random_crop_and_resize, h, w, + channels). The data type is uint8 or float32. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.randn(shape=[3, 256, 256]) + data = paddle.cast(data, dtype='uint8') + out = paddle.vision.ops.random_crop_and_resize([data], size=224) + + print(out.shape) + """ + + check_type(size, 'size', (int, tuple), 'batch_random_crop_and_resize') + assert interp_method in ['bilinear', 'nearest'] + assert data_format in ['NCHW', 'NHWC'] + if isinstance(size, int): + size = (size, size) + + if in_dygraph_mode(): + out = _C_ops.batch_random_crop_and_resize( + x, "size", size, "aspect_ratio_min", aspect_ratio_min, + "aspect_ratio_max", aspect_ratio_max, "area_max", area_max, + "area_min", area_min, "num_attempts", num_attempts, "interp_method", + interp_method, "align_corners", align_corners, "align_mode", + align_mode, "data_format", data_format, "seed", seed) + return out + + helper = LayerHelper('batch_random_crop_and_resize', **locals()) + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + inputs = {"X": x} + attrs = { + "size": size, + "aspect_ratio_min": aspect_ratio_min, + "aspect_ratio_max": aspect_ratio_max, + "area_min": area_min, + "area_max": area_max, + "num_attempts": num_attempts, + "interp_method": interp_method, + "align_corners": align_corners, + "align_mode": align_mode, + "data_format": data_format, + "seed": seed, + } + helper.append_op( + type="batch_random_crop_and_resize", + inputs=inputs, + outputs={"Out": out}, + attrs=attrs) + return out + + +def image_resize(x, + size, + interp_method='bilinear', + align_corners=True, + align_mode=1, + data_format='NCHW', + seed=0, + name=None): + """ + This operator implements the paddle.vision.transforms.Resize. + + Please refer to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/vision/transforms/Resized_cn.html#randomresizedcrop + for details. This operator has only a GPU kernel. + + Args: + x (List[Tensor]): A list of input images, 3D-Tensor with the shape + of [C, H, W] or [H, W, C]. The data type is uint8 or float32. + size (int|list|tuple): Target size of output image, with (height, + width) shape. + interp_method (str, optional): Interpolation method. Default: 'bilinear'. + support method are as following: + - "nearest", + - "bilinear" + align_corners (bool, optional): If True, the centers of 4 corner pixels + of the input and output tensors are aligned, preserving the values + at the corner pixels, If False, are not aligned. Default: True + align_mode (int32, optional): Optional for bilinear interpolation, + can be 0 for src_idx = scale*(dst_indx+0.5)-0.5, can be 1 for + src_idx = scale*dst_index. Default: 1 + data_format (str, optional): Only used in an optional string + from: NHWC, NCHW. Specify that the data format of the input + and output data is channel_first or channel_last. Default: NCHW + seed (int, optional): The random seed. Default: 0 + name(str, optional): For detailed information, please refer to : + ref:`api_guide_Name`. Usually name is no need to set and None by + default. + + Returns: + Tensor: The output of image_resizeis a 4-D tensor with shape + (batch_size, channels, h, w) or (batch_resize, h, w, channels). + The data type is uint8 or float32. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.randn(shape=[3, 256, 256]) + data = paddle.cast(data, dtype='uint8') + out = paddle.vision.ops.image_resize([data], size=224) + + print(out.shape) + """ + check_type(size, 'size', (int, tuple), 'image_resize') + assert interp_method in ['bilinear', 'nearest'] + assert data_format in ['NCHW', 'NHWC'] + if isinstance(size, int): + size = (size, size) + + if in_dygraph_mode(): + out = _C_ops.batch_resize(x, "size", size, "interp_method", + interp_method, "align_corners", align_corners, + "align_mode", align_mode, "data_format", + data_format, "seed", seed) + return out + + helper = LayerHelper('batch_resize', **locals()) + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + inputs = {"X": x} + attrs = { + "size": size, + "interp_method": interp_method, + "align_corners": align_corners, + "align_mode": align_mode, + "data_format": data_format, + "seed": seed, + } + helper.append_op( + type="batch_resize", inputs=inputs, outputs={"Out": out}, attrs=attrs) + return out + + class ConvNormActivation(Sequential): """ Configurable block used for Convolution-Normalzation-Activation blocks. diff --git a/python/paddle/vision/reader.py b/python/paddle/vision/reader.py new file mode 100644 index 00000000000000..400ac1bf3214ec --- /dev/null +++ b/python/paddle/vision/reader.py @@ -0,0 +1,235 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from ..fluid.layer_helper import LayerHelper, unique_name +from ..fluid import core, layers +from ..fluid.layers import nn, utils +from ..fluid.framework import _non_static_mode + +import paddle +from paddle.common_ops_import import * +from paddle import _C_ops + +__all__ = [ #noqa + 'file_label_loader', + 'file_label_reader', +] + + +class _Sampler(object): + def __init__(self, batch_size, num_samples, shuffle=False, drop_last=False): + self.batch_size = batch_size + self.num_samples = num_samples + self.shuffle = shuffle + self.drop_last = drop_last + self.start_idx = 0 + + self.sample_ids = np.arange(num_samples) + if shuffle: + np.random.shuffle(self.sample_ids) + + def __next__(self): + if self.start_idx >= self.num_samples: + self.reset() + return self.__next__() + + batch_len = min(self.batch_size, self.num_samples - self.start_idx) + indices = self.sample_ids[self.start_idx:self.start_idx + batch_len] + self.start_idx += batch_len + + if self.drop_last and len(indices) < self.batch_size: + self.reset() + return self.__next__() + + return indices + + def reset(self): + self.start_idx = 0 + if self.shuffle: + np.random.shuffle(self.sample_ids) + + +class _SamplerManager(object): + def __init__(self): + self.samplers = {} + + def get(self, + sample_id, + batch_size, + num_samples, + shuffle=False, + drop_last=False): + if sample_id in self.samplers: + return self.samplers[sample_id] + + sampler = _Sampler(batch_size, num_samples, shuffle, drop_last) + self.samplers[sample_id] = sampler + return sampler + + +_sampler_manager = _SamplerManager() + + +def file_label_loader(data_root, indices, batch_size, name=None): + """ + Reads a batch of data, outputs the bytes contents of a file + as a uint8 Tensor with one dimension. + + This API can only be used in Paddle GPU version. + + Args: + data_root (str): root directory of ImageNet dataset. + indices (Tensor): A Tensor of batch indices of samples in shape of + [N], while N is the batch size. + batch_size (int): The batch size, same as shape of indices. + name (str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + A list of image Tensor holds byte streams of a batch of images and + A Tensor of label Tensor. + + Examples: + .. code-block:: python + + import os + import paddle + from paddle.utils.download import get_path_from_url + + DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") + DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" + DATASET_MD5 = "c7110519124a433901cf005a4a91b607" + BATCH_SIZE = 16 + + data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + indices = paddle.arange(BATCH_SIZE) + images, labels = paddle.vision.reader.file_label_loader( + data_root, indices, BATCH_SIZE) + print(images[0].shape, labels.shape) + + """ + + if _non_static_mode(): + image = [ + core.VarBase(core.VarDesc.VarType.UINT8, [], + unique_name.generate("file_label_loader"), + core.VarDesc.VarType.LOD_TENSOR, False) + for i in range(batch_size) + ] + return _C_ops.file_label_loader(indices, image, 'data_root', data_root) + + inputs = {"Indices": indices} + attrs = {'data_root': data_root, } + + helper = LayerHelper("file_label_loader", **locals()) + image = [ + helper.create_variable( + name=unique_name.generate("file_label_loader"), + type=core.VarDesc.VarType.LOD_TENSOR, + dtype='uint8') for i in range(batch_size) + ] + + label = helper.create_variable( + name=unique_name.generate("file_label_loader"), + type=core.VarDesc.VarType.LOD_TENSOR, + dtype='int') + + helper.append_op( + type="file_label_loader", + inputs=inputs, + attrs=attrs, + outputs={"Image": image, + "Label": label}) + + return image, label + + +def file_label_reader(data_root, + batch_size=1, + shuffle=False, + drop_last=False, + seed=None): + """ + Reads batches of data iterably, outputs the bytes contents of a file + as a uint8 Tensor with one dimension. + + This API will start a C++ thread to load data with + :attr:`file_label_loader`, and yiled data iterably. + + This API can only be used in Paddle GPU version. + + Args: + data_root (str): root directory of ImageNet dataset. + batch_size (int): The batch size of a mini-batch. Default 1. + shuffle (bool): Whether to shuffle samples. Default False. + drop_last (bool): Whether to drop the last incomplete batch. Default False. + seed (int, optional): The seed for sample shuffling. Default None. + name (str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + A list of image Tensor holds byte streams of a batch of images and + A Tensor of label Tensor. + + Examples: + .. code-block:: python + + import os + import paddle + from paddle.utils.download import get_path_from_url + + DATASET_HOME = os.path.expanduser("~/.cache/paddle/datasets") + DATASET_URL = "https://paddlemodels.cdn.bcebos.com/ImageNet_stub.tar" + DATASET_MD5 = "c7110519124a433901cf005a4a91b607" + BATCH_SIZE = 16 + + data_root = get_path_from_url(DATASET_URL, DATASET_HOME, + DATASET_MD5) + images, labels = paddle.vision.reader.file_label_reader( + data_root, BATCH_SIZE) + print(images[0].shape, labels.shape) + + """ + + from paddle.vision.datasets import DatasetFolder + data_folder = DatasetFolder(data_root) + samples = [s[0] for s in data_folder.samples] + targets = [s[1] for s in data_folder.samples] + + if _non_static_mode(): + sample_id = utils._hash_with_id(data_root, batch_size, shuffle, + drop_last) + sampler = _sampler_manager.get(sample_id, + batch_size=batch_size, + num_samples=len(samples), + shuffle=shuffle, + drop_last=drop_last) + indices = paddle.to_tensor(next(sampler), dtype='int64') + return file_label_loader(data_root, indices, batch_size) + + def _reader(indices): + return file_label_loader(data_root, indices, batch_size) + + outs = paddle.io.data_reader( + _reader, + batch_size=batch_size, + num_samples=len(samples), + shuffle=shuffle, + drop_last=drop_last, + seed=seed) + return outs[:-1], outs[-1] diff --git a/tools/externalError/spider.py b/tools/externalError/spider.py index e07f05f561cb51..0df857a2d6806a 100644 --- a/tools/externalError/spider.py +++ b/tools/externalError/spider.py @@ -361,6 +361,52 @@ def handle_data(self, data): desc.strip()) CUFFTHTMLParser().feed(html) + #*************************************************************************************************# + + #*********************************** nvJPEG Error Message **************************************# + nvjpegStatus_t = { + "NVJPEG_STATUS_SUCCESS": 0, + "NVJPEG_STATUS_NOT_INITIALIZED": 1, + "NVJPEG_STATUS_INVALID_PARAMETER": 2, + "NVJPEG_STATUS_BAD_JPEG": 3, + "NVJPEG_STATUS_JPEG_NOT_SUPPORTED": 4, + "NVJPEG_STATUS_ALLOCATOR_FAILURE": 5, + "NVJPEG_STATUS_EXECUTION_FAILED": 6, + "NVJPEG_STATUS_ARCH_MISMATCH": 7, + "NVJPEG_STATUS_INTERNAL_ERROR": 8, + "NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED": 9, + "NVJPEG_STATUS_INCOMPLETE_BITSTREAM": 10, + } + + print("start crawling errorMessage for nvidia nvJPEG API--->") + url = 'https://docs.nvidia.com/cuda/nvjpeg/#nvjpeg-api-return-codes' + + allMessageDesc = externalErrorDesc.errors.add() + allMessageDesc.type = external_error_pb2.NVJPEG + + html = urllib.request.urlopen(url).read().decode('utf-8') + + res_div = r'Description of the returned error codes:.*?
(.*?)
' + m_div = re.findall(res_div, html, re.S | re.M)[0] + + res_dt = r'(.*?).*?colspan="1">(.*?)' + m_dt = re.findall(res_dt, m_div, re.S | re.M) + + for error in m_dt: + m_code = error[0] + m_code = m_code.split()[0].strip() + + m_message = error[1] + m_message = re.sub(r'\t', ' ', m_message) + m_message = re.sub(r'\n +', ' ', m_message) + m_message = re.sub(r'<.*?>', '', m_message) + + _Messages = allMessageDesc.messages.add() + _Messages.code = int(nvjpegStatus_t[m_code]) + _Messages.message = "'%s'. %s" % (m_code, m_message) + + print("End crawling errorMessage for nvidia NVJPEG API!\n") + #*************************************************************************************************# def main(argv):