Skip to content

Commit ffb1b07

Browse files
committed
introduce rgbPixelBytes
1 parent bfee703 commit ffb1b07

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

tensorflow_lite_support/cc/task/vision/image_transformer.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ StatusOr<FrameBuffer> ImageTransformer::Transform(
154154

155155
StatusOr<std::unique_ptr<FrameBuffer>> ImageTransformer::Postprocess() {
156156
std::unique_ptr<FrameBuffer> postprocessed_frame_buffer;
157+
const int kRgbPixelBytes = 3;
157158
const TfLiteTensor* output_tensor =
158159
TfLiteEngine::GetOutput(GetTfLiteEngine()->interpreter(), 0);
159160

@@ -187,19 +188,19 @@ StatusOr<std::unique_ptr<FrameBuffer>> ImageTransformer::Postprocess() {
187188
const tflite::task::vision::NormalizationOptions& normalization_options =
188189
GetInputSpecs().normalization_options.value();
189190

190-
if (normalization_options.size() == 1) {
191+
if (normalization_options.num_values == 1) {
191192
float mean_value = normalization_options.mean_values[0];
192193
float std_value = normalization_options.std_values[0];
193194

194195
for (size_t i = 0; i < output_byte_size / sizeof(uint8);
195196
++i, ++denormalized_output_data, ++output_data) {
196-
denormalized_output_data = static_cast<uint8>(std::round(std::min(
197+
*denormalized_output_data = static_cast<uint8>(std::round(std::min(
197198
255.f, std::max(0.f, (*output_data) * std_value + mean_value))));
198199
}
199200
} else {
200201
for (size_t i = 0; i < output_byte_size / sizeof(uint8);
201202
++i, ++denormalized_output_data, ++output_data) {
202-
denormalized_output_data = static_cast<uint8>(std::round(std::min(
203+
*denormalized_output_data = static_cast<uint8>(std::round(std::min(
203204
255.f,
204205
std::max(0.f,
205206
(*output_data) * normalization_options.std_values[i % 3] +
@@ -210,7 +211,7 @@ StatusOr<std::unique_ptr<FrameBuffer>> ImageTransformer::Postprocess() {
210211

211212
FrameBuffer::Plane postprocessed_plane = {
212213
/*buffer=*/postprocessed_data.data(),
213-
/*stride=*/{output_tensor->dims->data[1] * kRgbPixelBytes,
214+
/*stride=*/{output_tensor->dims->data[2] * kRgbPixelBytes,
214215
kRgbPixelBytes}};
215216
postprocessed_frame_buffer = FrameBuffer::Create(
216217
{postprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB,

tensorflow_lite_support/cc/task/vision/image_transformer.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ limitations under the License.
2424
#include "tensorflow_lite_support/cc/port/statusor.h"
2525
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
2626
#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
27-
#include "tensorflow_lite_support/cc/task/vision/proto/image_transformer_options.proto"
27+
#include "tensorflow_lite_support/cc/task/vision/proto/image_transformer_options_proto_inc.h"
2828

2929
namespace tflite {
3030
namespace task {
@@ -61,7 +61,7 @@ namespace vision {
6161
// A CLI demo tool is available for easily trying out this API, and provides
6262
// example usage. See:
6363
// examples/task/vision/desktop/image_classifier_demo.cc
64-
class ImageTransformer : public BaseVisionTaskApi<tflite::task::vision::FrameBuffer> {
64+
class ImageTransformer : public BaseVisionTaskApi<::tflite::task::vision::FrameBuffer> {
6565
public:
6666
using BaseVisionTaskApi::BaseVisionTaskApi;
6767

@@ -85,7 +85,7 @@ class ImageTransformer : public BaseVisionTaskApi<tflite::task::vision::FrameBuf
8585
// only supported colorspace for now),
8686
// - rotate it according to its `Orientation` so that inference is performed
8787
// on an "upright" image.
88-
tflite::support::StatusOr<tflite::task::vision::FrameBuffer> Transform(
88+
tflite::support::StatusOr<::tflite::task::vision::FrameBuffer> Transform(
8989
const FrameBuffer& frame_buffer);
9090

9191
// Same as above, except that the transformation is performed based on the
@@ -99,15 +99,15 @@ class ImageTransformer : public BaseVisionTaskApi<tflite::task::vision::FrameBuf
9999
// `frame_buffer` data before any `Orientation` flag gets applied. Also, the
100100
// region of interest is not clamped, so this method will return a non-ok
101101
// status if the region is out of these bounds.
102-
tflite::support::StatusOr<tflite::task::vision::FrameBuffer> Transform(
102+
tflite::support::StatusOr<::tflite::task::vision::FrameBuffer> Transform(
103103
const FrameBuffer& frame_buffer, const BoundingBox& roi);
104104

105105
protected:
106106
// The options used to build this ImageTransformer.
107107
std::unique_ptr<ImageTransformerOptions> options_;
108108

109109
// Post-processing to transform the raw model outputs into image results.
110-
tflite::support::StatusOr<std::unique_ptr<tflite::task::vision::FrameBuffer>> Postprocess();
110+
tflite::support::StatusOr<std::unique_ptr<::tflite::task::vision::FrameBuffer>> Postprocess();
111111

112112
// Performs sanity checks on the provided ImageTransformerOptions.
113113
static absl::Status SanityCheckOptions(const ImageTransformerOptions& options);

0 commit comments

Comments
 (0)