Skip to content

Commit 420223d

Browse files
committed
Support jax2tf in JaxLayer for tf backend
1 parent 19ca9c1 commit 420223d

File tree

3 files changed

+234
-45
lines changed

3 files changed

+234
-45
lines changed

keras/src/layers/layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,10 @@ def compute_output_spec(self, *args, **kwargs):
11451145
call_spec=call_spec,
11461146
class_name=self.__class__.__name__,
11471147
)
1148-
output_shape = self.compute_output_shape(**shapes_dict)
1148+
try:
1149+
output_shape = self.compute_output_shape(**shapes_dict)
1150+
except NotImplementedError as e:
1151+
return super().compute_output_spec(*args, **kwargs)
11491152

11501153
if (
11511154
isinstance(output_shape, list)

0 commit comments

Comments
 (0)