Skip to content

send/recv state dict reuse the same tensor #245

@3outeille

Description

@3outeille

The line assert id(state_dict["optim_sates"]) != id(state_dict_to_recv["optim_sates"]) in test_send_recv_state_dict() is failing which probably means that recv_state_dict is not creating new tensor objects but is instead reusing the same tensors.

One possible is fixed will be then to create a new tensor in recv_state_dict(). Something like this could work. What do you think ?

def recv_state_dict(pg: ProcessGroup, src_rank: int, og_state_dict: dict) -> dict:
    size = torch.LongTensor(1)

    # Receive object sizes
    pg.recv([size], src_rank, 0).wait()
    # Tensor to receive serialized objects into.
    object_tensor = torch.empty(size.item(), dtype=torch.uint8)

    pg.recv([object_tensor], src_rank, 0).wait()
    state_dict = _tensor_to_object(object_tensor, size)

    _, tensors = _get_sendable_state_dict(og_state_dict)

    jobs = []
    datas = []
    for i, tensor in enumerate(tensors):
        buffer = tensor
        if isinstance(tensor, DTensor):
            buffer = tensor.to_local()

        data = torch.empty_like(buffer, device="cpu")
        jobs.append(pg.recv([data], src_rank, i))
        datas.append(data)

    for job in jobs:
        job.wait()

+    new_tensors = []
    for tensor, data in zip(tensors, datas):
        if isinstance(tensor, DTensor):
            tensor = tensor.to_local()
-      tensor.copy_(data)
+      new_tensor = data.clone()
+      new_tensors.append(new_tensor)
-    state_dict = _load_sendable_state_dict(tensors, state_dict)
+    state_dict = _load_sendable_state_dict(new_tensors, state_dict)
    return state_dict

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions