diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 4967c50a8b..c7be012083 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -24,6 +24,28 @@ namespace mlx::core { static constexpr int MAX_ACTIVE_TASKS = 10; +namespace { + +// Create a tracer copy of a primal for use in vjp/jvp. If the primal is a +// stale Copy from a previous transform call (not an active tracer), peel it +// off to prevent copy-chain accumulation when containers feed tracers back. +array make_tracer(const array& p) { + auto s = p.has_primitive() ? p.primitive().stream() + : default_stream(default_device()); + auto source = p; + if (!p.is_tracer() && p.has_primitive() && !p.inputs().empty()) { + auto& prim = p.primitive(); + if (typeid(prim) == typeid(Copy)) { + source = p.inputs()[0]; + } + } + auto out = copy(source, s); + out.set_tracer(true); + return out; +} + +} // namespace + /* This class is only meant to be used in eval * for synchronizing with the main thread. */ class Synchronizer : public Primitive { @@ -335,10 +357,7 @@ std::pair, std::vector> vjp( // Make tracers from given primals std::vector primals_; for (auto& p : primals) { - auto s = p.has_primitive() ? p.primitive().stream() - : default_stream(default_device()); - primals_.push_back(copy(p, s)); // Does not do a deep copy - primals_.back().set_tracer(true); + primals_.push_back(make_tracer(p)); } // Pass tracer primals through the function @@ -543,10 +562,7 @@ std::pair, std::vector> jvp( std::vector primals_; for (auto& p : primals) { - auto s = p.has_primitive() ? p.primitive().stream() - : default_stream(default_device()); - primals_.push_back(copy(p, s)); // Does not do a deep copy - primals_.back().set_tracer(true); + primals_.push_back(make_tracer(p)); } auto outputs = fun(primals_); diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 25c871cdf9..40d84a39dc 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -17,6 +17,24 @@ using namespace mlx::core; +namespace { + +int count_graph_nodes(const array& x, const std::string& node_name) { + std::ostringstream oss; + print_graph(oss, x); + auto graph = oss.str(); + + int count = 0; + size_t pos = 0; + while ((pos = graph.find(node_name, pos)) != std::string::npos) { + count++; + pos += node_name.size(); + } + return count; +} + +} // namespace + TEST_CASE("test stop gradient") { auto x = zeros({5, 5}); auto y = stop_gradient(x); @@ -328,6 +346,42 @@ TEST_CASE("test grad") { } } +TEST_CASE("test transform container reuse does not accumulate stale wrappers") { + auto x = ones({128}); + + SUBCASE("grad reuses a single copy wrapper") { + std::vector container = {array(1.0f)}; + auto grad_fn = grad([&container](const std::vector& inputs) { + container[0] = inputs[0]; + return sum(inputs[1]); + }); + + for (int i = 0; i < 5; ++i) { + auto grads = grad_fn({container[0], x}); + eval(grads); + } + + CHECK_EQ(count_graph_nodes(container[0], "Copy "), 1); + } + + SUBCASE("jvp reuses a single copy wrapper") { + std::vector container = {array(1.0f)}; + auto fun = [&container](const std::vector& inputs) { + container[0] = inputs[0]; + return std::vector{sum(inputs[1])}; + }; + + for (int i = 0; i < 5; ++i) { + auto [outputs, tangents] = + jvp(fun, {container[0], x}, {array(1.0f), ones({128})}); + eval(outputs); + eval(tangents); + } + + CHECK_EQ(count_graph_nodes(container[0], "Copy "), 1); + } +} + TEST_CASE("test creation grads") { // Test astype {