Skip to content

nccl_api_->winPut Error Using RMA Windows #523

@BenBrock

Description

@BenBrock

Should I expect the RMA window interface to work with NCCLX (or any of the other backends), or is it still WIP?

I get the following error on issuing a win.put(...):

F0202 09:40:35.775741 120495 TorchCommWindowNCCLX.cpp:130] Check failed: nccl_api_->winPut( tensor.data_ptr(), tensor.numel(), torch_comm_->getNcclDataType(tensor), dstRank, targetOffsetNelems, win_, stream) == ncclSuccess (5 vs. 0)

Setting NCCL_DEBUG=WARN produces an error: NCCL ERROR Invalid window handle in ncclPut.

I've installed TorchComms using CUDA 12.9 and the pip nightly wheels, although I get the same results when building from source. I'm running on a system with four H200s fully connected by NVLink.

Note I opened another issue #410 about which interface should be used for creating an RMA window.

Here's my minimal program:

#!/usr/bin/env python3
import os
import torch
import torchcomms

def main():
    # Initialize TorchComm with NCCLX backend
    device = torch.device("cuda")
    comm = torchcomms.new_comm("ncclx", device, name="main_comm")

    # Get rank and world size
    rank = comm.get_rank()
    world_size = comm.get_size()

    # Calculate device ID
    num_devices = torch.cuda.device_count()
    device_id = rank % num_devices
    target_device = torch.device(f"cuda:{device_id}")

    print(f"Rank {rank}/{world_size}: Running on device {device_id}")

    size = 1024
    dtype = torch.float32

    allocator = torchcomms.get_mem_allocator(comm.get_backend())

    pool = torch.cuda.MemPool(allocator)

    with torch.cuda.use_mem_pool(pool):
        win_buf = torch.ones(
            [size], dtype=dtype, device=device
        )

    print(comm.get_backend())
    print(allocator)

    print(win_buf)

    comm.barrier(False)

    win = comm.new_window()
    win.tensor_register(win_buf)
    comm.barrier(False)

    # Program

    if rank == 1:
        send_data = torch.zeros((10,), dtype=dtype, device=device)
        win.put(send_data, dst_rank=0, target_offset_nelems=0, async_op=False)

    comm.barrier(False)

    if rank == 0:
        print(win_buf)

    comm.barrier(False)
    win.tensor_deregister()

    # Cleanup
    comm.finalize()

if __name__ == "__main__":
    main()

And the full NCCL_DEBUG=WARN output:

(torchcomms-pip-2025-02-02) bash-5.2$ NCCL_DEBUG=WARN torchrun --nproc_per_node=2 ./put_01.py
W0202 09:44:44.157000 120687 site-packages/torch/distributed/run.py:851]
W0202 09:44:44.157000 120687 site-packages/torch/distributed/run.py:851] *****************************************
W0202 09:44:44.157000 120687 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0202 09:44:44.157000 120687 site-packages/torch/distributed/run.py:851] *****************************************
WARNING: Logging before InitGoogleLogging() is written to STDERR
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0202 09:44:46.515873 120755 TorchCommFactory.cpp:164] [TC] Backend ncclx is registered
I0202 09:44:46.515873 120756 TorchCommFactory.cpp:164] [TC] Backend ncclx is registered
I0202 09:44:46.515902 120755 TorchCommFactory.cpp:164] [TC] Backend dummy is registered
I0202 09:44:46.515903 120756 TorchCommFactory.cpp:164] [TC] Backend dummy is registered
I0202 09:44:46.515919 120756 TorchCommNCCLXBootstrap.cpp:60] [TC] TORCHCOMM_NCCLX_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD not set, defaulting to auto
I0202 09:44:46.515923 120755 TorchCommNCCLXBootstrap.cpp:60] [TC] TORCHCOMM_NCCLX_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD not set, defaulting to auto
I0202 09:44:46.782441 120756 TorchCommNCCLXBootstrap.cpp:83] [TC] Found 2 CUDA devices
I0202 09:44:46.782469 120756 TorchCommNCCLXBootstrap.cpp:86] [TC] User did not provide device ID; using device cuda:1
I0202 09:44:46.794304 120755 TorchCommNCCLXBootstrap.cpp:83] [TC] Found 2 CUDA devices
I0202 09:44:46.794325 120755 TorchCommNCCLXBootstrap.cpp:86] [TC] User did not provide device ID; using device cuda:0
[E202 09:44:47.148137560 TCPStore.cpp:277] [c10d] The server socket on 29500 has failed to bind. TORCHELASTIC_USE_AGENT_STORE is enabled so ignoring the error.
I0202 09:44:47.079793 120755 CudaWrap.cc:267] cudaDriverVersion 13000
I0202 09:44:47.079841 120755 CudaWrap.cc:269] CUDART_VERSION 12090
I0202 17:44:47.079884548.079884 120755 showVersion:653] pcl-kini02:120755:120755 [0][main] NCCL INFO NCCL version 2.28.9x-git-107bbda+cuda12.9
I0202 09:44:47.082183 120756 CudaWrap.cc:267] cudaDriverVersion 13000
I0202 09:44:47.082224 120756 CudaWrap.cc:269] CUDART_VERSION 12090
W0202 17:44:47.289372953.289372 120755 DataSink.h:26] pcl-kini02:120755:120755 [0][main] NCCL WARN Empty sink for dataset: nccl_profiler_slow_rank. No logging will be done.
W0202 17:44:47.289491478.289491 120755 DataSink.h:26] pcl-kini02:120755:120755 [0][main] NCCL WARN Empty sink for dataset: nccl_profiler_algo. No logging will be done.
W0202 17:44:47.289638639.289638 120755 DataTable.cc:40] pcl-kini02:120755:120755 [0][main] NCCL WARN Failed to create scuba file: filesystem error: cannot create directories: Permission denied [/logs]
W0202 17:44:47.289708714.289708 120755 DataTable.cc:40] pcl-kini02:120755:120755 [0][main] NCCL WARN Failed to create scuba file: filesystem error: cannot create directories: Permission denied [/logs]
W0202 17:44:47.289803979.289803 120755 DataTableWrapper.cc:89] pcl-kini02:120755:120755 [0][main] NCCL WARN Could not find table name: nccl_structured_logging
W0202 17:44:47.289803979.289803 120755 DataTableWrapper.cc:89] pcl-kini02:120755:120755 [0][main] NCCL WARN
W0202 17:44:47.289833013.289833 120755 ncclIbInit:852] pcl-kini02:120755:120755 [0][main] NCCL WARN /__w/torchcomms/torchcomms/meta-pytorch/torchcomms/comms/ncclx/v2_28/src/transport/net_ib.cc:852 -> 3
W0202 17:44:47.293188688.293188 120756 DataSink.h:26] pcl-kini02:120756:120756 [1][main] NCCL WARN Empty sink for dataset: nccl_profiler_slow_rank. No logging will be done.
W0202 17:44:47.293288186.293288 120756 DataSink.h:26] pcl-kini02:120756:120756 [1][main] NCCL WARN Empty sink for dataset: nccl_profiler_algo. No logging will be done.
W0202 17:44:47.293428325.293428 120756 DataTable.cc:40] pcl-kini02:120756:120756 [1][main] NCCL WARN Failed to create scuba file: filesystem error: cannot create directories: Permission denied [/logs]
W0202 17:44:47.293485834.293485 120756 DataTable.cc:40] pcl-kini02:120756:120756 [1][main] NCCL WARN Failed to create scuba file: filesystem error: cannot create directories: Permission denied [/logs]
W0202 17:44:47.293589336.293589 120756 DataTableWrapper.cc:89] pcl-kini02:120756:120756 [1][main] NCCL WARN Could not find table name: nccl_structured_logging
W0202 17:44:47.293589336.293589 120756 DataTableWrapper.cc:89] pcl-kini02:120756:120756 [1][main] NCCL WARN
W0202 17:44:47.293627446.293627 120756 ncclIbInit:852] pcl-kini02:120756:120756 [1][main] NCCL WARN /__w/torchcomms/torchcomms/meta-pytorch/torchcomms/comms/ncclx/v2_28/src/transport/net_ib.cc:852 -> 3
W0202 17:44:47.551827834.551827 120755 DataTableWrapper.cc:89] pcl-kini02:120755:120755 [0][main] NCCL WARN Could not find table name: nccl_memory_logging
W0202 17:44:47.551827834.551827 120755 DataTableWrapper.cc:89] pcl-kini02:120755:120755 [0][main] NCCL WARN
I0202 09:44:47.556586 120755 CommStateX.cc:297] CommStateX: set rankTopology with system
I0202 09:44:47.556583 120756 CommStateX.cc:297] CommStateX: set rankTopology with system
Rank 0/2: Running on device 0
Rank 1/2: Running on device 1
ncclx
<torch._C._CUDAPluggableAllocator object at 0x147ba4dcd230>
tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')
ncclx
<torch._C._CUDAPluggableAllocator object at 0x15429282d2f0>
tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:1')
E0202 17:44:48.003116073.003116 120756 rma.cc:19] pcl-kini02:120756:120756 [1][main] NCCL ERROR Invalid window handle in ncclPut
W0202 17:44:48.003340242.003340 120756 ncclPut:81] pcl-kini02:120756:120756 [1][main] NCCL WARN /__w/torchcomms/torchcomms/meta-pytorch/torchcomms/comms/ncclx/v2_28/meta/rma/rma.cc:81 -> 5
F0202 09:44:48.003360 120756 TorchCommWindowNCCLX.cpp:130] Check failed: nccl_api_->winPut( tensor.data_ptr(), tensor.numel(), torch_comm_->getNcclDataType(tensor), dstRank, targetOffsetNelems, win_, stream) == ncclSuccess (5 vs. 0)
*** Check failure stack trace: ***
    @     0x15439ca88652  google::LogMessage::Fail()
    @     0x15439ca885b0  google::LogMessage::SendToLog()
    @     0x15439ca87f07  google::LogMessage::Flush()
    @     0x15439ca8b3fc  google::LogMessageFatal::~LogMessageFatal()
    @     0x15427e92819f  torch::comms::TorchCommWindowNCCLX::put()
    @     0x1543a28994e8  _ZNO8pybind116detail15argument_loaderIJRN5torch5comms15TorchCommWindowERKN2at6TensorEimbSt8optionalISt13unordered_mapINSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEESH_St4hashISH_ESt8equal_toISH_ESaISt4pairIKSH_SH_EEEESA_INSt6chrono8durationIlSt5ratioILl1ELl1000EEEEEEE4callIN3c1013intrusive_ptrINS3_9TorchWorkENS10_6detail34intrusive_target_default_null_typeIS12_EEEENS_18gil_scoped_releaseERZL20pybind11_init__commsRNS_7module_EEUlS5_S9_imbSR_SX_E_EENSt9enable_ifIXntsrSt7is_voidIT_E5valueES1E_E4typeEOT1_.isra.0
    @     0x1543a289be73  _ZZN8pybind1112cpp_function10initializeIZL20pybind11_init__commsRNS_7module_EEUlRN5torch5comms15TorchCommWindowERKN2at6TensorEimbSt8optionalISt13unordered_mapINSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEESJ_St4hashISJ_ESt8equal_toISJ_ESaISt4pairIKSJ_SJ_EEEESC_INSt6chrono8durationIlSt5ratioILl1ELl1000EEEEEE_N3c1013intrusive_ptrINS5_9TorchWorkENS11_6detail34intrusive_target_default_null_typeIS13_EEEEJS7_SB_imbST_SZ_EJNS_4nameENS_9is_methodENS_7siblingEA1071_cNS_3argES1C_S1C_S1C_NS_5arg_vES1D_NS_10call_guardIJNS_18gil_scoped_releaseEEEEEEEvOT_PFT0_DpT1_EDpRKT2_ENUlRNS_6detail13function_callEE_4_FUNES1U_
    @     0x1543a28bc36d  pybind11::cpp_function::dispatcher()
    @     0x55a30a252468  cfunction_call
    @     0x55a30a22923c  _PyObject_MakeTpCall.localalias
    @     0x55a30a233581  _PyEval_EvalFrameDefault
    @     0x55a30a2eb93f  PyEval_EvalCode
    @     0x55a30a3276ea  run_eval_code_obj
    @     0x55a30a322215  run_mod
    @     0x55a30a31f220  pyrun_file
    @     0x55a30a31eee6  _PyRun_SimpleFileObject.localalias
    @     0x55a30a31ec14  _PyRun_AnyFileObject.localalias
    @     0x55a30a31b7ae  Py_RunMain.localalias
    @     0x55a30a2d3877  Py_BytesMain
    @     0x1543e9f4230e  __libc_start_call_main
    @     0x1543e9f423c9  __libc_start_main_alias_2
    @     0x55a30a2d36c7  (unknown)
    @              (nil)  (unknown)
W0202 09:44:52.370000 120687 site-packages/torch/distributed/elastic/multiprocessing/api.py:1010] Sending process 120755 closing signal SIGTERM
*** Aborted at 1770054292 (unix time) try "date -d @1770054292" if you are using GNU date ***
PC: @                0x0 (unknown)
*** SIGTERM (@0xb901fc0001d76f) received by PID 120755 (TID 0x147cfbe9f740) from PID 120687; stack trace: ***
    @     0x147cfbee2b40 (unknown)
    @     0x147caf061274 (unknown)
    @     0x147caee1f183 (unknown)
    @     0x147caee5ddeb (unknown)
    @     0x147cafb9c8f7 (unknown)
    @     0x147cafb9cd45 (unknown)
    @     0x147caee2b9f4 (unknown)
    @     0x147caeef1fb3 (unknown)
    @     0x147cafb51765 (unknown)
    @     0x147caefa8027 (unknown)
    @     0x147caef9ef10 cuStreamSynchronize
    @     0x147ced8128b3 (unknown)
    @     0x147ced87ce70 cudaStreamSynchronize
    @     0x147bb54d39b5 at::native::nonzero_cuda_out_impl<>()
    @     0x147bb5487c06 at::native::nonzero_out_cuda()
    @     0x147bb5487f02 at::native::nonzero_cuda()
    @     0x147bb673511d c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x147c053db94d at::_ops::nonzero::call()
    @     0x147c04dbc7cc at::native::make_info()
    @     0x147c04db77b3 at::meta::structured_index_Tensor::meta()
    @     0x147bb6897ad4 at::(anonymous namespace)::wrapper_CUDA_index_out_Tensor_out()
    @     0x147bb6b0f326 at::native::masked_select_out_cuda_impl()
    @     0x147bb6b0f95e at::native::masked_select_cuda()
    @     0x147bb6747530 c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x147c056f2538 at::_ops::masked_select::redispatch()
    @     0x147c0809717b torch::autograd::VariableType::(anonymous namespace)::masked_select()
    @     0x147c08097572 c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x147c05786231 at::_ops::masked_select::call()
    @     0x147c187f63d2 torch::autograd::THPVariable_masked_select()
    @     0x55b2263e7468 cfunction_call
    @     0x55b2263be23c _PyObject_MakeTpCall.localalias
    @     0x55b2263c8581 _PyEval_EvalFrameDefault
E0202 09:44:52.584000 120687 site-packages/torch/distributed/elastic/multiprocessing/api.py:984] failed (exitcode: -6) local_rank: 1 (pid: 120756) of binary: /data/nfs_home/bbrock/.miniforge3/envs/torchcomms-pip-2025-02-02/bin/python3.12
Traceback (most recent call last):
  File "/data/nfs_home/bbrock/.miniforge3/envs/torchcomms-pip-2025-02-02/bin/torchrun", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/data/nfs_home/bbrock/.miniforge3/envs/torchcomms-pip-2025-02-02/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/data/nfs_home/bbrock/.miniforge3/envs/torchcomms-pip-2025-02-02/lib/python3.12/site-packages/torch/distributed/run.py", line 990, in main
    run(args)
  File "/data/nfs_home/bbrock/.miniforge3/envs/torchcomms-pip-2025-02-02/lib/python3.12/site-packages/torch/distributed/run.py", line 981, in run
    elastic_launch(
  File "/data/nfs_home/bbrock/.miniforge3/envs/torchcomms-pip-2025-02-02/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/nfs_home/bbrock/.miniforge3/envs/torchcomms-pip-2025-02-02/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 317, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
========================================================
./put_01.py FAILED
--------------------------------------------------------
Failures:
[1]:
  time      : 2026-02-02_09:44:52
  host      : pcl-kini02.sc.intel.com
  rank      : 0 (local_rank: 0)
  exitcode  : -15 (pid: 120755)
  error_file: <N/A>
  traceback : Signal 15 (SIGTERM) received by PID 120755
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-02-02_09:44:52
  host      : pcl-kini02.sc.intel.com
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 120756)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 120756
========================================================

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions