Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ namespace mlx::core {

static constexpr int MAX_ACTIVE_TASKS = 10;

namespace {

// Create a tracer copy of a primal for use in vjp/jvp. If the primal is a
// stale Copy from a previous transform call (not an active tracer), peel it
// off to prevent copy-chain accumulation when containers feed tracers back.
array make_tracer(const array& p) {
auto s = p.has_primitive() ? p.primitive().stream()
: default_stream(default_device());
auto source = p;
if (!p.is_tracer() && p.has_primitive() && !p.inputs().empty()) {
auto& prim = p.primitive();
if (typeid(prim) == typeid(Copy)) {
source = p.inputs()[0];
}
}
auto out = copy(source, s);
out.set_tracer(true);
return out;
}

} // namespace

/* This class is only meant to be used in eval
* for synchronizing with the main thread. */
class Synchronizer : public Primitive {
Expand Down Expand Up @@ -335,10 +357,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// Make tracers from given primals
std::vector<array> primals_;
for (auto& p : primals) {
auto s = p.has_primitive() ? p.primitive().stream()
: default_stream(default_device());
primals_.push_back(copy(p, s)); // Does not do a deep copy
primals_.back().set_tracer(true);
primals_.push_back(make_tracer(p));
}

// Pass tracer primals through the function
Expand Down Expand Up @@ -543,10 +562,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(

std::vector<array> primals_;
for (auto& p : primals) {
auto s = p.has_primitive() ? p.primitive().stream()
: default_stream(default_device());
primals_.push_back(copy(p, s)); // Does not do a deep copy
primals_.back().set_tracer(true);
primals_.push_back(make_tracer(p));
}
auto outputs = fun(primals_);

Expand Down
54 changes: 54 additions & 0 deletions tests/autograd_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@

using namespace mlx::core;

namespace {

int count_graph_nodes(const array& x, const std::string& node_name) {
std::ostringstream oss;
print_graph(oss, x);
auto graph = oss.str();

int count = 0;
size_t pos = 0;
while ((pos = graph.find(node_name, pos)) != std::string::npos) {
count++;
pos += node_name.size();
}
return count;
}

} // namespace

TEST_CASE("test stop gradient") {
auto x = zeros({5, 5});
auto y = stop_gradient(x);
Expand Down Expand Up @@ -328,6 +346,42 @@ TEST_CASE("test grad") {
}
}

TEST_CASE("test transform container reuse does not accumulate stale wrappers") {
auto x = ones({128});

SUBCASE("grad reuses a single copy wrapper") {
std::vector<array> container = {array(1.0f)};
auto grad_fn = grad([&container](const std::vector<array>& inputs) {
container[0] = inputs[0];
return sum(inputs[1]);
});

for (int i = 0; i < 5; ++i) {
auto grads = grad_fn({container[0], x});
eval(grads);
}

CHECK_EQ(count_graph_nodes(container[0], "Copy "), 1);
}

SUBCASE("jvp reuses a single copy wrapper") {
std::vector<array> container = {array(1.0f)};
auto fun = [&container](const std::vector<array>& inputs) {
container[0] = inputs[0];
return std::vector<array>{sum(inputs[1])};
};

for (int i = 0; i < 5; ++i) {
auto [outputs, tangents] =
jvp(fun, {container[0], x}, {array(1.0f), ones({128})});
eval(outputs);
eval(tangents);
}

CHECK_EQ(count_graph_nodes(container[0], "Copy "), 1);
}
}

TEST_CASE("test creation grads") {
// Test astype
{
Expand Down
Loading