diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ef792cd6f4..8bc13952cc 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1766,7 +1766,7 @@ std::pair broadcast_arrays( array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } @@ -1774,7 +1774,7 @@ array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, @@ -1785,7 +1785,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } @@ -1796,7 +1796,7 @@ array greater_equal( StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, @@ -1807,7 +1807,7 @@ array greater_equal( array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } @@ -1815,7 +1815,7 @@ array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, @@ -2811,7 +2811,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) { array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Broadcast arrays to a common shape auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, @@ -2825,7 +2825,7 @@ array operator&&(const array& a, const array& b) { array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Broadcast arrays to a common shape auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, bool_, @@ -2845,7 +2845,7 @@ array add(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, out_type, std::make_shared(to_stream(s)), std::move(inputs)); } @@ -2858,7 +2858,7 @@ array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, out_type, @@ -2874,7 +2874,7 @@ array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, out_type, @@ -2890,7 +2890,7 @@ array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); auto inputs = broadcast_arrays( {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); } @@ -2914,7 +2914,7 @@ array floor_divide( } auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); } @@ -2923,7 +2923,7 @@ array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays( {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, dtype, @@ -2953,7 +2953,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, out_type, @@ -2965,7 +2965,7 @@ array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, out_type, @@ -3048,7 +3048,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) { array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); } @@ -3144,7 +3144,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = at_least_float(promote_types(a.dtype(), b.dtype())); auto inputs = broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); - auto& shape = inputs[0].shape(); + auto shape = inputs[0].shape(); return array( shape, out_type, @@ -4187,10 +4187,12 @@ array conv_transpose_general( output_padding[i]; // Adjust with output_padding } + auto ndim = stride.size(); + return conv_general( /* const array& input = */ input, /* const array& weight = */ weight, - /* std::vector stride = */ std::vector(stride.size(), 1), + /* std::vector stride = */ std::vector(ndim, 1), /* std::vector padding_lo = */ std::move(padding_lo), /* std::vector padding_hi = */ std::move(padding_hi), /* std::vector kernel_dilation = */ std::move(dilation),