Skip to content
3 changes: 3 additions & 0 deletions include/command_graph_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class command_graph_generator {
std::optional<reduction_info> pending_reduction;

std::string debug_name;

const task* generator_task = nullptr;
};

struct host_object_state {
Expand All @@ -111,6 +113,7 @@ class command_graph_generator {
};

public:
bool is_generator_kernel(const task& tsk) const;
struct policy_set {
error_policy uninitialized_read_error = error_policy::panic;
error_policy overlapping_write_error = error_policy::panic;
Expand Down
15 changes: 15 additions & 0 deletions include/range_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@


namespace celerity {

// Forward-declaration so we can detect whether the functor is one_to_one
namespace access {
struct one_to_one;
}

namespace detail {

template <typename Functor, int BufferDims, int KernelDims>
Expand Down Expand Up @@ -85,6 +91,8 @@ namespace detail {
virtual region<3> map_3(const chunk<3>& chnk) const = 0;

virtual ~range_mapper_base() = default;

virtual bool is_one_to_one() const { return false; }
};

template <int BufferDims, typename Functor>
Expand All @@ -107,6 +115,13 @@ namespace detail {
region<3> map_3(const chunk<2>& chnk) const override { return map<3>(chnk); }
region<3> map_3(const chunk<3>& chnk) const override { return map<3>(chnk); }

// Override the s_one_to_one() to detect if the functor is specifically celerity::access::one_to_one:
bool is_one_to_one() const override {
// If the Functor is celerity::access::one_to_one, return true
if constexpr(std::is_same_v<Functor, celerity::access::one_to_one>) { return true; }
return false;
}

private:
Functor m_rmfn;
range<BufferDims> m_buffer_size;
Expand Down
9 changes: 9 additions & 0 deletions include/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ namespace detail {
/// Returns a set of bounding boxes, one for each accessed region, that must be allocated contiguously.
box_vector<3> compute_required_contiguous_boxes(const buffer_id bid, const box<3>& execution_range) const;

// Retrieves the range mapper associated with a specific buffer ID, or nullptr if not found.
const range_mapper_base* get_range_mapper(buffer_id search_bid) const {
for(const auto& ba : m_accesses) {
if(ba.bid == search_bid) { return ba.range_mapper.get(); }
}
return nullptr; // Not found
}


private:
std::vector<buffer_access> m_accesses;
std::unordered_set<buffer_id> m_accessed_buffers; ///< Cached set of buffer ids found in m_accesses
Expand Down
50 changes: 48 additions & 2 deletions src/command_graph_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ bool is_topologically_sorted(Iterator begin, Iterator end) {
return true;
}

bool command_graph_generator::is_generator_kernel(const task& tsk) const {
if(tsk.get_type() != task_type::device_compute) return false;

// Must not have a hint that modifies splitting:
if(tsk.get_hint<experimental::hints::split_1d>() != nullptr || tsk.get_hint<experimental::hints::split_2d>() != nullptr) { return false; }

// Must have exactly one buffer access
const auto& bam = tsk.get_buffer_access_map();
if(bam.get_num_accesses() != 1) return false;

// That single access must be discard_write
const auto [bid, mode] = bam.get_nth_access(0);
if(mode != access_mode::discard_write) return false;

// Must produce exactly the entire buffer
const auto full_box = box(subrange({}, tsk.get_global_size()));
if(bam.get_task_produced_region(bid) != full_box) return false;

// Confirm the *range mapper* is truly one_to_one:
const auto rm = bam.get_range_mapper(bid);
if(rm == nullptr || !rm->is_one_to_one()) { return false; }

return true;
}

std::vector<const command*> command_graph_generator::build_task(const task& tsk) {
const auto epoch_to_prune_before = m_epoch_for_new_commands;
batch current_batch;
Expand Down Expand Up @@ -455,6 +480,29 @@ void command_graph_generator::update_local_buffer_fresh_regions(const task& tsk,
}

void command_graph_generator::generate_distributed_commands(batch& current_batch, const task& tsk) {
// If it's a generator kernel, we generate commands immediately and skip partial generation altogether.
if(is_generator_kernel(tsk)) {
// Identify which buffer is discard-written
const auto [gen_bid, _] = tsk.get_buffer_access_map().get_nth_access(0);
auto& bstate = m_buffers.at(gen_bid);
const auto chunks = split_task_and_assign_chunks(tsk);

// Create a command for each chunk that belongs to our local node.
for(const auto& a_chunk : chunks) {
if(a_chunk.executed_on != m_local_nid) continue;
auto* cmd = create_command<execution_command>(current_batch, &tsk, subrange<3>{a_chunk.chnk}, false,
[&](const auto& record_debug_info) { record_debug_info(tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }); });

// Mark that subrange as freshly written, so there’s no “uninitialized read” later.
box<3> write_box(a_chunk.chnk.offset, a_chunk.chnk.offset + a_chunk.chnk.range);
region<3> written_region{write_box};
bstate.local_last_writer.update_region(written_region, cmd);
bstate.initialized_region = region_union(bstate.initialized_region, written_region);
}
// Return here so we skip the normal device-logic below.
return;
}

const auto chunks = split_task_and_assign_chunks(tsk);
const auto chunks_with_requirements = compute_per_chunk_requirements(tsk, chunks);

Expand Down Expand Up @@ -521,8 +569,6 @@ void command_graph_generator::generate_distributed_commands(batch& current_batch

if(!produced.empty()) {
generate_anti_dependencies(tsk, bid, buffer.local_last_writer, produced, cmd);

// Update last writer
buffer.local_last_writer.update_region(produced, cmd);
buffer.replicated_regions.update_region(produced, node_bitset{});

Expand Down
59 changes: 59 additions & 0 deletions test/command_graph_general_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,62 @@ TEST_CASE("command_graph_generator throws in tests if it detects overlapping wri
"range mapper for this write access or constrain the split via experimental::constrain_split to make the access non-overlapping.");
}
}

TEST_CASE("results form generator kernels are never communicated between nodes", "[command_graph_generator][owner-computes]") {
const bool split_2d = GENERATE(values({0, 1}));
CAPTURE(split_2d);

const size_t num_nodes = 4;
cdag_test_context cctx(num_nodes); // 4 nodes, so we can get a true 2D work assignment for the timestep kernel
auto buf = cctx.create_buffer<2>({256, 256}); // a 256x256 buffer

const auto tid_init = cctx.device_compute(buf.get_range()) //
.discard_write(buf, celerity::access::one_to_one())
.name("init")
.submit();
const auto tid_ts0 = cctx.device_compute(buf.get_range()) //
.hint_if(split_2d, experimental::hints::split_2d())
.read_write(buf, celerity::access::one_to_one())
.name("timestep 0")
.submit();
const auto tid_ts1 = cctx.device_compute(buf.get_range()) //
.hint_if(split_2d, experimental::hints::split_2d())
.read_write(buf, celerity::access::one_to_one())
.name("timestep 1")
.submit();

CHECK(cctx.query<execution_command_record>().count_per_node() == 3); // one for each task above
CHECK(cctx.query<push_command_record>().total_count() == 0);
CHECK(cctx.query<await_push_command_record>().total_count() == 0);

const auto inits = cctx.query<execution_command_record>(tid_init);
const auto ts0s = cctx.query<execution_command_record>(tid_ts0);
const auto ts1s = cctx.query<execution_command_record>(tid_ts1);
CHECK(inits.count_per_node() == 1);
CHECK(ts0s.count_per_node() == 1);
CHECK(ts1s.count_per_node() == 1);

for(node_id nid = 0; nid < num_nodes; ++nid) {
const auto n_init = inits.on(nid);
REQUIRE(n_init->accesses.size() == 1);

const auto generate = n_init->accesses.front();
CHECK(generate.bid == buf.get_id());
CHECK(generate.mode == access_mode::discard_write);

const auto n_ts0 = ts0s.on(nid);
CHECK(n_ts0.predecessors().contains(n_init));
REQUIRE(n_ts0->accesses.size() == 1);

const auto consume = n_ts0->accesses.front();
CHECK(consume.bid == buf.get_id());
CHECK(consume.mode == access_mode::read_write);

// generator kernel "init" has generated exactly the buffer subrange that is consumed by "timestep 0"
CHECK(consume.req == generate.req);

const auto n_ts1 = ts1s.on(nid);
CHECK(n_ts1.predecessors().contains(n_ts0));
CHECK_FALSE(n_ts1.predecessors().contains(n_init));
}
}