diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 0375e97e01..da7dc4d48e 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -296,32 +296,48 @@ def get_shard_assignments( def get_mlx_jaccl_devices_matrix( selected_cycle: list[NodeId], cycle_digraph: Topology, -) -> list[list[str | None]]: +) -> list[list[list[str]]]: """Build connectivity matrix mapping device i to device j via RDMA interface names. - The matrix element [i][j] contains the interface name on device i that connects - to device j, or None if no connection exists or no interface name is found. - Diagonal elements are always None. + The matrix element [i][j] contains the interface names on device i that + connect to device j, one per physical link (e.g. one per Thunderbolt + cable). Diagonal elements are always empty. + + Element k of [i][j] and element k of [j][i] are the two endpoints of the + same physical link. The jaccl backend pairs queue pairs between ranks by + index, so both cells must enumerate links in the same order; deriving both + directions from a single enumeration guarantees this even when more than + one link connects the same pair of nodes. """ num_nodes = len(selected_cycle) - matrix: list[list[str | None]] = [ - [None for _ in range(num_nodes)] for _ in range(num_nodes) + matrix: list[list[list[str]]] = [ + [[] for _ in range(num_nodes)] for _ in range(num_nodes) ] for i, node_i in enumerate(selected_cycle): - for j, node_j in enumerate(selected_cycle): - if i == j: - continue + for j in range(i + 1, num_nodes): + node_j = selected_cycle[j] + # Each directed edge carries the interface names of both + # endpoints, so edges from either direction describe the same + # physical link and can be merged into one set of endpoint pairs. + links: set[tuple[str, str]] = set() for conn in cycle_digraph.get_all_connections_between(node_i, node_j): if isinstance(conn, RDMAConnection): - matrix[i][j] = conn.source_rdma_iface - break - else: + links.add((conn.source_rdma_iface, conn.sink_rdma_iface)) + for conn in cycle_digraph.get_all_connections_between(node_j, node_i): + if isinstance(conn, RDMAConnection): + links.add((conn.sink_rdma_iface, conn.source_rdma_iface)) + + if not links: raise ValueError( "Current jaccl backend requires all-to-all RDMA connections" ) + ordered_links = sorted(links) + matrix[i][j] = [iface_i for iface_i, _ in ordered_links] + matrix[j][i] = [iface_j for _, iface_j in ordered_links] + return matrix diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index b891187266..bb24238b48 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -6,6 +6,7 @@ get_transition_events, place_instance, ) +from exo.master.placement_utils import get_mlx_jaccl_devices_matrix from exo.master.tests.conftest import ( create_node_memory, create_node_network, @@ -35,7 +36,7 @@ InputMessageContent, TextGenerationTaskParams, ) -from exo.shared.types.topology import Connection, SocketConnection +from exo.shared.types.topology import Connection, RDMAConnection, SocketConnection from exo.shared.types.worker.downloads import ( DownloadCompleted, DownloadFailed, @@ -496,7 +497,7 @@ def test_tensor_rdma_backend_connectivity_matrix( matrix = instance.jaccl_devices assert len(matrix) == 3 for i in range(3): - assert matrix[i][i] is None + assert matrix[i][i] == [] assigned_nodes = list(instance.shard_assignments.node_to_runner.keys()) node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)} @@ -505,9 +506,9 @@ def test_tensor_rdma_backend_connectivity_matrix( idx_b = node_to_idx[node_b] idx_c = node_to_idx[node_c] - assert matrix[idx_a][idx_b] == "rdma_en3" - assert matrix[idx_b][idx_c] == "rdma_en4" - assert matrix[idx_c][idx_a] == "rdma_en5" + assert matrix[idx_a][idx_b] == ["rdma_en3"] + assert matrix[idx_b][idx_c] == ["rdma_en4"] + assert matrix[idx_c][idx_a] == ["rdma_en5"] # Verify coordinators are set for all nodes assert len(instance.jaccl_coordinators) == 3 @@ -523,6 +524,68 @@ def test_tensor_rdma_backend_connectivity_matrix( assert len(ip_part.split(".")) == 4 +def test_jaccl_devices_matrix_multiple_links_between_two_nodes() -> None: + """Two cables between the same pair of nodes must produce index-aligned + rails: entry k of [i][j] and entry k of [j][i] are the two ends of the + same physical link, regardless of edge insertion order.""" + topology = Topology() + node_a = NodeId() + node_b = NodeId() + topology.add_node(node_a) + topology.add_node(node_b) + + # Cable 1: en6 on A <-> en7 on B. Cable 2: en7 on A <-> en6 on B. + # The crossed names mean a naive "first edge per direction" pick can pair + # interfaces that are not on the same cable. + cable_1_a_view = RDMAConnection( + source_rdma_iface="rdma_en6", sink_rdma_iface="rdma_en7" + ) + cable_2_a_view = RDMAConnection( + source_rdma_iface="rdma_en7", sink_rdma_iface="rdma_en6" + ) + cable_1_b_view = RDMAConnection( + source_rdma_iface="rdma_en7", sink_rdma_iface="rdma_en6" + ) + cable_2_b_view = RDMAConnection( + source_rdma_iface="rdma_en6", sink_rdma_iface="rdma_en7" + ) + + # Insert the two directions in opposite cable order on purpose. + topology.add_connection(Connection(source=node_a, sink=node_b, edge=cable_1_a_view)) + topology.add_connection(Connection(source=node_a, sink=node_b, edge=cable_2_a_view)) + topology.add_connection(Connection(source=node_b, sink=node_a, edge=cable_2_b_view)) + topology.add_connection(Connection(source=node_b, sink=node_a, edge=cable_1_b_view)) + + matrix = get_mlx_jaccl_devices_matrix([node_a, node_b], topology) + + assert matrix[0][0] == [] + assert matrix[1][1] == [] + assert len(matrix[0][1]) == 2 + assert len(matrix[1][0]) == 2 + + # Each rail index must reference the two endpoints of one physical cable. + rails = list(zip(matrix[0][1], matrix[1][0], strict=True)) + assert sorted(rails) == [("rdma_en6", "rdma_en7"), ("rdma_en7", "rdma_en6")] + + +def test_jaccl_devices_matrix_single_link_remains_single_entry() -> None: + topology = Topology() + node_a = NodeId() + node_b = NodeId() + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_connection( + Connection(source=node_a, sink=node_b, edge=create_rdma_connection(6)) + ) + topology.add_connection( + Connection(source=node_b, sink=node_a, edge=create_rdma_connection(6)) + ) + + matrix = get_mlx_jaccl_devices_matrix([node_a, node_b], topology) + + assert matrix == [[[], ["rdma_en6"]], [["rdma_en6"], []]] + + def _build_three_node_rdma_topology() -> tuple[ Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo] ]: diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index 16233f3f05..3d5d3927c0 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -31,7 +31,10 @@ class MlxRingInstance(BaseInstance): class MlxJacclInstance(BaseInstance): - jaccl_devices: list[list[str | None]] + # jaccl_devices[i][j] lists the RDMA interface names on rank i that + # connect to rank j, one entry per physical link. Entry k of [i][j] and + # entry k of [j][i] are the two ends of the same link. + jaccl_devices: list[list[list[str]]] jaccl_coordinators: dict[NodeId, str] diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 730abf64e3..0d0f220b1c 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -122,9 +122,7 @@ def mlx_distributed_init( case MlxJacclInstance( jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators ): - assert all( - jaccl_devices[i][i] is None for i in range(len(jaccl_devices)) - ) + assert all(jaccl_devices[i][i] == [] for i in range(len(jaccl_devices))) # Use RDMA connectivity matrix jaccl_devices_json = json.dumps(jaccl_devices) @@ -140,6 +138,22 @@ def mlx_distributed_init( os.environ["MLX_IBV_DEVICES"] = coordination_file os.environ["MLX_RANK"] = str(rank) os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator + + # The default jaccl mesh backend only ever uses the first link + # between two ranks; the ring backend stripes traffic across + # every link. Prefer the ring whenever any pair of ranks has + # more than one physical link so the extra bandwidth is used. + max_links = max( + (len(cell) for row in jaccl_devices for cell in row), + default=0, + ) + if max_links > 1: + os.environ["MLX_JACCL_RING"] = "1" + logger.info( + f"rank {rank} MLX_JACCL_RING=1 " + f"(up to {max_links} links per peer)" + ) + group = mx.distributed.init(backend="jaccl", strict=True) logger.info(f"Rank {rank} mlx distributed initialization complete")