From b9fa30766e500eeadc4fb3886bac544353c680f2 Mon Sep 17 00:00:00 2001 From: Yuze Ma Date: Sun, 25 May 2025 17:24:12 +0800 Subject: [PATCH] fix(torch): fix multi card training example --- advanced/pytorch-example/main.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index 6e7f462..8a08549 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -3,6 +3,7 @@ import torch.optim as optim import torch.nn.functional as F import torch.distributed as dist +import os from torchvision import datasets, transforms from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler @@ -31,9 +32,20 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x, dim=1) -def train(rank, world_size): - print(f"Running on rank {rank}.") - dist.init_process_group("nccl", rank=rank, world_size=world_size) +def train(): + # Initialize process group + dist.init_process_group(backend="nccl") + + # Get local rank from environment variable + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Set device + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + print(f"Running on rank {rank} (local_rank: {local_rank})") transform = transforms.Compose([ transforms.ToTensor(), @@ -43,15 +55,15 @@ def train(rank, world_size): sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) train_loader = DataLoader(dataset, batch_size=64, sampler=sampler) - model = MNISTModel().to(rank) - model = DDP(model, device_ids=[rank]) + model = MNISTModel().to(device) + model = DDP(model, device_ids=[local_rank]) optimizer = optim.Adam(model.parameters(), lr=0.001) model.train() for epoch in range(1, 11): sampler.set_epoch(epoch) for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(rank), target.to(rank) + data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) @@ -67,9 +79,5 @@ def train(rank, world_size): dist.destroy_process_group() -def main(): - world_size = torch.cuda.device_count() - torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True) - if __name__ == "__main__": - main() \ No newline at end of file + train() \ No newline at end of file