Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,15 +1766,15 @@ std::pair<array, array> broadcast_arrays(
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
}

array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
bool_,
Expand All @@ -1785,7 +1785,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
}
Expand All @@ -1796,7 +1796,7 @@ array greater_equal(
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
bool_,
Expand All @@ -1807,15 +1807,15 @@ array greater_equal(
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
}

array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
bool_,
Expand Down Expand Up @@ -2811,7 +2811,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
bool_,
Expand All @@ -2825,7 +2825,7 @@ array operator&&(const array& a, const array& b) {
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
bool_,
Expand All @@ -2845,7 +2845,7 @@ array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
}
Expand All @@ -2858,7 +2858,7 @@ array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
out_type,
Expand All @@ -2874,7 +2874,7 @@ array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
out_type,
Expand All @@ -2890,7 +2890,7 @@ array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
}
Expand All @@ -2914,7 +2914,7 @@ array floor_divide(
}

auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
}
Expand All @@ -2923,7 +2923,7 @@ array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
dtype,
Expand Down Expand Up @@ -2953,7 +2953,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
out_type,
Expand All @@ -2965,7 +2965,7 @@ array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
out_type,
Expand Down Expand Up @@ -3048,7 +3048,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
}
Expand Down Expand Up @@ -3144,7 +3144,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
auto shape = inputs[0].shape();
return array(
shape,
out_type,
Expand Down Expand Up @@ -4187,10 +4187,12 @@ array conv_transpose_general(
output_padding[i]; // Adjust with output_padding
}

auto ndim = stride.size();

return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ std::vector(stride.size(), 1),
/* std::vector<int> stride = */ std::vector(ndim, 1),
/* std::vector<int> padding_lo = */ std::move(padding_lo),
/* std::vector<int> padding_hi = */ std::move(padding_hi),
/* std::vector<int> kernel_dilation = */ std::move(dilation),
Expand Down
Loading