Skip to content

Commit 78cc484

Browse files
authored
[FIX] Fix tensor concatenation to handle mixed numpy/torch arrays (#37)
### TL;DR Fix tensor concatenation to handle mixed numpy arrays ### What changed? Modified the `_concatenate` function in `backend_tensor.py` to properly handle cases where some tensors in the input list are numpy arrays. The previous implementation only checked the type of the first tensor, which could lead to errors when concatenating mixed tensor types. Also added an explicit error for unsupported tensor types. ### How to test? Test concatenating a list of tensors where some elements are numpy arrays and others are not. Verify that the function correctly identifies and handles numpy arrays regardless of their position in the list. ### Why make this change? The previous implementation had a bug where it only checked the type of the first tensor in the list, which would fail if the first tensor was not a numpy array but other tensors in the list were. This change makes the function more robust by checking if any tensor in the list is a numpy array, ensuring proper handling of mixed tensor types.
2 parents c4d66ac + 72ede38 commit 78cc484

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

gempy_engine/core/backend_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,11 @@ def _array(array_like, dtype=None):
197197
def _concatenate(tensors, axis=0, dtype=None):
198198
# Switch if tensor is numpy array or a torch tensor
199199
match type(tensors[0]):
200-
case numpy.ndarray:
200+
case _ if any(isinstance(t, numpy.ndarray) for t in tensors):
201201
return numpy.concatenate(tensors, axis=axis)
202202
case torch.Tensor:
203203
return torch.cat(tensors, dim=axis)
204+
raise TypeError("Unsupported tensor type")
204205

205206
def _transpose(tensor, axes=None):
206207
return tensor.transpose(axes[0], axes[1])

0 commit comments

Comments
 (0)