Skip to content

Commit 856b1f2

Browse files
committed
Q3
1 parent 430c94e commit 856b1f2

File tree

3 files changed

+6
-13
lines changed

3 files changed

+6
-13
lines changed

questions/3_reshape-matrix/tinygrad/solution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
def reshape_matrix_tg(a, new_shape) -> Tensor:
44
"""
55
Reshape a 2D matrix `a` to shape `new_shape` using tinygrad.
6-
Inputs can be Python lists, NumPy arrays, or tinygrad Tensors.
6+
Inputs are tinygrad Tensors.
77
Returns a Tensor of shape `new_shape`, or an empty Tensor on mismatch.
88
"""
99
# Dimension check
1010
if len(a) * len(a[0]) != new_shape[0] * new_shape[1]:
1111
return Tensor([])
12-
a_t = Tensor(a)
13-
return a_t.reshape(new_shape)
12+
return a.reshape(new_shape)
Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
from tinygrad.tensor import Tensor
22

3-
def reshape_matrix_tg(a, new_shape) -> Tensor:
3+
def reshape_matrix_tg(a:Tensor, new_shape:tuple) -> Tensor:
44
"""
55
Reshape a 2D matrix `a` to shape `new_shape` using tinygrad.
6-
Inputs can be Python lists, NumPy arrays, or tinygrad Tensors.
6+
Inputs are tinygrad Tensors.
77
Returns a Tensor of shape `new_shape`, or an empty Tensor on mismatch.
88
"""
9-
# Dimension check
10-
if len(a) * len(a[0]) != new_shape[0] * new_shape[1]:
11-
return Tensor([])
12-
# Convert to Tensor and reshape
13-
a_t = Tensor(a)
14-
# Your implementation here
159
pass
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
[
22
{
3-
"test": "from tinygrad.tensor import Tensor\nres = reshape_matrix_tg(\n [[1,2,3],[4,5,6]],\n (3, 2)\n)\nprint(res.numpy().tolist())",
3+
"test": "from tinygrad.tensor import Tensor\nres = reshape_matrix_tg(\n Tensor([[1,2,3],[4,5,6]]),\n (3, 2)\n)\nprint(res.numpy().tolist())",
44
"expected_output": "[[1, 2], [3, 4], [5, 6]]"
55
},
66
{
7-
"test": "from tinygrad.tensor import Tensor\nres = reshape_matrix_tg(\n [[1,2],[3,4]],\n (3, 2)\n)\nprint(res.numpy().tolist())",
7+
"test": "from tinygrad.tensor import Tensor\nres = reshape_matrix_tg(\n Tensor([[1,2],[3,4]]),\n (3, 2)\n)\nprint(res.numpy().tolist())",
88
"expected_output": "[]"
99
}
1010
]

0 commit comments

Comments
 (0)