Skip to content

Commit f36d04f

Browse files
jharmsenseanpmorgan
authored andcommitted
rename ImageProjectiveTransform (#90)
1 parent 3814ae2 commit f36d04f

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ using generator::INTERPOLATION_NEAREST;
5050
using generator::ProjectiveGenerator;
5151

5252
template <typename Device, typename T>
53-
class ImageProjectiveTransform : public OpKernel {
53+
class ImageProjectiveTransformV2 : public OpKernel {
5454
private:
5555
Interpolation interpolation_;
5656

5757
public:
58-
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
58+
explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx) : OpKernel(ctx) {
5959
string interpolation_str;
6060
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
6161
if (interpolation_str == "NEAREST") {
@@ -118,10 +118,10 @@ class ImageProjectiveTransform : public OpKernel {
118118
};
119119

120120
#define REGISTER(TYPE) \
121-
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
121+
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
122122
.Device(DEVICE_CPU) \
123123
.TypeConstraint<TYPE>("dtype"), \
124-
ImageProjectiveTransform<CPUDevice, TYPE>)
124+
ImageProjectiveTransformV2<CPUDevice, TYPE>)
125125

126126
TF_CALL_uint8(REGISTER);
127127
TF_CALL_int32(REGISTER);
@@ -157,11 +157,11 @@ TF_CALL_double(DECLARE_FUNCTOR);
157157
} // end namespace functor
158158

159159
#define REGISTER(TYPE) \
160-
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
160+
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
161161
.Device(DEVICE_GPU) \
162162
.TypeConstraint<TYPE>("dtype") \
163163
.HostMemory("output_shape"), \
164-
ImageProjectiveTransform<GPUDevice, TYPE>)
164+
ImageProjectiveTransformV2<GPUDevice, TYPE>)
165165

166166
TF_CALL_uint8(REGISTER);
167167
TF_CALL_int32(REGISTER);

tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ the `transforms` to the `images`. Satisfies the description above.
9292
} // namespace
9393

9494
// V2 op supports output_shape.
95-
REGISTER_OP("ImageProjectiveTransform")
95+
REGISTER_OP("ImageProjectiveTransformV2")
9696
.Input("images: dtype")
9797
.Input("transforms: float32")
9898
.Input("output_shape: int32")

tensorflow_addons/custom_ops/image/python/transform.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
tf.dtypes.float32, tf.dtypes.float64
3131
])
3232

33-
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
33+
ops.RegisterShape("ImageProjectiveTransformV2")(common_shapes.call_cpp_shape_fn)
3434

3535

3636
@tf.function
@@ -108,7 +108,7 @@ def transform(images,
108108
else:
109109
raise TypeError("Transforms should have rank 1 or 2.")
110110

111-
output = _image_ops_so.image_projective_transform(
111+
output = _image_ops_so.image_projective_transform_v2(
112112
images,
113113
output_shape=output_shape,
114114
transforms=transforms,
@@ -260,7 +260,7 @@ def angles_to_projective_transforms(angles,
260260
axis=1)
261261

262262

263-
@ops.RegisterGradient("ImageProjectiveTransform")
263+
@ops.RegisterGradient("ImageProjectiveTransformV2")
264264
def _image_projective_transform_grad(op, grad):
265265
"""Computes the gradient for ImageProjectiveTransform."""
266266
images = op.inputs[0]
@@ -284,7 +284,7 @@ def _image_projective_transform_grad(op, grad):
284284
transforms = flat_transforms_to_matrices(transforms=transforms)
285285
inverse = tf.linalg.inv(transforms)
286286
transforms = matrices_to_flat_transforms(inverse)
287-
output = _image_ops_so.image_projective_transform(
287+
output = _image_ops_so.image_projective_transform_v2(
288288
images=grad,
289289
transforms=transforms,
290290
output_shape=tf.shape(image_or_images)[1:3],

0 commit comments

Comments
 (0)