Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions cuda_core/tests/example_tests/test_basic_examples.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# If we have subcategories of examples in the future, this file can be split along those lines

import glob
import os
from pathlib import Path

import pytest
from cuda.core import Device

from .utils import run_example

samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
sample_files = glob.glob(samples_path + "**/*.py", recursive=True)
# not dividing, but navigating into the "examples" directory.
EXAMPLES_DIR = Path(__file__).resolve().parents[2] / "examples"

# recursively glob for test files in examples directory, sort for deterministic
# test runs. Relative paths offer cleaner output when tests fail.
SAMPLE_FILES = sorted([str(p.relative_to(EXAMPLES_DIR)) for p in EXAMPLES_DIR.glob("**/*.py")])

@pytest.mark.parametrize("example", sample_files)

@pytest.mark.parametrize("example_rel_path", SAMPLE_FILES)
class TestExamples:
def test_example(self, example, deinit_cuda):
run_example(samples_path, example)
if Device().device_id != 0:
Device(0).set_current()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should kick off CI and check if removing this line is OK. IIRC on a multi-GPU system it would fail without this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think I get this, let me know if I mess up:

  • each thread gets only a single reference to any one Device (hence thread-local singleton pattern)
  • a thread can reference multiple Devices
  • a Device can have multiple CUDA Contexts but a Context can only belong to a single GPU Device
  • Contexts on the same Device are mutually exclusive
  • the driver manages the context stack for a given thread

So, in a multi-GPU example, the driver recruits n devices using a given thread, then runs the kernel and calls deinit_cuda(), popping the context off the context stack.

The Problem

The driver doesn't update the current_device to 0 when popping multiple shared (cudaDeviceEnablePeerAccess) Device contexts, thus when a program asks for a new Device the driver returns the nth device instead of the 0th.

A Possible Solution

Redundantly set Device(0) as the current device prior to running the example, if prior example was multi-GPU, we are now back to Device 0, otherwise the redundant call does nothing.

def test_example(self, example_rel_path: str, deinit_cuda) -> None:
    from cuda.core import Device

    Device(0).set_current()
    run_example(str(EXAMPLES_DIR), example_rel_path)

# deinit_cuda is defined in conftest.py and pops the cuda context automatically.
def test_example(self, example_rel_path: str, deinit_cuda) -> None:
from cuda.core import Device

# redundantly set current device to 0 in case previous example was multi-GPU
Device(0).set_current()
run_example(str(EXAMPLES_DIR), example_rel_path)
51 changes: 32 additions & 19 deletions cuda_core/tests/example_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import gc
import os
import importlib.util
import sys
from pathlib import Path

import pytest

Expand All @@ -12,24 +13,34 @@ class SampleTestError(Exception):
pass


def parse_python_script(filepath):
if not filepath.endswith(".py"):
raise ValueError(f"{filepath} not supported")
with open(filepath, encoding="utf-8") as f:
script = f.read()
return script
def run_example(parent_dir: str, rel_path_to_example: str, env=None) -> None:
fullpath = Path(parent_dir) / rel_path_to_example
module_name = fullpath.stem

old_sys_path = sys.path.copy()
old_argv = sys.argv

def run_example(samples_path, filename, env=None):
fullpath = os.path.join(samples_path, filename)
script = parse_python_script(fullpath)
try:
old_argv = sys.argv
sys.argv = [fullpath]
old_sys_path = sys.path.copy()
sys.path.append(samples_path)
# TODO: Refactor the examples to give them a common callable `main()` to avoid needing to use exec here?
exec(script, env if env else {}) # noqa: S102
sys.path.append(parent_dir)
sys.argv = [str(fullpath)]

# Collect metadata for file 'module_name' located at 'fullpath'.
spec = importlib.util.spec_from_file_location(module_name, fullpath)

if spec is None or spec.loader is None:
raise ImportError(f"Failed to load spec for {rel_path_to_example}")

# Otherwise convert the spec to a module, then run the module.
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module

# This runs top-level code.
spec.loader.exec_module(module)

# If the module has a main() function, call it.
if hasattr(module, "main"):
module.main()

except ImportError as e:
# for samples requiring any of optional dependencies
for m in ("cupy", "torch"):
Expand All @@ -40,14 +51,16 @@ def run_example(samples_path, filename, env=None):
raise
except SystemExit:
# for samples that early return due to any missing requirements
pytest.skip(f"skip {filename}")
pytest.skip(f"skip {rel_path_to_example}")
except Exception as e:
msg = "\n"
msg += f"Got error ({filename}):\n"
msg += f"Got error ({rel_path_to_example}):\n"
msg += str(e)
raise SampleTestError(msg) from e
finally:
sys.path = old_sys_path
sys.argv = old_argv

# further reduce the memory watermark
sys.modules.pop(module_name, None)
gc.collect()