Skip to content

Commit 875a398

Browse files
edwinsolisfsyurkevi
authored andcommitted
Added gemm accumulation matrix into interface and tests
1 parent a1d8054 commit 875a398

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

arrayfire/library/linear_algebra.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def gemm(
9292
rhs_opts: MatProp = MatProp.NONE,
9393
alpha: int | float = 1.0,
9494
beta: int | float = 0.0,
95+
accum: Array = None
9596
) -> Array:
9697
"""
9798
Performs BLAS general matrix multiplication (GEMM) on two Array instances.
@@ -125,6 +126,10 @@ def gemm(
125126
beta : int | float, optional
126127
Scalar multiplier for the existing matrix C in the accumulation. Default is 0.0.
127128
129+
accum: Array, optional
130+
A 2-dimensional, real or complex array representing the matrix C in the accumulation.
131+
Default is None (no accumulation).
132+
128133
Returns
129134
-------
130135
Array
@@ -135,7 +140,10 @@ def gemm(
135140
- The data types of `lhs` and `rhs` must be compatible.
136141
- Batch operations are not supported in this version.
137142
"""
138-
return cast(Array, wrapper.gemm(lhs.arr, rhs.arr, lhs_opts, rhs_opts, alpha, beta))
143+
accumulator = None
144+
if isinstance(accum, Array):
145+
accumulator = accum.arr
146+
return cast(Array, wrapper.gemm(lhs.arr, rhs.arr, lhs_opts, rhs_opts, alpha, beta, accumulator))
139147

140148

141149
@afarray_as_array

tests/test_library/test_linear_algebra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def test_gemm_basic(matrix_a: af.Array, matrix_b: af.Array) -> None:
6565
def test_gemm_alpha_beta(matrix_a: af.Array, matrix_b: af.Array) -> None:
6666
alpha = 0.5
6767
beta = 2.0
68-
result = af.gemm(matrix_a, matrix_b, alpha=alpha, beta=beta)
69-
expected = create_from_2d_nested(10.5, 12.0, 22.5, 26.0)
68+
result = af.gemm(matrix_a, matrix_b, alpha=alpha, beta=beta, accum=matrix_a)
69+
expected = create_from_2d_nested(11.5, 15.0, 27.5, 33.0)
7070
assert result == expected, f"Expected {expected}, got {result}"
7171

7272

0 commit comments

Comments
 (0)