Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 1 addition & 105 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3580,110 +3580,6 @@ array logcumsumexp(

namespace {

// Conv helpers
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
return ((in_dim + padding - wt_dim) / stride) + 1;
}

// Conv helpers
inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1);
}

Shape conv_out_shape(
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
int N = in_shape[0];
int O = wt_shape[0];
Shape out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;

int spatial_dims = in_shape.size() - 2;

if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}

if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

for (; i < in_shape.size() - 1; i++) {
if (kernel_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Kernel dilation sizes must be positive."
<< " Got kernel dilation " << kernel_dilation << ".";
throw std::invalid_argument(msg.str());
}

if (input_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Input dilation sizes must be positive."
<< " Got input dilation " << input_dilation << ".";
throw std::invalid_argument(msg.str());
}

if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
<< pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}

if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive." << " Got strides "
<< strides << ".";
throw std::invalid_argument(msg.str());
}

int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
int id = dilate_size(in_shape[i], input_dilation[i - 1]);

out_shape[i] = conv_out_axis_size(
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);

if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding"
<< " cannot be smaller than weight spatial dimensions."
<< " Got error at axis " << i << " for input with shape " << in_shape
<< ", padding low " << pads_lo << ", padding high " << pads_hi
<< ", and weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
out_shape[i] = O;

return out_shape;
}

inline void
run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
if (!issubdtype(in.dtype(), floating)) {
Expand Down Expand Up @@ -3997,7 +3893,7 @@ array conv_general(
}

// Get output shapes
auto out_shape = conv_out_shape(
auto out_shape = Convolution::conv_out_shape(
in.shape(),
wt.shape(),
stride,
Expand Down
120 changes: 120 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,114 @@ array conv_weight_backward_patches(
return grad;
}

namespace {

// Conv helpers
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
return ((in_dim + padding - wt_dim) / stride) + 1;
}

// Conv helpers
inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1);
}

} // namespace

Shape Convolution::conv_out_shape(
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
int N = in_shape[0];
int O = wt_shape[0];
Shape out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;

int spatial_dims = in_shape.size() - 2;

if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << " for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}

if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << " for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}

for (; i < in_shape.size() - 1; i++) {
if (kernel_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Kernel dilation sizes must be positive."
<< " Got kernel dilation " << kernel_dilation << ".";
throw std::invalid_argument(msg.str());
}

if (input_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Input dilation sizes must be positive."
<< " Got input dilation " << input_dilation << ".";
throw std::invalid_argument(msg.str());
}

if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
<< pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}

if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive." << " Got strides "
<< strides << ".";
throw std::invalid_argument(msg.str());
}

int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
int id = dilate_size(in_shape[i], input_dilation[i - 1]);

out_shape[i] = conv_out_axis_size(
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);

if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding"
<< " cannot be smaller than weight spatial dimensions."
<< " Got error at axis " << i << " for input with shape " << in_shape
<< ", padding low " << pads_lo << ", padding high " << pads_hi
<< ", and weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
out_shape[i] = O;

return out_shape;
}

std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
Expand Down Expand Up @@ -1454,6 +1562,18 @@ bool Convolution::is_equivalent(const Primitive& other) const {
groups_ == c_other.groups_ && flip_ == c_other.flip_;
}

std::vector<Shape> Convolution::output_shapes(
const std::vector<array>& inputs) {
return {conv_out_shape(
inputs[0].shape(), // in_shape
inputs[1].shape(), // wt_shape
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_)};
}

std::vector<array> Copy::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
Expand Down
10 changes: 10 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ class Convolution : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_NAME(Convolution)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(
kernel_strides_,
Expand All @@ -761,6 +762,15 @@ class Convolution : public UnaryPrimitive {
flip_);
}

static Shape conv_out_shape(
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation);

private:
std::vector<int> padding_lo_;
std::vector<int> padding_hi_;
Expand Down
99 changes: 99 additions & 0 deletions python/tests/test_export_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,105 @@ def forward(x):
expected = forward(input_data)
self.assertTrue(mx.allclose(expected, out))

def test_export_conv_shapeless(self):
# Conv1d (NLC)
path = os.path.join(self.test_dir, "conv1d.mlxfn")

class M1(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False)

def __call__(self, x):
return self.c(x)

m1 = M1()
mx.eval(m1.parameters())

def f1(x):
return m1(x)

x = mx.random.normal(shape=(4, 64, 3))
mx.export_function(path, f1, x, shapeless=True)
f1_imp = mx.import_function(path)
for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt)))

# Conv2d (NHWC)
path = os.path.join(self.test_dir, "conv2d.mlxfn")

class M2(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)

def __call__(self, x):
return self.c(x)

m2 = M2()
mx.eval(m2.parameters())

def f2(x):
return m2(x)

x = mx.random.normal(shape=(2, 32, 32, 3))
mx.export_function(path, f2, x, shapeless=True)
f2_imp = mx.import_function(path)
for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt)))

# Conv3d (NDHWC)
path = os.path.join(self.test_dir, "conv3d.mlxfn")

class M3(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False)

def __call__(self, x):
return self.c(x)

m3 = M3()
mx.eval(m3.parameters())

def f3(x):
return m3(x)

x = mx.random.normal(shape=(1, 8, 8, 8, 2))
mx.export_function(path, f3, x, shapeless=True)
f3_imp = mx.import_function(path)
for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt)))

# Grouped Conv2d (NHWC)
path = os.path.join(self.test_dir, "conv2d_grouped.mlxfn")

class MG(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.Conv2d(
4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False
)

def __call__(self, x):
return self.c(x)

mg = MG()
mx.eval(mg.parameters())

def fg(x):
return mg(x)

x = mx.random.normal(shape=(2, 32, 32, 4))
mx.export_function(path, fg, x, shapeless=True)
fg_imp = mx.import_function(path)
for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]:
xt = mx.random.normal(shape=shape)
self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt)))

def test_export_control_flow(self):

def fun(x, y):
Expand Down
Loading