Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions src/exo/master/placement_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,32 +296,48 @@ def get_shard_assignments(
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
cycle_digraph: Topology,
) -> list[list[str | None]]:
) -> list[list[list[str]]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.

The matrix element [i][j] contains the interface name on device i that connects
to device j, or None if no connection exists or no interface name is found.
Diagonal elements are always None.
The matrix element [i][j] contains the interface names on device i that
connect to device j, one per physical link (e.g. one per Thunderbolt
cable). Diagonal elements are always empty.

Element k of [i][j] and element k of [j][i] are the two endpoints of the
same physical link. The jaccl backend pairs queue pairs between ranks by
index, so both cells must enumerate links in the same order; deriving both
directions from a single enumeration guarantees this even when more than
one link connects the same pair of nodes.
"""
num_nodes = len(selected_cycle)
matrix: list[list[str | None]] = [
[None for _ in range(num_nodes)] for _ in range(num_nodes)
matrix: list[list[list[str]]] = [
[[] for _ in range(num_nodes)] for _ in range(num_nodes)
]

for i, node_i in enumerate(selected_cycle):
for j, node_j in enumerate(selected_cycle):
if i == j:
continue
for j in range(i + 1, num_nodes):
node_j = selected_cycle[j]

# Each directed edge carries the interface names of both
# endpoints, so edges from either direction describe the same
# physical link and can be merged into one set of endpoint pairs.
links: set[tuple[str, str]] = set()
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
break
else:
links.add((conn.source_rdma_iface, conn.sink_rdma_iface))
for conn in cycle_digraph.get_all_connections_between(node_j, node_i):
if isinstance(conn, RDMAConnection):
links.add((conn.sink_rdma_iface, conn.source_rdma_iface))

if not links:
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
)

ordered_links = sorted(links)
matrix[i][j] = [iface_i for iface_i, _ in ordered_links]
matrix[j][i] = [iface_j for _, iface_j in ordered_links]

return matrix


Expand Down
73 changes: 68 additions & 5 deletions src/exo/master/tests/test_placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_transition_events,
place_instance,
)
from exo.master.placement_utils import get_mlx_jaccl_devices_matrix
from exo.master.tests.conftest import (
create_node_memory,
create_node_network,
Expand Down Expand Up @@ -35,7 +36,7 @@
InputMessageContent,
TextGenerationTaskParams,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection, RDMAConnection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
Expand Down Expand Up @@ -496,7 +497,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
matrix = instance.jaccl_devices
assert len(matrix) == 3
for i in range(3):
assert matrix[i][i] is None
assert matrix[i][i] == []

assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
Expand All @@ -505,9 +506,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]

assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
assert matrix[idx_a][idx_b] == ["rdma_en3"]
assert matrix[idx_b][idx_c] == ["rdma_en4"]
assert matrix[idx_c][idx_a] == ["rdma_en5"]

# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3
Expand All @@ -523,6 +524,68 @@ def test_tensor_rdma_backend_connectivity_matrix(
assert len(ip_part.split(".")) == 4


def test_jaccl_devices_matrix_multiple_links_between_two_nodes() -> None:
"""Two cables between the same pair of nodes must produce index-aligned
rails: entry k of [i][j] and entry k of [j][i] are the two ends of the
same physical link, regardless of edge insertion order."""
topology = Topology()
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)

# Cable 1: en6 on A <-> en7 on B. Cable 2: en7 on A <-> en6 on B.
# The crossed names mean a naive "first edge per direction" pick can pair
# interfaces that are not on the same cable.
cable_1_a_view = RDMAConnection(
source_rdma_iface="rdma_en6", sink_rdma_iface="rdma_en7"
)
cable_2_a_view = RDMAConnection(
source_rdma_iface="rdma_en7", sink_rdma_iface="rdma_en6"
)
cable_1_b_view = RDMAConnection(
source_rdma_iface="rdma_en7", sink_rdma_iface="rdma_en6"
)
cable_2_b_view = RDMAConnection(
source_rdma_iface="rdma_en6", sink_rdma_iface="rdma_en7"
)

# Insert the two directions in opposite cable order on purpose.
topology.add_connection(Connection(source=node_a, sink=node_b, edge=cable_1_a_view))
topology.add_connection(Connection(source=node_a, sink=node_b, edge=cable_2_a_view))
topology.add_connection(Connection(source=node_b, sink=node_a, edge=cable_2_b_view))
topology.add_connection(Connection(source=node_b, sink=node_a, edge=cable_1_b_view))

matrix = get_mlx_jaccl_devices_matrix([node_a, node_b], topology)

assert matrix[0][0] == []
assert matrix[1][1] == []
assert len(matrix[0][1]) == 2
assert len(matrix[1][0]) == 2

# Each rail index must reference the two endpoints of one physical cable.
rails = list(zip(matrix[0][1], matrix[1][0], strict=True))
assert sorted(rails) == [("rdma_en6", "rdma_en7"), ("rdma_en7", "rdma_en6")]


def test_jaccl_devices_matrix_single_link_remains_single_entry() -> None:
topology = Topology()
node_a = NodeId()
node_b = NodeId()
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(6))
)
topology.add_connection(
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(6))
)

matrix = get_mlx_jaccl_devices_matrix([node_a, node_b], topology)

assert matrix == [[[], ["rdma_en6"]], [["rdma_en6"], []]]


def _build_three_node_rdma_topology() -> tuple[
Topology, NodeId, NodeId, NodeId, dict[NodeId, NodeNetworkInfo]
]:
Expand Down
5 changes: 4 additions & 1 deletion src/exo/shared/types/worker/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class MlxRingInstance(BaseInstance):


class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
# jaccl_devices[i][j] lists the RDMA interface names on rank i that
# connect to rank j, one entry per physical link. Entry k of [i][j] and
# entry k of [j][i] are the two ends of the same link.
jaccl_devices: list[list[list[str]]]
jaccl_coordinators: dict[NodeId, str]


Expand Down
20 changes: 17 additions & 3 deletions src/exo/worker/engines/mlx/utils_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def mlx_distributed_init(
case MlxJacclInstance(
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
):
assert all(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
assert all(jaccl_devices[i][i] == [] for i in range(len(jaccl_devices)))
# Use RDMA connectivity matrix
jaccl_devices_json = json.dumps(jaccl_devices)

Expand All @@ -140,6 +138,22 @@ def mlx_distributed_init(
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator

# The default jaccl mesh backend only ever uses the first link
# between two ranks; the ring backend stripes traffic across
# every link. Prefer the ring whenever any pair of ranks has
# more than one physical link so the extra bandwidth is used.
max_links = max(
(len(cell) for row in jaccl_devices for cell in row),
default=0,
)
if max_links > 1:
os.environ["MLX_JACCL_RING"] = "1"
logger.info(
f"rank {rank} MLX_JACCL_RING=1 "
f"(up to {max_links} links per peer)"
)

group = mx.distributed.init(backend="jaccl", strict=True)

logger.info(f"Rank {rank} mlx distributed initialization complete")
Expand Down