diff --git a/include/command_graph_generator.h b/include/command_graph_generator.h index d5afa3f37..f17d31092 100644 --- a/include/command_graph_generator.h +++ b/include/command_graph_generator.h @@ -93,6 +93,8 @@ class command_graph_generator { std::optional pending_reduction; std::string debug_name; + + const task* generator_task = nullptr; }; struct host_object_state { @@ -111,6 +113,7 @@ class command_graph_generator { }; public: + bool is_generator_kernel(const task& tsk) const; struct policy_set { error_policy uninitialized_read_error = error_policy::panic; error_policy overlapping_write_error = error_policy::panic; diff --git a/include/range_mapper.h b/include/range_mapper.h index 1faa301cc..8b3959ff0 100644 --- a/include/range_mapper.h +++ b/include/range_mapper.h @@ -13,6 +13,12 @@ namespace celerity { + +// Forward-declaration so we can detect whether the functor is one_to_one +namespace access { + struct one_to_one; +} + namespace detail { template @@ -85,6 +91,8 @@ namespace detail { virtual region<3> map_3(const chunk<3>& chnk) const = 0; virtual ~range_mapper_base() = default; + + virtual bool is_one_to_one() const { return false; } }; template @@ -107,6 +115,13 @@ namespace detail { region<3> map_3(const chunk<2>& chnk) const override { return map<3>(chnk); } region<3> map_3(const chunk<3>& chnk) const override { return map<3>(chnk); } + // Override the s_one_to_one() to detect if the functor is specifically celerity::access::one_to_one: + bool is_one_to_one() const override { + // If the Functor is celerity::access::one_to_one, return true + if constexpr(std::is_same_v) { return true; } + return false; + } + private: Functor m_rmfn; range m_buffer_size; diff --git a/include/task.h b/include/task.h index ef84a6cbd..046560ddd 100644 --- a/include/task.h +++ b/include/task.h @@ -64,6 +64,15 @@ namespace detail { /// Returns a set of bounding boxes, one for each accessed region, that must be allocated contiguously. box_vector<3> compute_required_contiguous_boxes(const buffer_id bid, const box<3>& execution_range) const; + // Retrieves the range mapper associated with a specific buffer ID, or nullptr if not found. + const range_mapper_base* get_range_mapper(buffer_id search_bid) const { + for(const auto& ba : m_accesses) { + if(ba.bid == search_bid) { return ba.range_mapper.get(); } + } + return nullptr; // Not found + } + + private: std::vector m_accesses; std::unordered_set m_accessed_buffers; ///< Cached set of buffer ids found in m_accesses diff --git a/src/command_graph_generator.cc b/src/command_graph_generator.cc index c69618421..99257f33b 100644 --- a/src/command_graph_generator.cc +++ b/src/command_graph_generator.cc @@ -73,6 +73,31 @@ bool is_topologically_sorted(Iterator begin, Iterator end) { return true; } +bool command_graph_generator::is_generator_kernel(const task& tsk) const { + if(tsk.get_type() != task_type::device_compute) return false; + + // Must not have a hint that modifies splitting: + if(tsk.get_hint() != nullptr || tsk.get_hint() != nullptr) { return false; } + + // Must have exactly one buffer access + const auto& bam = tsk.get_buffer_access_map(); + if(bam.get_num_accesses() != 1) return false; + + // That single access must be discard_write + const auto [bid, mode] = bam.get_nth_access(0); + if(mode != access_mode::discard_write) return false; + + // Must produce exactly the entire buffer + const auto full_box = box(subrange({}, tsk.get_global_size())); + if(bam.get_task_produced_region(bid) != full_box) return false; + + // Confirm the *range mapper* is truly one_to_one: + const auto rm = bam.get_range_mapper(bid); + if(rm == nullptr || !rm->is_one_to_one()) { return false; } + + return true; +} + std::vector command_graph_generator::build_task(const task& tsk) { const auto epoch_to_prune_before = m_epoch_for_new_commands; batch current_batch; @@ -455,6 +480,29 @@ void command_graph_generator::update_local_buffer_fresh_regions(const task& tsk, } void command_graph_generator::generate_distributed_commands(batch& current_batch, const task& tsk) { + // If it's a generator kernel, we generate commands immediately and skip partial generation altogether. + if(is_generator_kernel(tsk)) { + // Identify which buffer is discard-written + const auto [gen_bid, _] = tsk.get_buffer_access_map().get_nth_access(0); + auto& bstate = m_buffers.at(gen_bid); + const auto chunks = split_task_and_assign_chunks(tsk); + + // Create a command for each chunk that belongs to our local node. + for(const auto& a_chunk : chunks) { + if(a_chunk.executed_on != m_local_nid) continue; + auto* cmd = create_command(current_batch, &tsk, subrange<3>{a_chunk.chnk}, false, + [&](const auto& record_debug_info) { record_debug_info(tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }); }); + + // Mark that subrange as freshly written, so there’s no “uninitialized read” later. + box<3> write_box(a_chunk.chnk.offset, a_chunk.chnk.offset + a_chunk.chnk.range); + region<3> written_region{write_box}; + bstate.local_last_writer.update_region(written_region, cmd); + bstate.initialized_region = region_union(bstate.initialized_region, written_region); + } + // Return here so we skip the normal device-logic below. + return; + } + const auto chunks = split_task_and_assign_chunks(tsk); const auto chunks_with_requirements = compute_per_chunk_requirements(tsk, chunks); @@ -521,8 +569,6 @@ void command_graph_generator::generate_distributed_commands(batch& current_batch if(!produced.empty()) { generate_anti_dependencies(tsk, bid, buffer.local_last_writer, produced, cmd); - - // Update last writer buffer.local_last_writer.update_region(produced, cmd); buffer.replicated_regions.update_region(produced, node_bitset{}); diff --git a/test/command_graph_general_tests.cc b/test/command_graph_general_tests.cc index b51ac7a77..d51491119 100644 --- a/test/command_graph_general_tests.cc +++ b/test/command_graph_general_tests.cc @@ -437,3 +437,62 @@ TEST_CASE("command_graph_generator throws in tests if it detects overlapping wri "range mapper for this write access or constrain the split via experimental::constrain_split to make the access non-overlapping."); } } + +TEST_CASE("results form generator kernels are never communicated between nodes", "[command_graph_generator][owner-computes]") { + const bool split_2d = GENERATE(values({0, 1})); + CAPTURE(split_2d); + + const size_t num_nodes = 4; + cdag_test_context cctx(num_nodes); // 4 nodes, so we can get a true 2D work assignment for the timestep kernel + auto buf = cctx.create_buffer<2>({256, 256}); // a 256x256 buffer + + const auto tid_init = cctx.device_compute(buf.get_range()) // + .discard_write(buf, celerity::access::one_to_one()) + .name("init") + .submit(); + const auto tid_ts0 = cctx.device_compute(buf.get_range()) // + .hint_if(split_2d, experimental::hints::split_2d()) + .read_write(buf, celerity::access::one_to_one()) + .name("timestep 0") + .submit(); + const auto tid_ts1 = cctx.device_compute(buf.get_range()) // + .hint_if(split_2d, experimental::hints::split_2d()) + .read_write(buf, celerity::access::one_to_one()) + .name("timestep 1") + .submit(); + + CHECK(cctx.query().count_per_node() == 3); // one for each task above + CHECK(cctx.query().total_count() == 0); + CHECK(cctx.query().total_count() == 0); + + const auto inits = cctx.query(tid_init); + const auto ts0s = cctx.query(tid_ts0); + const auto ts1s = cctx.query(tid_ts1); + CHECK(inits.count_per_node() == 1); + CHECK(ts0s.count_per_node() == 1); + CHECK(ts1s.count_per_node() == 1); + + for(node_id nid = 0; nid < num_nodes; ++nid) { + const auto n_init = inits.on(nid); + REQUIRE(n_init->accesses.size() == 1); + + const auto generate = n_init->accesses.front(); + CHECK(generate.bid == buf.get_id()); + CHECK(generate.mode == access_mode::discard_write); + + const auto n_ts0 = ts0s.on(nid); + CHECK(n_ts0.predecessors().contains(n_init)); + REQUIRE(n_ts0->accesses.size() == 1); + + const auto consume = n_ts0->accesses.front(); + CHECK(consume.bid == buf.get_id()); + CHECK(consume.mode == access_mode::read_write); + + // generator kernel "init" has generated exactly the buffer subrange that is consumed by "timestep 0" + CHECK(consume.req == generate.req); + + const auto n_ts1 = ts1s.on(nid); + CHECK(n_ts1.predecessors().contains(n_ts0)); + CHECK_FALSE(n_ts1.predecessors().contains(n_init)); + } +}