Skip to content

Commit 2d35e25

Browse files
committed
Added Numpy.meshgrid operation for openvino backend
1 parent 693764a commit 2d35e25

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ NumpyOneInputOpsCorrectnessTest::test_logaddexp
100100
NumpyOneInputOpsCorrectnessTest::test_max
101101
NumpyOneInputOpsCorrectnessTest::test_mean
102102
NumpyOneInputOpsCorrectnessTest::test_median
103-
NumpyOneInputOpsCorrectnessTest::test_meshgrid
104103
NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2
105104
NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2
106105
NumpyOneInputOpsCorrectnessTest::test_pad_float64_constant_2

keras/src/backend/openvino/numpy.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,9 +1097,52 @@ def median(x, axis=None, keepdims=False):
10971097

10981098

10991099
def meshgrid(*x, indexing="xy"):
1100-
raise NotImplementedError(
1101-
"`meshgrid` is not supported with openvino backend"
1102-
)
1100+
if len(x) < 2:
1101+
raise ValueError("meshgrid requires at least 2 input arrays")
1102+
if indexing not in ("xy", "ij"):
1103+
raise ValueError("indexing must be either 'xy' or 'ij'")
1104+
1105+
tensors = [get_ov_output(xi) for xi in x]
1106+
n = len(tensors)
1107+
1108+
shapes = [ov_opset.shape_of(t, Type.i64).output(0) for t in tensors] # each is [Ni]
1109+
one = ov_opset.constant([1], Type.i64).output(0)
1110+
1111+
if indexing == "xy" and n >= 2:
1112+
out_shape = ov_opset.concat([shapes[1], shapes[0]] + shapes[2:], axis=0).output(0)
1113+
else:
1114+
out_shape = ov_opset.concat(shapes, axis=0).output(0)
1115+
1116+
outputs = []
1117+
for i, t in enumerate(tensors):
1118+
parts = []
1119+
for axis in range(n):
1120+
if indexing == "xy" and n >= 2:
1121+
if i == 0:
1122+
if axis == 0:
1123+
parts.append(one)
1124+
elif axis == 1:
1125+
parts.append(shapes[0])
1126+
else:
1127+
parts.append(one if axis != i else shapes[i])
1128+
elif i == 1:
1129+
if axis == 0:
1130+
parts.append(shapes[1])
1131+
elif axis == 1:
1132+
parts.append(one)
1133+
else:
1134+
parts.append(one if axis != i else shapes[i])
1135+
else:
1136+
parts.append(shapes[i] if axis == i else one)
1137+
else:
1138+
parts.append(shapes[i] if axis == i else one)
1139+
1140+
reshape_shape = ov_opset.concat(parts, axis=0).output(0)
1141+
reshaped = ov_opset.reshape(t, reshape_shape, False).output(0)
1142+
broadcasted = ov_opset.broadcast(reshaped, out_shape).output(0)
1143+
outputs.append(OpenVINOKerasTensor(broadcasted))
1144+
1145+
return outputs
11031146

11041147

11051148
def min(x, axis=None, keepdims=False, initial=None):

0 commit comments

Comments
 (0)