From b120ba40fa46a9dcee1c8ee1ffa2a32d4462333e Mon Sep 17 00:00:00 2001 From: Jason Schulz Date: Wed, 10 Jun 2026 11:40:32 -0400 Subject: [PATCH] feat: stripe RDMA traffic across multiple Thunderbolt cables The jaccl devices matrix previously kept a single interface name per node pair, picked independently for each direction. With more than one cable between two nodes the two directions could name interfaces on different cables, and the resulting queue pairs silently deadlocked during model load. Any additional cables were also invisible to the backend. Each matrix cell now carries every physical link between a pair, with both directions derived from one enumeration so that rail k of [i][j] and rail k of [j][i] always name the two ends of the same cable (jaccl pairs queue pairs across ranks by index). When any pair has more than one link the worker sets MLX_JACCL_RING=1, selecting the jaccl ring backend, which stripes collectives across all links (the default mesh backend only uses the first). Single-link setups keep the mesh backend and existing behaviour. Co-Authored-By: Claude Fable 5 --- src/exo/master/placement_utils.py | 40 +++++++++---- src/exo/master/tests/test_placement.py | 73 ++++++++++++++++++++++-- src/exo/shared/types/worker/instances.py | 5 +- src/exo/worker/engines/mlx/utils_mlx.py | 20 ++++++- 4 files changed, 117 insertions(+), 21 deletions(-) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 0375e97e01..da7dc4d48e 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -296,32 +296,48 @@ def get_shard_assignments( def get_mlx_jaccl_devices_matrix( selected_cycle: list[NodeId], cycle_digraph: Topology, -) -> list[list[str | None]]: +) -> list[list[list[str]]]: """Build connectivity matrix mapping device i to device j via RDMA interface names. - The matrix element [i][j] contains the interface name on device i that connects - to device j, or None if no connection exists or no interface name is found. - Diagonal elements are always None. + The matrix element [i][j] contains the interface names on device i that + connect to device j, one per physical link (e.g. one per Thunderbolt + cable). Diagonal elements are always empty. + + Element k of [i][j] and element k of [j][i] are the two endpoints of the + same physical link. The jaccl backend pairs queue pairs between ranks by + index, so both cells must enumerate links in the same order; deriving both + directions from a single enumeration guarantees this even when more than + one link connects the same pair of nodes. """ num_nodes = len(selected_cycle) - matrix: list[list[str | None]] = [ - [None for _ in range(num_nodes)] for _ in range(num_nodes) + matrix: list[list[list[str]]] = [ + [[] for _ in range(num_nodes)] for _ in range(num_nodes) ] for i, node_i in enumerate(selected_cycle): - for j, node_j in enumerate(selected_cycle): - if i == j: - continue + for j in range(i + 1, num_nodes): + node_j = selected_cycle[j] + # Each directed edge carries the interface names of both + # endpoints, so edges from either direction describe the same + # physical link and can be merged into one set of endpoint pairs. + links: set[tuple[str, str]] = set() for conn in cycle_digraph.get_all_connections_between(node_i, node_j): if isinstance(conn, RDMAConnection): - matrix[i][j] = conn.source_rdma_iface - break - else: + links.add((conn.source_rdma_iface, conn.sink_rdma_iface)) + for conn in cycle_digraph.get_all_connections_between(node_j, node_i): + if isinstance(conn, RDMAConnection): + links.add((conn.sink_rdma_iface, conn.source_rdma_iface)) + + if not links: raise ValueError( "Current jaccl backend requires all-to-all RDMA connections" ) + ordered_links = sorted(links) + matrix[i][j] = [iface_i for iface_i, _ in ordered_links] + matrix[j][i] = [iface_j for _, iface_j in ordered_links] + return matrix diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index b891187266..bb24238b48 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -6,6 +6,7 @@ get_transition_events, place_instance, ) +from exo.master.placement_utils import get_mlx_jaccl_devices_matrix from exo.master.tests.conftest import ( create_node_memory, create_node_network, @@ -35,7 +36,7 @@ InputMessageContent, TextGenerationTaskParams, ) -from exo.shared.types.topology import Connection, SocketConnection +from exo.shared.types.topology import Connection, RDMAConnection, SocketConnection from exo.shared.types.worker.downloads import ( DownloadCompleted, DownloadFailed, @@ -496,7 +497,7 @@ def test_tensor_rdma_backend_connectivity_matrix( matrix = instance.jaccl_devices assert len(matrix) == 3 for i in range(3): - assert matrix[i][i] is None + assert matrix[i][i] == [] assigned_nodes = list(instance.shard_assignments.node_to_runner.keys()) node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)} @@ -505,9 +506,9 @@ def test_tensor_rdma_backend_connectivity_matrix( idx_b = node_to_idx[node_b] idx_c = node_to_idx[node_c] - assert matrix[idx_a][idx_b] == "rdma_en3" - assert matrix[idx_b][idx_c] == "rdma_en4" - assert matrix[idx_c][idx_a] == "rdma_en5" + assert matrix[idx_a][idx_b] == ["rdma_en3"] + assert matrix[idx_b][idx_c] == ["rdma_en4"] + assert matrix[idx_c][idx_a] == ["rdma_en5"] # Verify coordinators are set for all nodes assert len(instance.jaccl_coordinators) == 3 @@ -523,6 +524,68 @@ def test_tensor_rdma_backend_connectivity_matrix( assert len(ip_part.split(".")) == 4 +def test_jaccl_devices_matrix_multiple_links_between_two_nodes() -> None: + """Two cables between the same pair of nodes must produce index-aligned + rails: entry k of [i][j] and entry k of [j][i] are the two ends of the + same physical link, regardless of edge insertion order.""" + topology = Topology() + node_a = NodeId() + node_b = NodeId() + topology.add_node(node_a) + topology.add_node(node_b) + + # Cable 1: en6 on A <-> en7 on B. Cable 2: en7 on A <-> en6 on B. + # The crossed names mean a naive "first edge per direction" pick can pair + # interfaces that are not on the same cable. + cable_1_a_view = RDMAConnection( + source_rdma_iface="rdma_en6", sink_rdma_iface="rdma_en7" + ) + cable_2_a_view = RDMAConnection( + source_rdma_iface="rdma_en7", sink_rdma_iface="rdma_en6" + ) + cable_1_b_view = RDMAConnection( + source_rdma_iface="rdma_en7", sink_rdma_iface="rdma_en6" + ) + cable_2_b_view = RDMAConnection( + source_rdma_iface="rdma_en6", sink_rdma_iface="rdma_en7" + ) + + # Insert the two directions in opposite cable order on purpose. + topology.add_connection(Connection(source=node_a, sink=node_b, edge=cable_1_a_view)) + topology.add_connection(Connection(source=node_a, sink=node_b, edge=cable_2_a_view)) + topology.add_connection(Connection(source=node_b, sink=node_a, edge=cable_2_b_view)) + topology.add_connection(Connection(source=node_b, sink=node_a, edge=cable_1_b_view)) + + matrix = get_mlx_jaccl_devices_matrix([node_a, node_b], topology) + + assert matrix[0][0] == [] + assert matrix[1][1] == [] + assert len(matrix[0][1]) == 2 + assert len(matrix[1][0]) == 2 + + # Each rail index must reference the two endpoints of one physical cable. + rails = list(zip(matrix[0][1], matrix[1][0], strict=True)) + assert sorted(rails) == [("rdma_en6", "rdma_en7"), ("rdma_en7", "rdma_en6")] + + +def test_jaccl_devices_matrix_single_link_remains_single_entry() -> None: + topology = Topology() + node_a = NodeId() + node_b = NodeId() + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(6)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(6)) + ) + + matrix = get_mlx_jaccl_devices_matrix([node_a, node_b], topology) + + assert matrix == [[[], ["rdma_en6"]], [["rdma_en6"], []]] + + def _build_three_node_rdma_topology() -> tuple[ Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo] ]: diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index 16233f3f05..3d5d3927c0 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -31,7 +31,10 @@ class MlxRingInstance(BaseInstance): class MlxJacclInstance(BaseInstance): - jaccl_devices: list[list[str | None]] + # jaccl_devices[i][j] lists the RDMA interface names on rank i that + # connect to rank j, one entry per physical link. Entry k of [i][j] and + # entry k of [j][i] are the two ends of the same link. + jaccl_devices: list[list[list[str]]] jaccl_coordinators: dict[NodeId, str] diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 730abf64e3..0d0f220b1c 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -122,9 +122,7 @@ def mlx_distributed_init( case MlxJacclInstance( jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators ): - assert all( - jaccl_devices[i][i] is None for i in range(len(jaccl_devices)) - ) + assert all(jaccl_devices[i][i] == [] for i in range(len(jaccl_devices))) # Use RDMA connectivity matrix jaccl_devices_json = json.dumps(jaccl_devices) @@ -140,6 +138,22 @@ def mlx_distributed_init( os.environ["MLX_IBV_DEVICES"] = coordination_file os.environ["MLX_RANK"] = str(rank) os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator + + # The default jaccl mesh backend only ever uses the first link + # between two ranks; the ring backend stripes traffic across + # every link. Prefer the ring whenever any pair of ranks has + # more than one physical link so the extra bandwidth is used. + max_links = max( + (len(cell) for row in jaccl_devices for cell in row), + default=0, + ) + if max_links > 1: + os.environ["MLX_JACCL_RING"] = "1" + logger.info( + f"rank {rank} MLX_JACCL_RING=1 " + f"(up to {max_links} links per peer)" + ) + group = mx.distributed.init(backend="jaccl", strict=True) logger.info(f"Rank {rank} mlx distributed initialization complete")