diff --git a/.gitignore b/.gitignore index afd700b49952..7bce46435932 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ examples/**/*.jpg .python-version .coverage *coverage.xml -.ruff_cache \ No newline at end of file +.ruff_cache +keras_venv/ \ No newline at end of file diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index 6f8bcb18955c..5227a7d32533 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -36,7 +36,6 @@ NumpyDtypeTest::test_matmul_ NumpyDtypeTest::test_max NumpyDtypeTest::test_mean NumpyDtypeTest::test_median -NumpyDtypeTest::test_meshgrid NumpyDtypeTest::test_minimum_python_types NumpyDtypeTest::test_multiply NumpyDtypeTest::test_outer_ @@ -94,7 +93,7 @@ NumpyOneInputOpsCorrectnessTest::test_logaddexp NumpyOneInputOpsCorrectnessTest::test_max NumpyOneInputOpsCorrectnessTest::test_mean NumpyOneInputOpsCorrectnessTest::test_median -NumpyOneInputOpsCorrectnessTest::test_meshgrid +NumpyOneInputOpsCorrectnessTest::test_moveaxis NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2 NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2 NumpyOneInputOpsCorrectnessTest::test_pad_float64_constant_2 diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 90a27d8c0833..08da1493254b 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1046,9 +1046,63 @@ def median(x, axis=None, keepdims=False): def meshgrid(*x, indexing="xy"): - raise NotImplementedError( - "`meshgrid` is not supported with openvino backend" - ) + if not x: + return [] + + if indexing not in ["xy", "ij"]: + raise ValueError( + f"indexing parameter must be either 'xy' or 'ij', got {indexing}" + ) + + arrays = [] + for array in x: + arrays.append(get_ov_output(array)) + + ndim = len(arrays) + s0 = arrays[0].get_element_type() + + output = [] + + for i, xi in enumerate(arrays): + shape = [1] * ndim + + if indexing == "xy" and ndim > 1: + if i == 0: + shape[1] = -1 + elif i == 1: + shape[0] = -1 + else: + shape[i] = -1 + else: + shape[i] = -1 + + reshape_node = ov_opset.reshape( + xi, ov_opset.constant(shape, Type.i32), False + ).output(0) + + output_shape = [1] * ndim + for j, xj in enumerate(arrays): + xj_shape = xj.get_partial_shape().to_shape() + if xj_shape and xj_shape[0] > 0: + output_shape[ + j + if indexing == "ij" + else ( + 1 + if j == 0 and ndim > 1 + else 0 + if j == 1 and ndim > 1 + else j + ) + ] = xj_shape[0] + + broadcast_node = ov_opset.broadcast( + reshape_node, ov_opset.constant(output_shape, Type.i32) + ).output(0) + + output.append(OpenVINOKerasTensor(broadcast_node)) + + return output def min(x, axis=None, keepdims=False, initial=None): diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 865c55a3ceeb..2dbcca6e5026 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -59,14 +59,8 @@ def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs): def get_random_transformation(self, images, training=True, seed=None): if seed is None: seed = self._get_seed_generator(self.backend._backend) - # Base case: Unbatched data - batch_size = 1 - if len(images.shape) == 4: - # This is a batch of images (4D input) - batch_size = self.backend.core.shape(images)[0] - random_values = self.backend.random.uniform( - shape=(batch_size,), + shape=(self.backend.core.shape(images)[0],), minval=0, maxval=1, seed=seed, diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py index a43dfc55694a..b488c2c31f83 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py @@ -78,37 +78,17 @@ def test_tf_data_compatibility(self): def test_grayscale_with_single_color_image(self): test_cases = [ - # batched inputs (np.full((1, 4, 4, 3), 128, dtype=np.float32), "channels_last"), (np.full((1, 3, 4, 4), 128, dtype=np.float32), "channels_first"), - # unbatched inputs - (np.full((4, 4, 3), 128, dtype=np.float32), "channels_last"), - (np.full((3, 4, 4), 128, dtype=np.float32), "channels_first"), ] for xs, data_format in test_cases: layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) transformed = ops.convert_to_numpy(layer(xs)) - # Determine if the input was batched - is_batched = len(xs.shape) == 4 - - # If batched, select the first image from the batch for inspection. - # Otherwise, use the transformed image directly. - # `image_to_inspect` will always be a 3D tensor. - if is_batched: - image_to_inspect = transformed[0] - else: - image_to_inspect = transformed - if data_format == "channels_last": - # image_to_inspect has shape (H, W, C), - # get the first channel [:, :, 0] - channel_data = image_to_inspect[:, :, 0] - else: # data_format == "channels_first" - # image_to_inspect has shape (C, H, W), - # get the first channel [0, :, :] - channel_data = image_to_inspect[0, :, :] - - unique_vals = np.unique(channel_data) - self.assertEqual(len(unique_vals), 1) + unique_vals = np.unique(transformed[0, :, :, 0]) + self.assertEqual(len(unique_vals), 1) + else: + unique_vals = np.unique(transformed[0, 0, :, :]) + self.assertEqual(len(unique_vals), 1) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000000..83635a5b7b9b --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +env = + KERAS_BACKEND=openvino \ No newline at end of file