Skip to content

Commit c89c89a

Browse files
authored
Add emulated multi-host T3K testing via single-host, multi-process (tenstorrent#25285)
### Ticket Link to Github Issue ### Problem description Add basic sanity regression testing for emulating distributed multi-host `MeshDevice` ### What's changed - Added some basic sanity tests to T3K unit tests for big-mesh - Remove some hardcoding in ControlPlane for `t3k_dual_host_mesh_graph_descriptor.yaml` handling ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes: https://github.com/tenstorrent/tt-metal/actions/runs/16360533162 (Pending) - [x] [T3K Unit Tests]: https://github.com/tenstorrent/tt-metal/actions/runs/16361825747
1 parent e97b310 commit c89c89a

File tree

9 files changed

+100
-34
lines changed

9 files changed

+100
-34
lines changed

.github/workflows/t3000-unit-tests-impl.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
test-group: [
2626
{ name: "t3k ttmetal tests", arch: wormhole_b0, cmd: run_t3000_ttmetal_tests, timeout: 35, owner_id: ULMEPM2MA}, #Sean Nijjar
2727
{ name: "t3k ttnn tests", arch: wormhole_b0, cmd: run_t3000_ttnn_tests, timeout: 45, owner_id: UBHPP2NDP}, #Joseph Chu
28+
{ name: "t3k big-mesh multiprocess tests", arch: wormhole_b0, cmd: run_t3000_dual_rank_big_mesh_tests, timeout: 20, owner_id: UBHPP2NDP}, #Joseph Chu
2829
{ name: "t3k tt_metal multiprocess tests", arch: wormhole_b0, cmd: run_t3000_tt_metal_multiprocess_tests, timeout: 5, owner_id: U03NG0A5ND7}, #Aditya Saigal
2930
{ name: "t3k falcon7b tests", arch: wormhole_b0, cmd: run_t3000_falcon7b_tests, timeout: 30, owner_id: UBHPP2NDP}, #Joseph Chu
3031
{ name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 30, owner_id: U053W15B6JF}, #Djordje Ivanovic

tests/scripts/t3000/run_t3000_unit_tests.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ run_t3000_ttmetal_tests() {
4343
fi
4444
}
4545

46+
run_t3000_dual_rank_big_mesh_tests() {
47+
tt-run --rank-binding tests/tt_metal/distributed/config/2x4_multiprocess_rank_bindings.yaml --mpi-args "--allow-run-as-root --tag-output" build/test/tt_metal/distributed/multiprocess/distributed_multiprocess_tests --gtest_filter="*BigMeshDualRankTestT3K*"
48+
}
49+
4650
run_t3000_ttfabric_tests() {
4751
# Record the start time
4852
fail=0

tests/tt_metal/distributed/multiprocess/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ target_sources(
44
PRIVATE
55
main.cpp
66
test_visible_devices_mp.cpp
7+
test_sanity.cpp
78
)
89
set_target_properties(
910
distributed_multiprocess_tests

tests/tt_metal/distributed/multiprocess/run_visible_devices_mp_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ for config in "${DEVICE_CONFIGS[@]}"; do
1717
echo "------------------------------------------------"
1818

1919
# Run with mpirun, setting the environment variable
20-
TT_METAL_VISIBLE_DEVICES="$config" mpirun --allow-run-as-root -np 1 ./build/test/tt_metal/distributed/multiprocess/distributed_multiprocess_tests
20+
TT_METAL_VISIBLE_DEVICES="$config" mpirun --allow-run-as-root -np 1 ./build/test/tt_metal/distributed/multiprocess/distributed_multiprocess_tests --gtest_filter="*VisibleDevicesMPTest*"
2121

2222
if [ $? -eq 0 ]; then
2323
echo "✓ [distributed tests] Test passed for configuration: $config"
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include <fmt/base.h>
6+
#include <gtest/gtest.h>
7+
#include <gmock/gmock.h>
8+
9+
#include <impl/context/metal_context.hpp>
10+
#include <tt-metalium/control_plane.hpp>
11+
#include <tt-metalium/distributed_context.hpp>
12+
#include <tt-metalium/system_mesh.hpp>
13+
#include <tt-metalium/mesh_coord.hpp>
14+
#include <tt-metalium/distributed_host_buffer.hpp>
15+
#include <tt-metalium/fabric_types.hpp>
16+
#include <tt-metalium/host_buffer.hpp>
17+
18+
#include <tt-metalium/tt_metal.hpp>
19+
20+
namespace tt::tt_metal::distributed {
21+
22+
using tt_fabric::HostRankId;
23+
using tt_fabric::MeshId;
24+
using tt_fabric::MeshScope;
25+
26+
TEST(BigMeshDualRankTestT3K, DistributedContext) {
27+
auto& dctx = MetalContext::instance().get_distributed_context();
28+
auto world_size = dctx.size();
29+
EXPECT_EQ(*world_size, 2);
30+
}
31+
32+
TEST(BigMeshDualRankTestT3K, LocalRankBinding) {
33+
auto& dctx = MetalContext::instance().get_distributed_context();
34+
auto& control_plane = MetalContext::instance().get_control_plane();
35+
36+
tt_fabric::HostRankId local_rank_binding = control_plane.get_local_host_rank_id_binding();
37+
if (*dctx.rank() == 0) {
38+
EXPECT_EQ(*local_rank_binding, 0);
39+
} else {
40+
EXPECT_EQ(*local_rank_binding, 1);
41+
}
42+
}
43+
44+
TEST(BigMeshDualRankTestT3K, SystemMeshValidation) {
45+
EXPECT_NO_THROW({
46+
const auto& system_mesh = SystemMesh::instance();
47+
EXPECT_EQ(system_mesh.shape(), MeshShape(2,4));
48+
EXPECT_EQ(system_mesh.local_shape(), MeshShape(2,2));
49+
});
50+
}
51+
52+
TEST(BigMeshDualRankTestT3K, MeshDevice2x4Validation) {
53+
auto mesh_device = MeshDevice::create(MeshDeviceConfig(MeshShape(2,4)), DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);
54+
EXPECT_EQ(mesh_device->shape(), MeshShape(2,4));
55+
}
56+
57+
TEST(BigMeshDualRankTestT3K, SystemMeshShape) {
58+
const auto& system_mesh = SystemMesh::instance();
59+
EXPECT_EQ(system_mesh.local_shape(), MeshShape(2, 2));
60+
61+
auto& control_plane = MetalContext::instance().get_control_plane();
62+
auto rank = control_plane.get_local_host_rank_id_binding();
63+
64+
if (rank == HostRankId{0}) {
65+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(0, 0)));
66+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(0, 1)));
67+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(1, 0)));
68+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(1, 1)));
69+
} else {
70+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(0, 2)));
71+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(0, 3)));
72+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(1, 2)));
73+
EXPECT_NO_THROW(system_mesh.get_physical_device_id(MeshCoordinate(1, 3)));
74+
}
75+
}
76+
77+
} // namespace tt::tt_metal::distributed

tt_metal/api/tt-metalium/mesh_coord.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,13 +408,18 @@ class DistributedMeshContainer : public MeshContainer<MaybeRemote<T>> {
408408
template <typename CoordSystem>
409409
void populate_local_region(const CoordSystem& coord_sys, const std::vector<T>& local_values) {
410410
TT_FATAL(
411-
local_values.size() == this->shape().mesh_size(),
411+
local_values.size() <= this->shape().mesh_size(),
412412
"Number of local values {} does not match mesh size {}",
413413
local_values.size(),
414414
this->shape().mesh_size());
415415
size_t idx = 0;
416-
for (const auto& local_coord : this->coord_range()) {
417-
this->at(local_coord) = MaybeRemote<T>::local(local_values[idx++]);
416+
for (const auto& coord : this->coord_range()) {
417+
// If local_values size equals mesh size, treat all coordinates as local
418+
// Otherwise, only populate coordinates that are actually local
419+
// TODO: Does not support reshaping. Implementation generalized to support this in PR #25190.
420+
if (local_values.size() == this->shape().mesh_size() || coord_sys.is_local(coord)) {
421+
this->at(coord) = MaybeRemote<T>::local(local_values[idx++]);
422+
}
418423
}
419424
}
420425

tt_metal/distributed/coordinate_translation.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <mesh_coord.hpp>
88
#include <tt-metalium/fabric_types.hpp>
99
#include <stdint.h>
10+
#include <tuple>
1011

1112
namespace tt {
1213
namespace tt_metal {
@@ -31,6 +32,10 @@ class PhysicalMeshCoordinate {
3132
MeshId mesh_id() const { return mesh_id_; }
3233
chip_id_t chip_id() const { return chip_id_; }
3334

35+
// Needed for reflect / fmt
36+
static constexpr auto attribute_names = std::forward_as_tuple("mesh_id", "chip_id");
37+
auto attribute_values() const { return std::forward_as_tuple(mesh_id_, chip_id_); }
38+
3439
private:
3540
MeshId mesh_id_{0};
3641
chip_id_t chip_id_{0};

tt_metal/distributed/system_mesh.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ MaybeRemoteDeviceId SystemMesh::Impl::get_maybe_remote_device_id(const MeshCoord
5252
return physical_coordinates_.at(coord).when(
5353
[&](const auto& physical_coord) {
5454
auto physical_device_id = get_physical_device_id(coord);
55-
log_debug(LogDistributed, "Mesh coordinate: {}, Physical device ID: {}", coord, physical_device_id);
55+
log_debug(LogDistributed, "Mesh coordinate: {} is local, Physical device ID: {}", coord, physical_device_id);
5656
return MaybeRemoteDeviceId::local(physical_device_id);
5757
},
5858
[&]() {
@@ -80,6 +80,8 @@ SystemMesh::Impl::Impl() :
8080
}
8181

8282
// Use populate_local_region to set up the distributed container
83+
log_debug(LogDistributed, "SystemMesh: Global shape: {}, Local shape: {}, Local offset: {}", global_shape_, local_coordinates.shape(), local_offset_);
84+
log_debug(LogDistributed, "SystemMesh: Populating local region with physical coordinates: {}", ordered_physical_coords);
8385
physical_coordinates_.populate_local_region(coord_system, ordered_physical_coords);
8486
}
8587

tt_metal/fabric/control_plane.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -509,35 +509,6 @@ std::map<FabricNodeId, chip_id_t> ControlPlane::get_logical_chip_to_physical_chi
509509
logical_mesh_chip_id_to_physical_chip_id_mapping.insert({FabricNodeId(MeshId{4}, i), physical_chip_ids[i]});
510510
}
511511
// This case can be depreciated once we have multi-host testing and validate it working
512-
} else if (mesh_graph_desc_filename == "t3k_dual_host_mesh_graph_descriptor.yaml") {
513-
// TODO(#24230): This path will soon be deprecated once we generalize logical mesh_chip_id to physical chip_id
514-
// mapping
515-
auto& cluster = tt::tt_metal::MetalContext::instance().get_cluster();
516-
auto chip_eth_coords = cluster.get_user_chip_ethernet_coordinates();
517-
std::vector<eth_coord_t> eth_coords;
518-
eth_coords.reserve(chip_eth_coords.size());
519-
for (const auto& [_, eth_coord] : chip_eth_coords) {
520-
eth_coords.push_back(eth_coord);
521-
}
522-
std::sort(eth_coords.begin(), eth_coords.end(), EthCoordComparator());
523-
524-
auto mesh_ids = this->get_local_mesh_id_bindings();
525-
auto mesh_id = mesh_ids.at(0); // Use the first mesh ID
526-
auto host_rank_id = this->get_local_host_rank_id_binding();
527-
auto fabric_chip_ids = this->routing_table_generator_->mesh_graph->get_chip_ids(mesh_id, host_rank_id).values();
528-
529-
TT_FATAL(
530-
fabric_chip_ids.size() == eth_coords.size(),
531-
"Number of fabric chip ids {} does not match number of eth coords {}",
532-
fabric_chip_ids.size(),
533-
eth_coords.size());
534-
for (std::uint32_t idx = 0; idx < fabric_chip_ids.size(); idx++) {
535-
auto fabric_chip_id = fabric_chip_ids.at(idx);
536-
auto eth_coord = eth_coords.at(idx);
537-
logical_mesh_chip_id_to_physical_chip_id_mapping.insert(
538-
{tt_fabric::FabricNodeId(mesh_id, fabric_chip_id),
539-
cluster.get_physical_chip_id_from_eth_coord(eth_coord)});
540-
}
541512
} else {
542513
// Iterate over every mesh defined in the mesh-graph descriptor and embed it on top of
543514
// the physical cluster using the generic helper.

0 commit comments

Comments
 (0)