Skip to content

Fix NVSHMEM IBGDA duplicate-definition in host builds#581

Open
Gregory-Pereira wants to merge 1 commit intodeepseek-ai:mainfrom
Gregory-Pereira:fix/nvshmem-ibgda-multiple-definition
Open

Fix NVSHMEM IBGDA duplicate-definition in host builds#581
Gregory-Pereira wants to merge 1 commit intodeepseek-ai:mainfrom
Gregory-Pereira:fix/nvshmem-ibgda-multiple-definition

Conversation

@Gregory-Pereira
Copy link

Fix NVSHMEM IBGDA duplicate-definition in host builds

Summary

This PR fixes a linker error that occurs when building the DeepEP Python extension with NVSHMEM enabled.

Problem

csrc/kernels/configs.cuh includes NVSHMEM’s internal header:

device_host_transport/nvshmem_common_ibgda.h

That header conditionally defines the global symbol:

nvshmemi_ibgda_device_state_d

When the header is included from a host-compiled translation unit (e.g. deep_ep.cpp built with g++), the symbol is emitted as a definition. Since DeepEP also links against libnvshmem_device.a, which defines the same symbol, the build fails at link time with:

multiple definition of `nvshmemi_ibgda_device_state_d`

This can be reproduced when building DeepEP against newer NVSHMEM releases (e.g. v3.5.19) in environments where configs.cuh is included by both CUDA (.cu) and host (.cpp) sources.

Solution

Restrict inclusion of nvshmem_common_ibgda.h to NVCC compilation contexts only:

#if defined(__CUDACC__)
#include <device_host_transport/nvshmem_common_ibgda.h>
#endif

This ensures:

  • The NVSHMEM IBGDA header is only included during CUDA compilation
  • Host-compiled translation units (e.g. deep_ep.cpp built with g++) do not emit
    a duplicate definition of nvshmemi_ibgda_device_state_d
  • The extension links cleanly when libnvshmem_device.a is present

Why this approach

When included from host-compiled code, nvshmem_common_ibgda.h can emit a
definition of nvshmemi_ibgda_device_state_d, which collides with the definition
provided by libnvshmem_device.a at link time.

Restricting the include to NVCC compilation avoids this duplicate-definition
path and makes the build more robust across NVSHMEM versions and build
environments.

The change is minimal and does not alter existing CUDA behavior.

My build log hitting this

#38 15.06 Cloning into 'deepep'...
#38 15.64 + cd deepep
#38 15.64 + git fetch origin 92fe2deaec24bc92ebd9de276daa6ca9ed602ed4
#38 15.85 From https://github.com/deepseek-ai/DeepEP
#38 15.85  * branch            92fe2deaec24bc92ebd9de276daa6ca9ed602ed4 -> FETCH_HEAD
#38 15.85 + git checkout -q 92fe2deaec24bc92ebd9de276daa6ca9ed602ed4
#38 15.87 + uv build --wheel --no-build-isolation --out-dir /wheels
#38 15.90 Building wheel...
#38 18.11 W0208 00:54:30.289000 215 torch/utils/cpp_extension.py:117] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
#38 18.13 toml section missing PosixPath('pyproject.toml') does not contain a tool.setuptools_scm section
#38 18.14 Build summary:
#38 18.14  > Sources: ['csrc/deep_ep.cpp', 'csrc/kernels/runtime.cu', 'csrc/kernels/layout.cu', 'csrc/kernels/intranode.cu', 'csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']
#38 18.14  > Includes: ['csrc/', '/opt/nvshmem-v3.5.19-1/include']
#38 18.14  > Libraries: ['/opt/nvshmem-v3.5.19-1/lib']
#38 18.14  > Compilation flags: {'cxx': ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes', '-DDISABLE_AGGRESSIVE_PTX_INSTRS'], 'nvcc': ['-O3', '-Xcompiler', '-O3', '-rdc=true', '--ptxas-options=--register-usage-level=10', '-DDISABLE_AGGRESSIVE_PTX_INSTRS'], 'nvcc_dlink': ['-dlink', '-L/opt/nvshmem-v3.5.19-1/lib', '-lnvshmem_device']}
#38 18.14  > Link flags: ['-lcuda', '-l:libnvshmem_host.so', '-l:libnvshmem_device.a', '-Wl,-rpath,/opt/nvshmem-v3.5.19-1/lib']
#38 18.14  > Arch list: 9.0a;10.0+PTX
#38 18.14  > NVSHMEM path: /opt/nvshmem-v3.5.19-1
#38 18.14 
#38 18.14 running bdist_wheel
#38 18.15 running build
#38 18.15 running build_py
#38 18.16 creating build/lib.linux-x86_64-cpython-312/deep_ep
#38 18.16 copying deep_ep/utils.py -> build/lib.linux-x86_64-cpython-312/deep_ep
#38 18.16 copying deep_ep/buffer.py -> build/lib.linux-x86_64-cpython-312/deep_ep
#38 18.16 copying deep_ep/__init__.py -> build/lib.linux-x86_64-cpython-312/deep_ep
#38 18.16 running build_ext
#38 18.17 W0208 00:54:30.351000 215 torch/utils/cpp_extension.py:531] There are no c++ version bounds defined for CUDA version 12.9
#38 18.17 building 'deep_ep_cpp' extension
#38 18.17 creating /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc
#38 18.17 creating /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels
#38 45.99 [1/7] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/runtime.o.d -Icsrc/ -I/opt/nvshmem-v3.5.19-1/include -I/opt/vllm/lib/python3.12/site-packages/torch/include -I/opt/vllm/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/vllm/include -I/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/include/python3.12 -c -c /tmp/deepep/csrc/kernels/runtime.cu -o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/runtime.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=deep_ep_cpp -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_90a,code=sm_90a -std=c++17
#38 55.93 [2/7] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/intranode.o.d -Icsrc/ -I/opt/nvshmem-v3.5.19-1/include -I/opt/vllm/lib/python3.12/site-packages/torch/include -I/opt/vllm/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/vllm/include -I/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/include/python3.12 -c -c /tmp/deepep/csrc/kernels/intranode.cu -o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/intranode.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=deep_ep_cpp -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_90a,code=sm_90a -std=c++17
#38 74.68 [3/7] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/layout.o.d -Icsrc/ -I/opt/nvshmem-v3.5.19-1/include -I/opt/vllm/lib/python3.12/site-packages/torch/include -I/opt/vllm/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/vllm/include -I/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/include/python3.12 -c -c /tmp/deepep/csrc/kernels/layout.cu -o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/layout.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=deep_ep_cpp -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_90a,code=sm_90a -std=c++17
#38 79.46 [4/7] c++ -MMD -MF /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/deep_ep.o.d -pthread -fno-strict-overflow -Wsign-compare -Wunreachable-code -DNDEBUG -g -O3 -Wall -O3 -fPIC -fPIC -Icsrc/ -I/opt/nvshmem-v3.5.19-1/include -I/opt/vllm/lib/python3.12/site-packages/torch/include -I/opt/vllm/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/vllm/include -I/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/include/python3.12 -c -c /tmp/deepep/csrc/deep_ep.cpp -o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/deep_ep.o -O3 -Wno-deprecated-declarations -Wno-unused-variable -Wno-sign-compare -Wno-reorder -Wno-attributes -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=deep_ep_cpp -std=c++17
#38 85.46 [5/7] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode_ll.o.d -Icsrc/ -I/opt/nvshmem-v3.5.19-1/include -I/opt/vllm/lib/python3.12/site-packages/torch/include -I/opt/vllm/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/vllm/include -I/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/include/python3.12 -c -c /tmp/deepep/csrc/kernels/internode_ll.cu -o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode_ll.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=deep_ep_cpp -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_90a,code=sm_90a -std=c++17
#38 153.7 [6/7] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode.o.d -Icsrc/ -I/opt/nvshmem-v3.5.19-1/include -I/opt/vllm/lib/python3.12/site-packages/torch/include -I/opt/vllm/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/vllm/include -I/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/include/python3.12 -c -c /tmp/deepep/csrc/kernels/internode.cu -o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=deep_ep_cpp -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_90a,code=sm_90a -std=c++17
#38 157.2 [7/7] /usr/local/cuda/bin/nvcc /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/deep_ep.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode_ll.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/intranode.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/layout.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/runtime.o -o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/dlink.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -dlink -L/opt/nvshmem-v3.5.19-1/lib -lnvshmem_device -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=deep_ep_cpp -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_100,code=sm_100 -gencode=arch=compute_90a,code=sm_90a
#38 157.2 c++ -pthread -D__NVSHMEM_NUMBA_SUPPORT__ -shared -Wl,--exclude-libs,ALL -LModules/_hacl -D__NVSHMEM_NUMBA_SUPPORT__ /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/deep_ep.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode_ll.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/intranode.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/layout.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/runtime.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/dlink.o -L/opt/nvshmem-v3.5.19-1/lib -L/opt/vllm/lib/python3.12/site-packages/torch/lib -L/usr/local/cuda/lib64 -L/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-312/deep_ep_cpp.cpython-312-x86_64-linux-gnu.so -lcuda -l:libnvshmem_host.so -l:libnvshmem_device.a -Wl,-rpath,/opt/nvshmem-v3.5.19-1/lib
#38 157.7 /usr/bin/ld: /opt/nvshmem-v3.5.19-1/lib/libnvshmem_device.a(init_device.cu.o):(.bss+0x380): multiple definition of `nvshmemi_ibgda_device_state_d'; /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/deep_ep.o:/opt/nvshmem-v3.5.19-1/include/device_host_transport/nvshmem_common_ibgda.h:351: first defined here
#38 157.7 collect2: error: ld returned 1 exit status
#38 157.7 error: command '/usr/bin/c++' failed with exit code 1
#38 158.0   × Failed to build `/tmp/deepep`
#38 158.0   ├─▶ The build backend returned an error
#38 158.0   ╰─▶ Call to `setuptools.build_meta:__legacy__.build_wheel` failed (exit
#38 158.0       status: 1)
#38 158.0       hint: This usually indicates a problem with the package or the build
#38 158.0       environment.
#38 ERROR: process "/bin/sh -c TARGETPLATFORM=${TARGETPLATFORM} /tmp/build-compiled-wheels.sh &&     rm -f /tmp/build-compiled-wheels.sh" did not complete successfully: exit code: 2
------
 > [builder 33/33] RUN --mount=type=cache,target=/root/.cache/uv     --mount=type=secret,id=aws_access_key_id     --mount=type=secret,id=aws_secret_access_key     TARGETPLATFORM=linux/amd64 /tmp/build-compiled-wheels.sh &&     rm -f /tmp/build-compiled-wheels.sh:
157.2 c++ -pthread -D__NVSHMEM_NUMBA_SUPPORT__ -shared -Wl,--exclude-libs,ALL -LModules/_hacl -D__NVSHMEM_NUMBA_SUPPORT__ /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/deep_ep.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/internode_ll.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/intranode.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/layout.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/kernels/runtime.o /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/dlink.o -L/opt/nvshmem-v3.5.19-1/lib -L/opt/vllm/lib/python3.12/site-packages/torch/lib -L/usr/local/cuda/lib64 -L/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-312/deep_ep_cpp.cpython-312-x86_64-linux-gnu.so -lcuda -l:libnvshmem_host.so -l:libnvshmem_device.a -Wl,-rpath,/opt/nvshmem-v3.5.19-1/lib
157.7 /usr/bin/ld: /opt/nvshmem-v3.5.19-1/lib/libnvshmem_device.a(init_device.cu.o):(.bss+0x380): multiple definition of `nvshmemi_ibgda_device_state_d'; /tmp/deepep/build/temp.linux-x86_64-cpython-312/csrc/deep_ep.o:/opt/nvshmem-v3.5.19-1/include/device_host_transport/nvshmem_common_ibgda.h:351: first defined here
157.7 collect2: error: ld returned 1 exit status
157.7 error: command '/usr/bin/c++' failed with exit code 1
158.0   × Failed to build `/tmp/deepep`
158.0   ├─▶ The build backend returned an error
158.0   ╰─▶ Call to `setuptools.build_meta:__legacy__.build_wheel` failed (exit
158.0       status: 1)
158.0       hint: This usually indicates a problem with the package or the build
158.0       environment.

cc @tlrmchlsmth @smarterclayton

…ymbol

Signed-off-by: greg pereira <grpereir@redhat.com>
@Gregory-Pereira
Copy link
Author

I see #574 has a better full implementation, if that gets accepted I will drop this minimal patch in favor of that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant