The line assert id(state_dict["optim_sates"]) != id(state_dict_to_recv["optim_sates"]) in test_send_recv_state_dict() is failing which probably means that recv_state_dict is not creating new tensor objects but is instead reusing the same tensors.
One possible is fixed will be then to create a new tensor in recv_state_dict(). Something like this could work. What do you think ?
def recv_state_dict(pg: ProcessGroup, src_rank: int, og_state_dict: dict) -> dict:
size = torch.LongTensor(1)
# Receive object sizes
pg.recv([size], src_rank, 0).wait()
# Tensor to receive serialized objects into.
object_tensor = torch.empty(size.item(), dtype=torch.uint8)
pg.recv([object_tensor], src_rank, 0).wait()
state_dict = _tensor_to_object(object_tensor, size)
_, tensors = _get_sendable_state_dict(og_state_dict)
jobs = []
datas = []
for i, tensor in enumerate(tensors):
buffer = tensor
if isinstance(tensor, DTensor):
buffer = tensor.to_local()
data = torch.empty_like(buffer, device="cpu")
jobs.append(pg.recv([data], src_rank, i))
datas.append(data)
for job in jobs:
job.wait()
+ new_tensors = []
for tensor, data in zip(tensors, datas):
if isinstance(tensor, DTensor):
tensor = tensor.to_local()
- tensor.copy_(data)
+ new_tensor = data.clone()
+ new_tensors.append(new_tensor)
- state_dict = _load_sendable_state_dict(tensors, state_dict)
+ state_dict = _load_sendable_state_dict(new_tensors, state_dict)
return state_dict
The line
assert id(state_dict["optim_sates"]) != id(state_dict_to_recv["optim_sates"])intest_send_recv_state_dict()is failing which probably means thatrecv_state_dictis not creating new tensor objects but is instead reusing the same tensors.One possible is fixed will be then to create a new tensor in
recv_state_dict(). Something like this could work. What do you think ?def recv_state_dict(pg: ProcessGroup, src_rank: int, og_state_dict: dict) -> dict: size = torch.LongTensor(1) # Receive object sizes pg.recv([size], src_rank, 0).wait() # Tensor to receive serialized objects into. object_tensor = torch.empty(size.item(), dtype=torch.uint8) pg.recv([object_tensor], src_rank, 0).wait() state_dict = _tensor_to_object(object_tensor, size) _, tensors = _get_sendable_state_dict(og_state_dict) jobs = [] datas = [] for i, tensor in enumerate(tensors): buffer = tensor if isinstance(tensor, DTensor): buffer = tensor.to_local() data = torch.empty_like(buffer, device="cpu") jobs.append(pg.recv([data], src_rank, i)) datas.append(data) for job in jobs: job.wait() + new_tensors = [] for tensor, data in zip(tensors, datas): if isinstance(tensor, DTensor): tensor = tensor.to_local() - tensor.copy_(data) + new_tensor = data.clone() + new_tensors.append(new_tensor) - state_dict = _load_sendable_state_dict(tensors, state_dict) + state_dict = _load_sendable_state_dict(new_tensors, state_dict) return state_dict