diff --git a/.docker/Dockerfile.docs b/.docker/Dockerfile.docs index 6674178..17a86c5 100644 --- a/.docker/Dockerfile.docs +++ b/.docker/Dockerfile.docs @@ -1,9 +1,8 @@ -# Docs build/verification image: baleen (CPU-only) + MkDocs toolchain. +# Docs build/verification image: baleen + krill (CPU) + MkDocs toolchain. # Used to run `mkdocs build --strict` so mkdocstrings can import baleen. FROM python:3.11-slim ENV DEBIAN_FRONTEND=noninteractive -ENV BALEEN_NO_CUDA=1 RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential zlib1g-dev libhdf5-dev \ @@ -15,7 +14,11 @@ RUN groupadd -g ${GID} app && useradd -m -u ${UID} -g ${GID} app WORKDIR /app COPY . . -RUN pip install --no-cache-dir ".[docs]" +# krill (engine) is imported by baleen at import time; mkdocstrings needs it. +RUN pip install --no-cache-dir ".[docs]" \ + && pip install --no-cache-dir numpy scipy pyslow5 pyfastx \ + && pip install --no-cache-dir krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/simple/ USER app WORKDIR /work diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 6d9ba01..1fd9249 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - name: Checkout @@ -24,7 +24,12 @@ jobs: - name: Install and test run: | - BALEEN_NO_CUDA=1 pip install ".[test]" + pip install ".[test]" + # krill (DTW + eventalign engine) is not on PyPI — install the CPU + # wheel from the project index. Its runtime deps come from PyPI first. + pip install numpy scipy pyslow5 pyfastx + pip install krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/simple/ pytest build-cpu: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 15cb643..2dd6630 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -33,7 +33,13 @@ jobs: python-version: "3.11" - name: Install docs dependencies - run: BALEEN_NO_CUDA=1 pip install ".[docs]" + run: | + pip install ".[docs]" + # krill (engine) is not on PyPI; mkdocstrings imports baleen, which + # imports it. Install the CPU wheel from the project index. + pip install numpy scipy pyslow5 pyfastx + pip install krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/simple/ - name: Build site (strict) run: mkdocs build --strict diff --git a/CLAUDE.md b/CLAUDE.md index 4bbbf82..660eeb3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -5,16 +5,16 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Build and Test Commands ```bash -# Install package (CUDA auto-detected if nvcc available) +# Install package (pure Python — no C extension to build) pip install . -# Install CPU-only (skip CUDA compilation) -BALEEN_NO_CUDA=1 pip install . - -# Target specific GPU archs (comma-separated compute capabilities without dot) -BALEEN_CUDA_ARCHS=86,90 pip install . -# Or auto-detect installed GPU -BALEEN_CUDA_ARCHS=native pip install . +# The DTW + eventalign engine 'krill' is a required runtime dependency that is +# NOT on PyPI. Install it from the project index (GPU cu122 wheel, or CPU): +pip install krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/cu122/simple/ # GPU +pip install krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/simple/ # CPU +# (Or use a prebuilt baleen Docker image, which bundles krill + slow5tools.) # Run all tests pytest @@ -49,20 +49,19 @@ Conventional commits: `feat:`, `fix:`, `perf:`, `build:`, `bench:`, `ci:`, `refa ## Architecture Overview -Baleen is a CUDA-accelerated DTW (Dynamic Time Warping) and nanopore signal analysis pipeline for detecting RNA modifications by comparing native and IVT (in vitro transcribed) nanopore signals. +Baleen is a GPU-accelerated DTW (Dynamic Time Warping) and nanopore signal analysis pipeline for detecting RNA modifications by comparing native and IVT (in vitro transcribed) nanopore signals. The DTW and eventalign engine is provided by **krill**. ### Package Structure ``` baleen/ ├── __init__.py # Re-exports public API from eventalign -├── _cuda_dtw/ # CUDA DTW implementation with CPU fallback -│ └── __init__.py # Python wrapper (dtw_distance, dtw_pairwise, etc.) +├── _dtw.py # DTW shim delegating to krill (+ GPU memory helpers) └── eventalign/ # Main analysis pipeline ├── __init__.py # Public API exports ├── _pipeline.py # run_pipeline(), save/load_results() ├── _bam.py # BAM parsing, contig stats, filtering - ├── _f5c.py # f5c eventalign CLI wrapper + ├── _eventalign.py # krill eventalign wrapper (f5c-format TSV output) ├── _signal.py # Signal extraction and grouping by position ├── _probability.py # Modification probability algorithms ├── _hierarchical.py # Hierarchical Bayesian + HMM pipeline (V1→V2→V3) @@ -72,9 +71,9 @@ baleen/ ### Data Flow 1. **Input**: Native + IVT BAM/FASTQ/BLOW5 files + reference FASTA -2. **Event alignment**: f5c eventalign produces per-read signal tables per position +2. **Event alignment**: krill aligns each read's signal to its mapped reference subsequence (HMM-free, forced-dense) and emits an f5c-format per-position signal table 3. **Signal grouping**: Group signals by genomic position, find common positions -4. **DTW computation**: Pairwise DTW distance matrices per position (CUDA or tslearn fallback) +4. **DTW computation**: Pairwise DTW distance matrices per position (krill GPU kernel, CPU fallback) 5. **Modification calling**: Three-stage hierarchical pipeline: - V1: Empirical-Bayes null scoring with hierarchical shrinkage - V2: Anchored two-component mixture EM @@ -90,11 +89,14 @@ baleen/ ### DTW Backend Selection -The `_cuda_dtw` module auto-selects backend at import time: -- CUDA (GPU) if `_cuda_dtw` C extension compiled successfully -- CPU (tslearn) fallback otherwise +`baleen/_dtw.py` is a thin shim over krill's bundled DTW (same cuDTW++ kernel +the project previously vendored as the `_cuda_dtw` C extension). krill +auto-selects GPU when a device + GPU wheel are present, else CPU. -Use `use_cuda=True/False` to force backend, or `None` for auto-select. +Use `use_cuda=True/False` to force backend, or `None` for auto-select (mapped +to krill's `use_gpu`). The pure-Python GPU memory-planning helpers +(`estimate_gpu_memory`, `get_device_count`, `get_per_device_memory`) live in +the shim since krill does not expose them. ### Modification Probability Algorithms @@ -110,22 +112,27 @@ Three modes in `_hmm_training.py`: - **Semi-supervised**: Platt-scaling calibrator from labeled positions - **Supervised**: MLE transitions + KDE emissions from labeled trajectories -## CUDA Kernel Architecture +## DTW Engine (krill) + +The DTW kernels (GPU + CPU) live in the **krill** package, not in this repo. +krill ships the same cuDTW++ warp-shuffle kernel baleen previously vendored. +The GPU path is bit-identical to that legacy kernel (verified during the swap); +krill's CPU path resamples long signals to fixed buckets (GPU-consistent), +which differs from the old tslearn fallback only on CPU-only installs. -- **FP32 only** — `DTWDistance` template, always float. FP16 would break Pascal consumer GPUs (1/64 FP32 throughput). -- **Wavefront parallelism**: one thread per row of cost matrix, diagonal sweep. `blockDim.x = 1024` (max threads per block). Three rolling diagonals in shared memory (~12 KB). -- **One block per pair** for pairwise mode; grid.x = num_comparisons. Outer loop over reference sequences is serial. -- **Cost function**: squared Euclidean distance, `sqrt` only at the end. Path matrix = nullptr for pairwise (no memory waste). -- **No Sakoe-Chiba band** — a soft-band variant was tried and reverted because setting out-of-band cells to INF without reducing thread count/diagonals is pure overhead. A real band optimization requires skipping diagonals and sizing `blockDim.x` to `min(1024, 2*band_width+1)`. -- Source files: `dtw.hpp` (kernel), `dtw_api.cpp` (Python-C bridge), `multithreading.cpp` (CPU thread pool). +GPU vs CPU is decided by which krill wheel is installed (cu122 vs plain) plus +device presence. krill exposes `dtw_distance`, `dtw_pairwise`, +`dtw_pairwise_varlen`, `dtw_multi_position_pairwise`, `dtw_backend`, +`dtw_available`. ## External Dependencies -- **f5c**: External CLI tool for nanopore event alignment. Must be on PATH. -- **pysam**: BAM file parsing -- **tslearn**: CPU DTW fallback -- **scipy**: Statistical functions, optimization -- **numba** (optional): JIT-compiled HMM forward-backward kernel (`@njit(cache=True)`), kicks in when installed +- **krill**: DTW + eventalign engine (not on PyPI; install from the project + index — cu122 GPU wheel or plain CPU wheel). Required. +- **slow5tools**: CLI used to index BLOW5 (`slow5tools index`); must be on PATH. +- **pyslow5 / pyfastx / pysam**: BLOW5 signal, reference FASTA, and BAM access. +- **scipy**: Statistical functions, optimization. +- **numba** (optional): JIT-compiled HMM forward-backward kernel (`@njit(cache=True)`), kicks in when installed. # CLAUDE.md diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 1d94153..1cfa793 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,44 +1,34 @@ -# --- Build stage --- -FROM python:3.11-slim AS builder +# CPU production image: baleen + krill (CPU wheel) + slow5tools. +# No f5c, no CUDA build — krill is the DTW + eventalign engine and is pure +# Python to install (a prebuilt wheel from the project index). +FROM python:3.11-slim ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y --no-install-recommends \ - wget build-essential zlib1g-dev libhdf5-dev \ + wget ca-certificates zlib1g \ + && apt-cache search '^libhdf5-[0-9]' | head -1 | awk '{print $1}' \ + | xargs apt-get install -y --no-install-recommends \ && rm -rf /var/lib/apt/lists/* -# Install f5c v1.6 from pre-built binaries (CPU) -RUN wget -q "https://github.com/hasindu2008/f5c/releases/download/v1.6/f5c-v1.6-binaries.tar.gz" \ - && tar xf f5c-v1.6-binaries.tar.gz \ - && cp f5c-v1.6/f5c_x86_64_linux /usr/local/bin/f5c \ - && chmod +x /usr/local/bin/f5c \ - && rm -rf f5c-v1.6 f5c-v1.6-binaries.tar.gz +# slow5tools — baleen indexes BLOW5 via `slow5tools index` (pyslow5 needs .idx). +RUN wget -q "https://github.com/hasindu2008/slow5tools/releases/download/v1.3.0/slow5tools-v1.3.0-x86_64-linux-binaries.tar.gz" \ + && tar xf slow5tools-v1.3.0-x86_64-linux-binaries.tar.gz \ + && cp slow5tools-v1.3.0/slow5tools /usr/local/bin/slow5tools \ + && chmod +x /usr/local/bin/slow5tools \ + && rm -rf slow5tools-v1.3.0 slow5tools-v1.3.0-x86_64-linux-binaries.tar.gz -# Install baleen (CPU only); copy console_script to a known path +# baleen (pure Python now — no C extension to build). WORKDIR /app COPY . . -ENV BALEEN_NO_CUDA=1 -RUN pip install --no-cache-dir . \ - && BALEEN_BIN="$(which baleen 2>/dev/null)" \ - && if [ -z "$BALEEN_BIN" ]; then \ - printf '#!/bin/sh\nexec python3 -m baleen "$@"\n' > /usr/local/bin/baleen; \ - elif [ "$BALEEN_BIN" != "/usr/local/bin/baleen" ]; then \ - cp "$BALEEN_BIN" /usr/local/bin/baleen; \ - fi \ - && chmod +x /usr/local/bin/baleen - -# --- Runtime stage --- -FROM python:3.11-slim - -RUN apt-get update \ - && apt-get install -y --no-install-recommends zlib1g \ - && apt-cache search '^libhdf5-[0-9]' | head -1 | awk '{print $1}' \ - | xargs apt-get install -y --no-install-recommends \ - && rm -rf /var/lib/apt/lists/* - -COPY --from=builder /usr/local/bin/f5c /usr/local/bin/f5c -COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages -COPY --from=builder /usr/local/bin/baleen /usr/local/bin/baleen +RUN pip install --no-cache-dir . + +# krill engine — follow the package's strict rules: +# 1. runtime deps from PyPI first (NEVER `krill[...]`, NEVER --extra-index-url) +# 2. krill itself ONLY from the project index, --no-deps. CPU wheel here. +RUN pip install --no-cache-dir numpy scipy pyslow5 pyfastx \ + && pip install --no-cache-dir krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/simple/ WORKDIR /data ENTRYPOINT ["baleen"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu index 969cae6..47e834d 100644 --- a/Dockerfile.gpu +++ b/Dockerfile.gpu @@ -1,87 +1,42 @@ -# --- Build stage --- -FROM nvidia/cuda:12.6.3-devel-ubuntu22.04 AS builder +# GPU production image: baleen + krill (cu122 GPU wheel) + slow5tools. +# No f5c, no nvcc/CUDA build — krill ships the GPU DTW kernel as a prebuilt +# wheel, so a CUDA *runtime* base (matching the cu122 wheel) is sufficient. +FROM nvidia/cuda:12.2.2-runtime-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive -# Switch to Azure-hosted Ubuntu mirror — GitHub Actions runners live in -# Azure, so this is effectively an internal hop and avoids the intermittent -# archive.ubuntu.com timeouts that kill the build. +# Azure-hosted Ubuntu mirror — GitHub Actions runners live in Azure; avoids the +# intermittent archive.ubuntu.com timeouts that kill the build. RUN sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g; \ s|http://security.ubuntu.com|http://azure.archive.ubuntu.com|g' \ /etc/apt/sources.list \ && apt-get update && apt-get install -y --no-install-recommends \ - python3 python3-pip python3-dev python3-venv \ - wget build-essential zlib1g-dev libhdf5-dev \ - && rm -rf /var/lib/apt/lists/* - -# Install f5c v1.6 from pre-built binaries (CPU) -RUN wget -q "https://github.com/hasindu2008/f5c/releases/download/v1.6/f5c-v1.6-binaries.tar.gz" \ - && tar xf f5c-v1.6-binaries.tar.gz \ - && cp f5c-v1.6/f5c_x86_64_linux /usr/local/bin/f5c \ - && chmod +x /usr/local/bin/f5c \ - && rm -rf f5c-v1.6 f5c-v1.6-binaries.tar.gz - -# Install baleen with CUDA; ensure console_script is at a known path. -# -v surfaces nvcc compile/link output so silent CUDA fallback is visible in CI. -# After install, verify _cuda_dtw*.so was actually built — fail loud otherwise. -WORKDIR /app -COPY . . -RUN pip3 install pytest \ - && pip3 install --no-cache-dir -v . 2>&1 | tee /tmp/pip-install.log \ - && SO_PATH=$(cd / && python3 -c "import baleen._cuda_dtw, os; print(os.path.dirname(baleen._cuda_dtw.__file__))") \ - && if ! ls "$SO_PATH"/_cuda_dtw*.so >/dev/null 2>&1; then \ - echo "ERROR: CUDA extension (.so) was not built — image would be CPU-only." >&2; \ - echo "SO_PATH=$SO_PATH" >&2; \ - echo "find result:" >&2; \ - find /usr/local/lib /usr/lib -name '_cuda_dtw*.so' 2>/dev/null >&2 || true; \ - echo "pip install log tail:" >&2; \ - tail -80 /tmp/pip-install.log >&2; \ - exit 1; \ - fi \ - && echo "OK: _cuda_dtw $(ls $SO_PATH/_cuda_dtw*.so)" \ - && BALEEN_BIN="$(which baleen 2>/dev/null)" \ - && if [ -z "$BALEEN_BIN" ]; then \ - printf '#!/bin/sh\nexec python3 -m baleen "$@"\n' > /usr/local/bin/baleen; \ - elif [ "$BALEEN_BIN" != "/usr/local/bin/baleen" ]; then \ - cp "$BALEEN_BIN" /usr/local/bin/baleen; \ - fi \ - && chmod +x /usr/local/bin/baleen - -# Record the Python version so the runtime stage can verify it matches -RUN python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' \ - > /tmp/python_version.txt - -# --- Runtime stage --- -# NOTE: must use the same Ubuntu version as the builder so Python versions match -FROM nvidia/cuda:12.6.3-runtime-ubuntu22.04 - -ENV DEBIAN_FRONTEND=noninteractive - -RUN sed -i 's|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g; \ - s|http://security.ubuntu.com|http://azure.archive.ubuntu.com|g' \ - /etc/apt/sources.list \ - && apt-get update \ - && apt-get install -y --no-install-recommends python3 python3-pip zlib1g \ + python3 python3-pip wget ca-certificates zlib1g \ && apt-cache search '^libhdf5-[0-9]' | head -1 | awk '{print $1}' \ | xargs apt-get install -y --no-install-recommends \ && rm -rf /var/lib/apt/lists/* -COPY --from=builder /usr/local/bin/f5c /usr/local/bin/f5c -COPY --from=builder /usr/local/bin/baleen /usr/local/bin/baleen -COPY --from=builder /tmp/python_version.txt /tmp/python_version.txt +# slow5tools — baleen indexes BLOW5 via `slow5tools index` (pyslow5 needs .idx). +RUN wget -q "https://github.com/hasindu2008/slow5tools/releases/download/v1.3.0/slow5tools-v1.3.0-x86_64-linux-binaries.tar.gz" \ + && tar xf slow5tools-v1.3.0-x86_64-linux-binaries.tar.gz \ + && cp slow5tools-v1.3.0/slow5tools /usr/local/bin/slow5tools \ + && chmod +x /usr/local/bin/slow5tools \ + && rm -rf slow5tools-v1.3.0 slow5tools-v1.3.0-x86_64-linux-binaries.tar.gz -# Copy installed Python packages — Ubuntu 22.04 ships Python 3.10 -# Verify the builder used the same version before copying -COPY --from=builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages -COPY --from=builder /usr/lib/python3/dist-packages /usr/lib/python3/dist-packages -RUN BUILD_VER=$(cat /tmp/python_version.txt) \ - && RUNTIME_VER=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') \ - && if [ "$BUILD_VER" != "$RUNTIME_VER" ]; then \ - echo "ERROR: Python version mismatch: builder=$BUILD_VER runtime=$RUNTIME_VER" >&2; \ - echo "Update the COPY paths in Dockerfile.gpu to match python$RUNTIME_VER" >&2; \ - exit 1; \ - fi \ - && rm /tmp/python_version.txt +# baleen (pure Python now — no C extension to build). Ubuntu 22.04 ships a +# setuptools that predates PEP 621, so upgrade pip/setuptools first or the +# [project] table is ignored and an empty "UNKNOWN" package is built. +WORKDIR /app +COPY . . +RUN pip3 install --no-cache-dir --upgrade pip setuptools wheel \ + && pip3 install --no-cache-dir . + +# krill engine — follow the package's strict rules: +# 1. runtime deps from PyPI first (NEVER `krill[...]`, NEVER --extra-index-url) +# 2. krill itself ONLY from the project index, --no-deps. cu122 GPU wheel here. +RUN pip3 install --no-cache-dir numpy scipy pyslow5 pyfastx \ + && pip3 install --no-cache-dir krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/cu122/simple/ WORKDIR /data ENTRYPOINT ["baleen"] diff --git a/baleen/_cuda_dtw/__init__.py b/baleen/_cuda_dtw/__init__.py deleted file mode 100644 index 2244518..0000000 --- a/baleen/_cuda_dtw/__init__.py +++ /dev/null @@ -1,509 +0,0 @@ -""" -CUDA-accelerated Dynamic Time Warping (DTW) module - -This module provides GPU-accelerated DTW distance calculation with -automatic CPU fallback when CUDA is not available. - -CPU backend: delegates to tslearn. -""" - -import logging -import subprocess -import numpy as np -from typing import Union, Optional - -_log = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Backend detection -# --------------------------------------------------------------------------- - -try: - from ._cuda_dtw import dtw_distance as _dtw_distance_cuda - from ._cuda_dtw import dtw_pairwise as _dtw_pairwise_cuda - from ._cuda_dtw import dtw_pairwise_varlen as _dtw_pairwise_varlen_cuda - from ._cuda_dtw import cleanup as _cuda_cleanup - - CUDA_AVAILABLE = True -except ImportError: - CUDA_AVAILABLE = False - -try: - from ._cuda_dtw import dtw_multi_position_pairwise as _dtw_multi_position_cuda -except (ImportError, AttributeError): - _dtw_multi_position_cuda = None - -try: - from tslearn.metrics import dtw as _tslearn_dtw - from tslearn.metrics import cdist_dtw as _tslearn_cdist_dtw - - TSLEARN_AVAILABLE = True -except ImportError: - TSLEARN_AVAILABLE = False - -_BACKEND = "cuda" if CUDA_AVAILABLE else "cpu" - -# Detect cuDTW++ (v0.2+) vs legacy OpenDBA kernel -_CUDTW_ACTIVE = False -if CUDA_AVAILABLE: - try: - from ._cuda_dtw import __version__ as _cuda_ver - _CUDTW_ACTIVE = "cudtw" in _cuda_ver - except (ImportError, AttributeError): - pass - -if _BACKEND == "cuda": - if _CUDTW_ACTIVE: - _log.debug("DTW backend: cuda (cuDTW++ warp-shuffle)") - else: - _log.debug("DTW backend: cuda (legacy OpenDBA wavefront)") -else: - _log.debug("DTW backend: cpu (tslearn fallback)") - - -def backend() -> str: - """Return the name of the active DTW backend ('cuda' or 'cpu').""" - return _BACKEND - - -# --------------------------------------------------------------------------- -# Internal helpers -# --------------------------------------------------------------------------- - -_MAX_CUDTW_LEN = 2047 -_BUCKETS = (127, 255, 511, 1023, 2047) - - -def _select_bucket(n: int) -> int: - for b in _BUCKETS: - if n <= b: - return b - return _BUCKETS[-1] # caller is expected to resample before reaching here - - -def _resample_signal(sig, target_len): - """Resample a signal to target_len using scipy.""" - from scipy.signal import resample - return resample(sig.astype(np.float64), target_len).astype(np.float32) - - -# --------------------------------------------------------------------------- -# CPU DTW implementation -# --------------------------------------------------------------------------- - -def _dtw_distance_cpu(seq1: np.ndarray, seq2: np.ndarray) -> float: - """Compute DTW distance on CPU via tslearn.""" - if not TSLEARN_AVAILABLE: - raise RuntimeError( - "tslearn is required for CPU DTW.\n" - "Install it with: pip install tslearn" - ) - s1_2d = seq1.reshape(-1, 1) - s2_2d = seq2.reshape(-1, 1) - return float(_tslearn_dtw(s1_2d, s2_2d)) - - -def _dtw_pairwise_cpu(sequences: np.ndarray) -> np.ndarray: - """Compute pairwise DTW distances on CPU via tslearn.""" - if not TSLEARN_AVAILABLE: - raise RuntimeError( - "tslearn is required for CPU DTW.\n" - "Install it with: pip install tslearn" - ) - dataset_3d = sequences[:, :, np.newaxis] - result = _tslearn_cdist_dtw(dataset_3d) - return np.asarray(result, dtype=np.float64) - - -# --------------------------------------------------------------------------- -# dtw_distance (public API) -# --------------------------------------------------------------------------- - -def dtw_distance( - seq1: Union[np.ndarray, list], - seq2: Union[np.ndarray, list], - use_cuda: Optional[bool] = None, -) -> float: - """ - Compute DTW distance between two sequences. - - Parameters - ---------- - seq1 : array-like - First sequence (will be converted to float32 numpy array) - seq2 : array-like - Second sequence (will be converted to float32 numpy array) - use_cuda : bool or None, optional - Backend selection: - - None (default): auto-select (CUDA if available, else CPU) - - True: force CUDA, raises RuntimeError if unavailable - - False: force CPU - - Returns - ------- - float - DTW distance between seq1 and seq2 - """ - # --- Input conversion --- - if not isinstance(seq1, np.ndarray): - seq1 = np.array(seq1, dtype=np.float32) - else: - seq1 = np.asarray(seq1, dtype=np.float32) - - if not isinstance(seq2, np.ndarray): - seq2 = np.array(seq2, dtype=np.float32) - else: - seq2 = np.asarray(seq2, dtype=np.float32) - - if not seq1.flags["C_CONTIGUOUS"]: - seq1 = np.ascontiguousarray(seq1) - if not seq2.flags["C_CONTIGUOUS"]: - seq2 = np.ascontiguousarray(seq2) - - if seq1.ndim != 1: - raise ValueError(f"seq1 must be 1-dimensional, got shape {seq1.shape}") - if seq2.ndim != 1: - raise ValueError(f"seq2 must be 1-dimensional, got shape {seq2.shape}") - - if len(seq1) == 0 or len(seq2) == 0: - raise ValueError("Sequences cannot be empty") - - # --- Backend dispatch --- - if use_cuda is True: - if not CUDA_AVAILABLE: - raise RuntimeError( - "CUDA backend requested but not available. " - "Install with CUDA support or use use_cuda=False for CPU." - ) - return _dtw_distance_cuda(seq1, seq2) - - if use_cuda is False: - return _dtw_distance_cpu(seq1, seq2) - - # use_cuda is None: auto-select - if CUDA_AVAILABLE: - return _dtw_distance_cuda(seq1, seq2) - - return _dtw_distance_cpu(seq1, seq2) - - -# --------------------------------------------------------------------------- -# dtw_pairwise (public API) -# --------------------------------------------------------------------------- - -def dtw_pairwise( - sequences: Union[np.ndarray, list], - use_cuda: Optional[bool] = None, -) -> np.ndarray: - """ - Compute pairwise DTW distances for a batch of equal-length sequences. - - Parameters - ---------- - sequences : array-like - 2D array of sequences with shape (num_sequences, seq_length). - All sequences must have the same length. - use_cuda : bool or None, optional - Backend selection (None=auto, True=force CUDA, False=force CPU). - - Returns - ------- - np.ndarray - Symmetric distance matrix of shape (num_sequences, num_sequences). - """ - # --- Input conversion --- - if not isinstance(sequences, np.ndarray): - sequences_list = list(sequences) - if sequences_list and hasattr(sequences_list[0], '__len__'): - lengths = {len(s) for s in sequences_list} - if len(lengths) > 1: - raise ValueError( - f"All sequences must have the same length; got lengths {sorted(lengths)}" - ) - sequences = np.array(sequences, dtype=np.float32) - else: - sequences = np.asarray(sequences, dtype=np.float32) - - if sequences.ndim != 2: - raise ValueError(f"sequences must be 2D array, got shape {sequences.shape}") - if sequences.shape[0] < 2: - raise ValueError(f"Need at least 2 sequences, got {sequences.shape[0]}") - if sequences.shape[1] == 0: - raise ValueError("Sequence length cannot be 0") - - # --- Backend dispatch --- - if use_cuda is True: - if not CUDA_AVAILABLE: - raise RuntimeError( - "CUDA backend requested but not available. " - "Install with CUDA support or use use_cuda=False for CPU." - ) - return _dtw_pairwise_cuda(sequences) - - if use_cuda is False: - return _dtw_pairwise_cpu(sequences) - - if CUDA_AVAILABLE: - return _dtw_pairwise_cuda(sequences) - - return _dtw_pairwise_cpu(sequences) - - -def dtw_pairwise_varlen( - signals: list[np.ndarray], - use_cuda: Optional[bool] = None, -) -> np.ndarray: - """ - Compute pairwise DTW distances for variable-length sequences. - - Parameters - ---------- - signals : list of np.ndarray - List of 1D float32 arrays (variable lengths allowed). - use_cuda : bool or None - Backend selection (None=auto, True=force CUDA, False=force CPU). - - Returns - ------- - np.ndarray - Symmetric distance matrix of shape (N, N). - """ - if len(signals) < 2: - raise ValueError(f"Need at least 2 signals, got {len(signals)}") - - prepped = [np.ascontiguousarray(np.asarray(s, dtype=np.float32)) for s in signals] - lengths = np.array([len(s) for s in prepped], dtype=np.int64) - - if any(l == 0 for l in lengths): - raise ValueError("All signals must be non-empty") - - want_cuda = use_cuda is True or (use_cuda is None and CUDA_AVAILABLE) - - if want_cuda: - if not CUDA_AVAILABLE: - raise RuntimeError( - "CUDA backend requested but not available. " - "Install with CUDA support or use use_cuda=False." - ) - - # Resample signals > 2047 when cuDTW++ is active - if _CUDTW_ACTIVE: - max_raw = int(lengths.max()) - if max_raw > _MAX_CUDTW_LEN: - _log.info("Resampling %d signals from max %d to %d for cuDTW++", - len(prepped), max_raw, _MAX_CUDTW_LEN) - prepped = [_resample_signal(s, _MAX_CUDTW_LEN) - if len(s) > _MAX_CUDTW_LEN else s - for s in prepped] - lengths = np.array([len(s) for s in prepped], dtype=np.int64) - - max_len = int(lengths.max()) - n = len(prepped) - padded = np.zeros((n, max_len), dtype=np.float32) - for i, s in enumerate(prepped): - padded[i, :len(s)] = s - result = _dtw_pairwise_varlen_cuda(padded, lengths) - return np.asarray(result, dtype=np.float64) - - n = len(prepped) - result = np.zeros((n, n), dtype=np.float64) - for i in range(n): - for j in range(i + 1, n): - d = _dtw_distance_cpu(prepped[i], prepped[j]) - result[i, j] = d - result[j, i] = d - return result - - -# --------------------------------------------------------------------------- -# dtw_multi_position_pairwise (public API) -# --------------------------------------------------------------------------- - -def dtw_multi_position_pairwise( - position_signals: list[list[np.ndarray]], - use_cuda: Optional[bool] = None, - num_streams: int = 16, - device_id: int = 0, -) -> list[np.ndarray]: - """ - Batch-compute pairwise DTW distances for multiple positions in one GPU call. - - Parameters - ---------- - position_signals : list of list of np.ndarray - position_signals[p][r] is the 1D float32 signal for position p, read r. - use_cuda : bool or None - Backend selection (None=auto, True=force CUDA, False=force CPU). - num_streams : int - Number of CUDA streams for concurrent processing (default 16). - device_id : int - GPU device ID (default 0). - - Returns - ------- - list of np.ndarray - Distance matrices, one per position. Each is (n_p, n_p) float64. - """ - if len(position_signals) < 1: - raise ValueError("Need at least 1 position, got 0") - - prepped: list[list[np.ndarray]] = [] - counts: list[int] = [] - for pos_sigs in position_signals: - ps = [np.ascontiguousarray(np.asarray(s, dtype=np.float32)) for s in pos_sigs] - if any(len(s) == 0 for s in ps): - raise ValueError("All signals must be non-empty") - prepped.append(ps) - counts.append(len(ps)) - - want_cuda = use_cuda is True or (use_cuda is None and CUDA_AVAILABLE) - - if want_cuda: - if not CUDA_AVAILABLE or _dtw_multi_position_cuda is None: - raise RuntimeError( - "CUDA backend requested but not available. " - "Install with CUDA support or use use_cuda=False." - ) - - # Resample signals > 2047 when cuDTW++ is active - if _CUDTW_ACTIVE: - any_long = any( - len(s) > _MAX_CUDTW_LEN for pos_sigs in prepped for s in pos_sigs - ) - if any_long: - _log.info("Resampling signals > %d for cuDTW++", _MAX_CUDTW_LEN) - prepped = [ - [_resample_signal(s, _MAX_CUDTW_LEN) - if len(s) > _MAX_CUDTW_LEN else s - for s in pos_sigs] - for pos_sigs in prepped - ] - - global_max_len = max( - len(s) for pos_sigs in prepped for s in pos_sigs - ) - total_seqs = sum(counts) - - padded = np.zeros((total_seqs, global_max_len), dtype=np.float32) - lengths = np.empty(total_seqs, dtype=np.int64) - idx = 0 - for pos_sigs in prepped: - for s in pos_sigs: - padded[idx, :len(s)] = s - lengths[idx] = len(s) - idx += 1 - - counts_arr = np.array(counts, dtype=np.int64) - - flat_result = _dtw_multi_position_cuda( - padded, lengths, counts_arr, - num_cuda_streams=num_streams, - device_id=device_id, - ) - - result_list: list[np.ndarray] = [] - offset = 0 - for n in counts: - mat = np.asarray(flat_result[offset:offset + n * n], dtype=np.float64).reshape(n, n) - result_list.append(mat) - offset += n * n - return result_list - - # CPU fallback - result_list = [] - for pos_sigs in prepped: - n = len(pos_sigs) - if n < 2: - result_list.append(np.zeros((n, n), dtype=np.float64)) - continue - mat = np.zeros((n, n), dtype=np.float64) - for i in range(n): - for j in range(i + 1, n): - d = _dtw_distance_cpu(pos_sigs[i], pos_sigs[j]) - mat[i, j] = d - mat[j, i] = d - result_list.append(mat) - return result_list - - -# --------------------------------------------------------------------------- -# cleanup / is_available -# --------------------------------------------------------------------------- - -def cleanup(): - """Reset CUDA device and free all GPU resources. No-op on CPU.""" - if not CUDA_AVAILABLE: - return - _cuda_cleanup() - - -def is_available() -> bool: - """Check if CUDA DTW extension is available.""" - return CUDA_AVAILABLE - - -def estimate_gpu_memory(position_signals: list[list[np.ndarray]]) -> int: - """Estimate GPU memory bytes for a multi-position pairwise DTW call. - - The cuDTW++ wrapper allocates: - - d_subjects: total_seqs * global_bucket * 4 bytes (doubles as query source) - - d_out: sum(n_p^2) * 4 bytes (float32 squared cost) - No cost matrix is allocated (cuDTW++ accumulates in registers). - """ - total_seqs = sum(len(ps) for ps in position_signals) - max_len = max(len(s) for ps in position_signals for s in ps) - bucket = _select_bucket(max_len) - - input_bytes = total_seqs * bucket * 4 - output_bytes = sum(len(ps) ** 2 for ps in position_signals) * 4 - lengths_bytes = total_seqs * 8 # host-side h_lengths array - - total = input_bytes + output_bytes + lengths_bytes - return int(total * 1.2) # 20% headroom for stream/kernel overhead - - -def get_device_count() -> int: - """Return number of visible CUDA devices.""" - try: - result = subprocess.run( - ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], - capture_output=True, text=True, timeout=5, - ) - if result.returncode == 0: - return len([l for l in result.stdout.strip().split('\n') if l.strip()]) - except Exception: - pass - return 1 if CUDA_AVAILABLE else 0 - - -def get_per_device_memory() -> list[int]: - """Return total GPU memory in bytes for each visible CUDA device.""" - try: - result = subprocess.run( - ['nvidia-smi', '--query-gpu=memory.total', - '--format=csv,noheader,nounits'], - capture_output=True, text=True, timeout=5, - ) - if result.returncode == 0: - lines = [l.strip() for l in result.stdout.strip().split('\n') if l.strip()] - return [int(mb) * 1024 * 1024 for mb in lines] - except Exception: - pass - if CUDA_AVAILABLE: - return [8 * 1024 ** 3] # default 8 GB - return [] - - -__all__ = [ - "dtw_distance", - "dtw_pairwise", - "dtw_pairwise_varlen", - "dtw_multi_position_pairwise", - "estimate_gpu_memory", - "cleanup", - "is_available", - "backend", - "get_device_count", - "get_per_device_memory", - "CUDA_AVAILABLE", -] diff --git a/baleen/_cuda_dtw/cuda_utils.hpp b/baleen/_cuda_dtw/cuda_utils.hpp deleted file mode 100644 index 5d3fcfd..0000000 --- a/baleen/_cuda_dtw/cuda_utils.hpp +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef __dba_cuda_utils_included -#define __dba_cuda_utils_included - -#include -#include "multithreading.h" - -#define CUDA_THREADBLOCK_MAX_L1CACHE 48000 -// Note that you should not change this to >1028 unless you carefully review all the code for reduction steps that imply 32x32 map-reduce! -#ifndef CUDA_THREADBLOCK_MAX_THREADS -#define CUDA_THREADBLOCK_MAX_THREADS 1024 -#endif -#define CUDA_WARP_WIDTH 32 -#define CUERR(MSG) \ - { \ - cudaError_t err; \ - if ((err = cudaGetLastError()) != cudaSuccess) \ - { \ - std::cerr << "CUDA error: " << cudaGetErrorString(err) << " (" << MSG << ")" << std::endl; \ - exit((int)err); \ - } \ - } -#define FULL_MASK 0xffffffff - -#define DIV_ROUNDUP(numerator, denominator) (((numerator) + (denominator) - 1) / (denominator)) - -// Find the smallest value for a local variable within a warp -template -__inline__ __device__ T warpReduceMin(T val) -{ - for (int offset = CUDA_WARP_WIDTH / 2; offset > 0; offset /= 2) - { - T tmpVal = __shfl_down_sync(FULL_MASK, val, offset); - if (tmpVal < val) - { - val = tmpVal; - } - } - return val; -} - -template -__inline__ __device__ T warpReduceMax(T val) -{ - for (int offset = CUDA_WARP_WIDTH / 2; offset > 0; offset /= 2) - { - T tmpVal = __shfl_down_sync(FULL_MASK, val, offset); - if (tmpVal > val) - { - val = tmpVal; - } - } - return val; -} - -unsigned int *getMaxThreadsPerDevice(int deviceCount) -{ - unsigned int *maxThreads; - cudaMallocHost(&maxThreads, sizeof(unsigned int) * deviceCount); - CUERR("Allocating CPU memory for CUDA device properties"); - cudaDeviceProp deviceProp; - for (int i = 0; i < deviceCount; i++) - { - cudaGetDeviceProperties(&deviceProp, i); - CUERR("Getting GPU device properties"); - // When debugging there are too many registers in the DTW kernel (due to number of local variables to - // track without optimization) and you get failure to launch when using a full thread count complement. -#if DEBUG == 1 - maxThreads[i] = deviceProp.maxThreadsPerBlock / 4; -#else - maxThreads[i] = deviceProp.maxThreadsPerBlock; -#endif - } - return maxThreads; -} - -// Methods below free resources after done using asynchronously called DTW kernels -struct heterogeneous_workload -{ - void *dtwCostSoFar_memptr; // we only free it, so datatype templating is not neccesary - void *newDtwCostSoFar_memptr; // we only free it, so datatype templating is not neccesary - unsigned char *pathMatrix_memptr; - cudaStream_t stream; -}; - -__host__ - CUT_THREADPROC - dtwStreamCleanup(void *void_arg) -{ - heterogeneous_workload *workload = (heterogeneous_workload *)void_arg; - // ... GPU is done with processing, continue on new CPU thread... - - // Free dynamically allocated resources that were associated with data processing done in the stream. - // std::cerr << "Freeing memory" << std::endl; - if (workload->dtwCostSoFar_memptr != 0) - { - cudaFree(workload->dtwCostSoFar_memptr); - CUERR("Freeing DTW intermediate cost values"); - } - if (workload->newDtwCostSoFar_memptr != 0) - { - cudaFree(workload->newDtwCostSoFar_memptr); - CUERR("Freeing new DTW intermediate cost values"); - } - if (workload->pathMatrix_memptr != 0) - { - cudaFree(workload->pathMatrix_memptr); - CUERR("Freeing DTW path matrix"); - } - cudaStreamDestroy(workload->stream); - CUERR("Removing a CUDA stream after completion"); - cudaFreeHost(workload); - CUERR("Freeing host memory for dtwStreamCleanup"); - - CUT_THREADEND; -} - -__host__ void CUDART_CB dtwStreamCleanupLaunch(cudaStream_t stream, cudaError_t status, void *streamResources) -{ - // Check status of GPU after stream operations are done. Die if there was a failure. - CUERR("On callback after DTWDistance calculations completed."); - - // Spawn new CPU worker thread and perform stream resource cleanup on the CPU (since calling the CUDA API from within this callback is not allowed according to the docs). - cutStartThread(dtwStreamCleanup, streamResources); -} - -void addStreamCleanupCallback(void *dtwCostSoFar, void *newDtwCostSoFar, unsigned char *pathMatrix, cudaStream_t stream) -{ - heterogeneous_workload *cleanup_workload = 0; - cudaMallocHost(&cleanup_workload, sizeof(heterogeneous_workload)); - CUERR("Allocating page locked CPU memory for DTW stream callback data"); - cleanup_workload->dtwCostSoFar_memptr = dtwCostSoFar; - cleanup_workload->newDtwCostSoFar_memptr = newDtwCostSoFar; - cleanup_workload->pathMatrix_memptr = pathMatrix; - cleanup_workload->stream = stream; - cudaStreamAddCallback(stream, dtwStreamCleanupLaunch, cleanup_workload, 0); -} -#endif \ No newline at end of file diff --git a/baleen/_cuda_dtw/cudtw/LICENSE b/baleen/_cuda_dtw/cudtw/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/baleen/_cuda_dtw/cudtw/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/baleen/_cuda_dtw/cudtw/include/DTW.hpp b/baleen/_cuda_dtw/cudtw/include/DTW.hpp deleted file mode 100644 index f5da68a..0000000 --- a/baleen/_cuda_dtw/cudtw/include/DTW.hpp +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef CUDTW_DTW_HPP -#define CUDTW_DTW_HPP - -// Simplified cuDTW++ dispatcher for baleen. -// Only database-mode (query_type==0), non-subwarp kernels. -// Query is passed as a global-memory pointer (no constant memory). - -#include "./kernels/SHFL_FULLDTW_127.cuh" -#include "./kernels/SHFL_FULLDTW_255.cuh" -#include "./kernels/SHFL_FULLDTW_511.cuh" -#include "./kernels/SHFL_FULLDTW_1023.cuh" -#include "./kernels/SHFL_FULLDTW_2047.cuh" - -namespace FullDTW { - -template -__host__ -bool dist( - const value_t *Query, - value_t *Subject, - value_t *Dist, - index_t num_entries, - index_t num_features, - cudaStream_t stream = 0) -{ - const dim3 grid(num_entries, 1, 1); - const dim3 block(32, 1, 1); - - if (num_features == 127) { - shfl_FullDTW_127<<>>( - Query, Subject, Dist, num_entries, num_features); - return true; - } - if (num_features == 255) { - shfl_FullDTW_255<<>>( - Query, Subject, Dist, num_entries, num_features); - return true; - } - if (num_features == 511) { - shfl_FullDTW_511<<>>( - Query, Subject, Dist, num_entries, num_features); - return true; - } - if (num_features == 1023) { - shfl_FullDTW_1023<<>>( - Query, Subject, Dist, num_entries, num_features); - return true; - } - if (num_features == 2047) { - shfl_FullDTW_2047<<>>( - Query, Subject, Dist, num_entries, num_features); - return true; - } - - return false; // unsupported length -} - -} // namespace FullDTW - -#endif diff --git a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_1023.cuh b/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_1023.cuh deleted file mode 100755 index 705c174..0000000 --- a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_1023.cuh +++ /dev/null @@ -1,409 +0,0 @@ -#ifndef SHFL_FULLDTW_1023 -#define SHFL_FULLDTW_1023 - -// 32 values per thread no shared memory -template < - typename index_t, - typename value_t> __global__ -void shfl_FullDTW_1023 ( // was DTW_fast_1024_shuffle_kernel_no_shared_memory - const value_t * Query, - value_t * Subject, - value_t * Dist, - index_t num_entries, - index_t num_features) { - - const index_t blid = blockIdx.x; - const index_t thid = threadIdx.x; - const index_t lane = num_features+1; - const index_t base = blid*num_features; - const index_t WARP_SIZE = 32; - const index_t l = thid; - - - //extern __shared__ value_t Subject_cache[]; - - value_t penalty_left = INFINITY; - value_t penalty_diag = 0; // INFINITY; - value_t penalty_here0 = INFINITY; // 0; - value_t penalty_here1 = INFINITY; // 0; - value_t penalty_here2 = INFINITY; // 0; - value_t penalty_here3 = INFINITY; // 0; - value_t penalty_here4 = INFINITY; // 0; - value_t penalty_here5 = INFINITY; // 0; - value_t penalty_here6 = INFINITY; // 0; - value_t penalty_here7 = INFINITY; // 0; - value_t penalty_here8 = INFINITY; // 0; - value_t penalty_here9 = INFINITY; // 0; - value_t penalty_here10 = INFINITY; // 0; - value_t penalty_here11 = INFINITY; // 0; - value_t penalty_here12 = INFINITY; // 0; - value_t penalty_here13 = INFINITY; // 0; - value_t penalty_here14 = INFINITY; // 0; - value_t penalty_here15 = INFINITY; // 0; - value_t penalty_here16 = INFINITY; // 0; - value_t penalty_here17 = INFINITY; // 0; - value_t penalty_here18 = INFINITY; // 0; - value_t penalty_here19 = INFINITY; // 0; - value_t penalty_here20 = INFINITY; // 0; - value_t penalty_here21 = INFINITY; // 0; - value_t penalty_here22 = INFINITY; // 0; - value_t penalty_here23 = INFINITY; // 0; - value_t penalty_here24 = INFINITY; // 0; - value_t penalty_here25 = INFINITY; // 0; - value_t penalty_here26 = INFINITY; // 0; - value_t penalty_here27 = INFINITY; // 0; - value_t penalty_here28 = INFINITY; // 0; - value_t penalty_here29 = INFINITY; // 0; - value_t penalty_here30 = INFINITY; // 0; - value_t penalty_here31 = INFINITY; // 0; - value_t penalty_temp0; - value_t penalty_temp1; - - // Init shared memeory for right column - //for (index_t l = thid; l < lane; l += blockDim.x) - // Subject_cache[l] = INFINITY; - //__syncthreads(); - - //index_t iter = 0; - - //for (index_t l = thid; l < lane/32; l += blockDim.x) { - - //const index_t iter = l/WARP_SIZE; - if (thid == 0) { - penalty_left = INFINITY; - penalty_diag = INFINITY; - penalty_here0 = INFINITY; // 0; - penalty_here1 = INFINITY; // 0; - penalty_here2 = INFINITY; // 0; - penalty_here3 = INFINITY; // 0; - penalty_here4 = INFINITY; // 0; - penalty_here5 = INFINITY; // 0; - penalty_here6 = INFINITY; // 0; - penalty_here7 = INFINITY; // 0; - penalty_here8 = INFINITY; // 0; - penalty_here9 = INFINITY; // 0; - penalty_here10 = INFINITY; // 0; - penalty_here11 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here13 = INFINITY; // 0; - penalty_here14 = INFINITY; // 0; - penalty_here15 = INFINITY; // 0; - penalty_here16 = INFINITY; // 0; - penalty_here17 = INFINITY; // 0; - penalty_here18 = INFINITY; // 0; - penalty_here19 = INFINITY; // 0; - penalty_here20 = INFINITY; // 0; - penalty_here21 = INFINITY; // 0; - penalty_here22 = INFINITY; // 0; - penalty_here23 = INFINITY; // 0; - penalty_here24 = INFINITY; // 0; - penalty_here25 = INFINITY; // 0; - penalty_here26 = INFINITY; // 0; - penalty_here27 = INFINITY; // 0; - penalty_here28 = INFINITY; // 0; - penalty_here29 = INFINITY; // 0; - penalty_here30 = INFINITY; // 0; - penalty_here31 = INFINITY; // 0; - } - - //const value_t subject_value = Subject[base+l-1]; - const value_t subject_value0 = l == 0 ? 0 : Subject[base+32*l-1]; - const value_t subject_value1 = Subject[base+32*l-0]; - const value_t subject_value2 = Subject[base+32*l+1]; - const value_t subject_value3 = Subject[base+32*l+2]; - const value_t subject_value4 = Subject[base+32*l+3]; - const value_t subject_value5 = Subject[base+32*l+4]; - const value_t subject_value6 = Subject[base+32*l+5]; - const value_t subject_value7 = Subject[base+32*l+6]; - const value_t subject_value8 = Subject[base+32*l+7]; - const value_t subject_value9 = Subject[base+32*l+8]; - const value_t subject_value10 = Subject[base+32*l+9]; - const value_t subject_value11 = Subject[base+32*l+10]; - const value_t subject_value12 = Subject[base+32*l+11]; - const value_t subject_value13 = Subject[base+32*l+12]; - const value_t subject_value14 = Subject[base+32*l+13]; - const value_t subject_value15 = Subject[base+32*l+14]; - - const value_t subject_value16 = Subject[base+32*l+15]; - const value_t subject_value17 = Subject[base+32*l+16]; - const value_t subject_value18 = Subject[base+32*l+17]; - const value_t subject_value19 = Subject[base+32*l+18]; - const value_t subject_value20 = Subject[base+32*l+19]; - const value_t subject_value21 = Subject[base+32*l+20]; - const value_t subject_value22 = Subject[base+32*l+21]; - const value_t subject_value23 = Subject[base+32*l+22]; - const value_t subject_value24 = Subject[base+32*l+23]; - const value_t subject_value25 = Subject[base+32*l+24]; - const value_t subject_value26 = Subject[base+32*l+25]; - const value_t subject_value27 = Subject[base+32*l+26]; - const value_t subject_value28 = Subject[base+32*l+27]; - const value_t subject_value29 = Subject[base+32*l+28]; - const value_t subject_value30 = Subject[base+32*l+29]; - const value_t subject_value31 = Subject[base+32*l+30]; - //if (blid == 0) Dist[2*l] = subject_value0; - //if (blid == 0) Dist[2*l+1] = subject_value1; - index_t counter = 1; - value_t query_value = INFINITY; - value_t new_query_value = Query[thid]; - if (thid == 0) query_value = new_query_value; - if (thid == 0) penalty_here1 = 0; // (query_value - subject_value0)*(query_value - subject_value0); - //penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here1, 1, 32); - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //const index_t j = l; - //if (blid == 0 && thid == 31 && iter == 0) Dist[0] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[2*thid+1] = penalty_here1; - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = INFINITY; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here31, 1, 32); - - //if (iter && thid == 0) penalty_left = Subject_cache[2+1]; - if (thid == 0) penalty_left = INFINITY; - - for (index_t k = 3; k < lane+WARP_SIZE-1; k++) { - const index_t i = k-l; - //outside = k <= l || i >= lane; - - //const value_t residue = outside ? INFINITY : Query[i-1]-subject_value; - //const value_t residue = outside ? INFINITY : query_value-subject_value; - //if (thid == 0 && iter == 0 && k == 2) penalty_temp = INFINITY; else - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)+1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid+1] = penalty_here1; - - //if (counter%32 == 0 && counter > 1) new_query_value = Query[i+2*thid-1]; - if (counter%32 == 0) new_query_value = Query[i+2*thid-1]; - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //if (thid == 0) if (!outside) Dist[counter] = query_value; else Dist[counter] = 0; - counter++; - - // save the right column - //if (!outside && thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[i] = penalty_here; // TO DO: replace this by shhffles - //if (iter < lane/WARP_SIZE-1 && thid == 31 && k>l) Subject_cache[i] = penalty_here31; - - // shuffle the penalty - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here31, 1, 32); - //if (thid == 0 && !outside) penalty_left = Subject_cache[i+1]; // TO DO: replace by shuffles - // if (iter > 0 && thid == 0 && k>l) penalty_left = Subject_cache[i+1]; - //if (iter && thid == 0) penalty_left = Subject_cache[i+1]; - //if (!iter && thid == 0) penalty_left = INFINITY; - if (thid == 0) penalty_left = INFINITY; - - //if (thid == 0 && k>l) if (iter > 0) penalty_left = Subject_cache[i+1]; else penalty_left = INFINITY; - //if (thid == 0 && k>l && iter > 0) penalty_left = Subject_cache[i+1]; - } - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_diag)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - //if (thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[lane+(iter+1)*WARP_SIZE-1-l] = penalty_here31; - //iter++; - - if(thid == blockDim.x-1) Dist[blid] = penalty_here31; -} - -#endif diff --git a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_127.cuh b/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_127.cuh deleted file mode 100755 index dee438f..0000000 --- a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_127.cuh +++ /dev/null @@ -1,164 +0,0 @@ -#ifndef SHFL_FULLDTW_127 -#define SHFL_FULLDTW_127 - -// Using 4 values per thread no shared memory -template < - typename index_t, - typename value_t> __global__ -void shfl_FullDTW_127( // was void DTW_fast_128_shuffle_kernel_no_shared_memory - const value_t * Query, - value_t * Subject, - value_t * Dist, - index_t num_entries, - index_t num_features) { - - const index_t blid = blockIdx.x; - const index_t thid = threadIdx.x; - const index_t lane = num_features+1; - const index_t base = blid*num_features; - const index_t WARP_SIZE = 32; - const index_t l = thid; - - //extern __shared__ value_t Subject_cache[]; - - value_t penalty_left = INFINITY; - value_t penalty_diag = 0; // INFINITY; - value_t penalty_here0 = INFINITY; // 0; - value_t penalty_here1 = INFINITY; // 0; - value_t penalty_here2 = INFINITY; // 0; - value_t penalty_here3 = INFINITY; // 0; - value_t penalty_temp0; - value_t penalty_temp1; - - // Init shared memeory for right column - //for (index_t l = thid; l < lane; l += blockDim.x) - // Subject_cache[l] = INFINITY; - //__syncthreads(); - - //index_t iter = 0; - - //for (index_t l = thid; l < lane/4; l += blockDim.x) { - // const index_t iter = l/WARP_SIZE; - if (thid == 0) { - //if (iter > 0) penalty_left = Subject_cache[1]; else penalty_left = INFINITY; - penalty_left = INFINITY; - penalty_diag = INFINITY; - penalty_here0 = INFINITY; // 0; - penalty_here1 = INFINITY; // 0; - penalty_here2 = INFINITY; // 0; - penalty_here3 = INFINITY; // 0; - - } - - //const value_t subject_value = Subject[base+l-1]; - const value_t subject_value0 = l == 0 ? 0 : Subject[base+4*l-1]; - const value_t subject_value1 = Subject[base+4*l-0]; - const value_t subject_value2 = Subject[base+4*l+1]; - const value_t subject_value3 = Subject[base+4*l+2]; - //if (blid == 0) Dist[2*l] = subject_value0; - //if (blid == 0) Dist[2*l+1] = subject_value1; - index_t counter = 1; - value_t query_value = INFINITY; - value_t new_query_value = Query[thid]; - if (thid == 0) query_value = new_query_value; - if (thid == 0) penalty_here1 = 0; // (query_value - subject_value0)*(query_value - subject_value0); - //penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here1, 1, 32); - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //const index_t j = l; - //if (blid == 0 && thid == 31 && iter == 0) Dist[0] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[2*thid+1] = penalty_here1; - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = INFINITY; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here3, 1, 32); - - //if (iter && thid == 0) penalty_left = Subject_cache[2+1]; - if (thid == 0) penalty_left = INFINITY; - - - for (index_t k = 3; k < lane+WARP_SIZE-1; k++) { - const index_t i = k-l; - //outside = k <= l || i >= lane; - - //const value_t residue = outside ? INFINITY : Query[i-1]-subject_value; - //const value_t residue = outside ? INFINITY : query_value-subject_value; - //if (thid == 0 && iter == 0 && k == 2) penalty_temp = INFINITY; else - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - - //if (i <= lane && blid == 15) { - // Dist[i*lane + 4*l] = penalty_here0; - // Dist[i*lane + 4*l+1] = penalty_here1; - // Dist[i*lane + 4*l+2] = penalty_here2; - // Dist[i*lane + 4*l+3] = penalty_here3; - //} - // if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)+1] = penalty_here1; - // if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid] = penalty_here0; - // if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid+1] = penalty_here1; - - //if (counter%32 == 0 && counter > 1) new_query_value = Query[i+2*thid-1]; - if (counter%32 == 0) new_query_value = Query[i+2*thid-1]; - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //if (thid == 0) if (!outside) Dist[counter] = query_value; else Dist[counter] = 0; - counter++; - - // save the right column - //if (!outside && thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[i] = penalty_here; // TO DO: replace this by shhffles - //if (iter < lane/WARP_SIZE-1 && thid == 31 && k>l) Subject_cache[i] = penalty_here3; - - // shuffle the penalty - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here3, 1, 32); - //if (thid == 0 && !outside) penalty_left = Subject_cache[i+1]; // TO DO: replace by shuffles - // if (iter > 0 && thid == 0 && k>l) penalty_left = Subject_cache[i+1]; - //if (iter && thid == 0) penalty_left = Subject_cache[i+1]; - //if (!iter && thid == 0) penalty_left = INFINITY; - if (thid == 0) penalty_left = INFINITY; - //if (thid == 0 && k>l) if (iter > 0) penalty_left = Subject_cache[i+1]; else penalty_left = INFINITY; - //if (thid == 0 && k>l && iter > 0) penalty_left = Subject_cache[i+1]; - } - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - - //if (blid == 15) { - // Dist[128*lane + 4*l] = penalty_here0; - // Dist[128*lane + 4*l+1] = penalty_here1; - // Dist[128*lane + 4*l+2] = penalty_here2; - // Dist[128*lane + 4*l+3] = penalty_here3; - //} - - //if (thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[lane+(iter+1)*WARP_SIZE-1-l] = penalty_here3; - //iter++; - - if(thid == blockDim.x-1) Dist[blid] = penalty_here3; -} - -#endif diff --git a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_2047.cuh b/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_2047.cuh deleted file mode 100755 index c34a4d4..0000000 --- a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_2047.cuh +++ /dev/null @@ -1,651 +0,0 @@ -#ifndef SHFL_FULLDTW_2047 -#define SHFL_FULLDTW_2047 - - -// 64 values per thread no shared memory -template < - typename index_t, - typename value_t> __global__ -void shfl_FullDTW_2047( - const value_t * Query, - value_t * Subject, - value_t * Dist, - index_t num_entries, - index_t num_features) { - - const index_t blid = blockIdx.x; - const index_t thid = threadIdx.x; - const index_t lane = num_features+1; - const index_t base = blid*num_features; - const index_t l = thid; - //extern __shared__ value_t Subject_cache[]; - - value_t penalty_left = INFINITY; - value_t penalty_diag = 0; // INFINITY; - value_t penalty_here0 = INFINITY; // 0; - value_t penalty_here1 = INFINITY; // 0; - value_t penalty_here2 = INFINITY; // 0; - value_t penalty_here3 = INFINITY; // 0; - value_t penalty_here4 = INFINITY; // 0; - value_t penalty_here5 = INFINITY; // 0; - value_t penalty_here6 = INFINITY; // 0; - value_t penalty_here7 = INFINITY; // 0; - value_t penalty_here8 = INFINITY; // 0; - value_t penalty_here9 = INFINITY; // 0; - value_t penalty_here10 = INFINITY; // 0; - value_t penalty_here11 = INFINITY; // 0; - value_t penalty_here12 = INFINITY; // 0; - value_t penalty_here13 = INFINITY; // 0; - value_t penalty_here14 = INFINITY; // 0; - value_t penalty_here15 = INFINITY; // 0; - value_t penalty_here16 = INFINITY; // 0; - value_t penalty_here17 = INFINITY; // 0; - value_t penalty_here18 = INFINITY; // 0; - value_t penalty_here19 = INFINITY; // 0; - value_t penalty_here20 = INFINITY; // 0; - value_t penalty_here21 = INFINITY; // 0; - value_t penalty_here22 = INFINITY; // 0; - value_t penalty_here23 = INFINITY; // 0; - value_t penalty_here24 = INFINITY; // 0; - value_t penalty_here25 = INFINITY; // 0; - value_t penalty_here26 = INFINITY; // 0; - value_t penalty_here27 = INFINITY; // 0; - value_t penalty_here28 = INFINITY; // 0; - value_t penalty_here29 = INFINITY; // 0; - value_t penalty_here30 = INFINITY; // 0; - value_t penalty_here31 = INFINITY; // 0; - value_t penalty_here32 = INFINITY; // 0; - value_t penalty_here33 = INFINITY; // 0; - value_t penalty_here34 = INFINITY; // 0; - value_t penalty_here35 = INFINITY; // 0; - value_t penalty_here36 = INFINITY; // 0; - value_t penalty_here37 = INFINITY; // 0; - value_t penalty_here38 = INFINITY; // 0; - value_t penalty_here39 = INFINITY; // 0; - value_t penalty_here40 = INFINITY; // 0; - value_t penalty_here41 = INFINITY; // 0; - value_t penalty_here42 = INFINITY; // 0; - value_t penalty_here43 = INFINITY; // 0; - value_t penalty_here44 = INFINITY; // 0; - value_t penalty_here45 = INFINITY; // 0; - value_t penalty_here46 = INFINITY; // 0; - value_t penalty_here47 = INFINITY; // 0; - value_t penalty_here48 = INFINITY; // 0; - value_t penalty_here49 = INFINITY; // 0; - value_t penalty_here50 = INFINITY; // 0; - value_t penalty_here51 = INFINITY; // 0; - value_t penalty_here52 = INFINITY; // 0; - value_t penalty_here53 = INFINITY; // 0; - value_t penalty_here54 = INFINITY; // 0; - value_t penalty_here55 = INFINITY; // 0; - value_t penalty_here56 = INFINITY; // 0; - value_t penalty_here57 = INFINITY; // 0; - value_t penalty_here58 = INFINITY; // 0; - value_t penalty_here59 = INFINITY; // 0; - value_t penalty_here60 = INFINITY; // 0; - value_t penalty_here61 = INFINITY; // 0; - value_t penalty_here62 = INFINITY; // 0; - value_t penalty_here63 = INFINITY; // 0; - value_t penalty_temp0; - value_t penalty_temp1; - - // Init shared memeory for right column - //for (index_t l = thid; l < lane; l += blockDim.x) - // Subject_cache[l] = INFINITY; - //__syncthreads(); - - if (thid == 0) { - penalty_left = INFINITY; - penalty_diag = INFINITY; - penalty_here0 = INFINITY; // 0; - penalty_here1 = INFINITY; // 0; - penalty_here2 = INFINITY; // 0; - penalty_here3 = INFINITY; // 0; - penalty_here4 = INFINITY; // 0; - penalty_here5 = INFINITY; // 0; - penalty_here6 = INFINITY; // 0; - penalty_here7 = INFINITY; // 0; - penalty_here8 = INFINITY; // 0; - penalty_here9 = INFINITY; // 0; - penalty_here10 = INFINITY; // 0; - penalty_here11 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here13 = INFINITY; // 0; - penalty_here14 = INFINITY; // 0; - penalty_here15 = INFINITY; // 0; - penalty_here16 = INFINITY; // 0; - penalty_here17 = INFINITY; // 0; - penalty_here18 = INFINITY; // 0; - penalty_here19 = INFINITY; // 0; - penalty_here20 = INFINITY; // 0; - penalty_here21 = INFINITY; // 0; - penalty_here22 = INFINITY; // 0; - penalty_here23 = INFINITY; // 0; - penalty_here24 = INFINITY; // 0; - penalty_here25 = INFINITY; // 0; - penalty_here26 = INFINITY; // 0; - penalty_here27 = INFINITY; // 0; - penalty_here28 = INFINITY; // 0; - penalty_here29 = INFINITY; // 0; - penalty_here30 = INFINITY; // 0; - penalty_here31 = INFINITY; // 0; - } - - //const value_t subject_value = Subject[base+l-1]; - //const value_t subject_value0 = Subject[base+64*l-1]; - const value_t subject_value0 = l == 0 ? 0 : Subject[base+64*l-1]; - const value_t subject_value1 = Subject[base+64*l-0]; - const value_t subject_value2 = Subject[base+64*l+1]; - const value_t subject_value3 = Subject[base+64*l+2]; - const value_t subject_value4 = Subject[base+64*l+3]; - const value_t subject_value5 = Subject[base+64*l+4]; - const value_t subject_value6 = Subject[base+64*l+5]; - const value_t subject_value7 = Subject[base+64*l+6]; - const value_t subject_value8 = Subject[base+64*l+7]; - const value_t subject_value9 = Subject[base+64*l+8]; - const value_t subject_value10 = Subject[base+64*l+9]; - const value_t subject_value11 = Subject[base+64*l+10]; - const value_t subject_value12 = Subject[base+64*l+11]; - const value_t subject_value13 = Subject[base+64*l+12]; - const value_t subject_value14 = Subject[base+64*l+13]; - const value_t subject_value15 = Subject[base+64*l+14]; - - const value_t subject_value16 = Subject[base+64*l+15]; - const value_t subject_value17 = Subject[base+64*l+16]; - const value_t subject_value18 = Subject[base+64*l+17]; - const value_t subject_value19 = Subject[base+64*l+18]; - const value_t subject_value20 = Subject[base+64*l+19]; - const value_t subject_value21 = Subject[base+64*l+20]; - const value_t subject_value22 = Subject[base+64*l+21]; - const value_t subject_value23 = Subject[base+64*l+22]; - const value_t subject_value24 = Subject[base+64*l+23]; - const value_t subject_value25 = Subject[base+64*l+24]; - const value_t subject_value26 = Subject[base+64*l+25]; - const value_t subject_value27 = Subject[base+64*l+26]; - const value_t subject_value28 = Subject[base+64*l+27]; - const value_t subject_value29 = Subject[base+64*l+28]; - const value_t subject_value30 = Subject[base+64*l+29]; - const value_t subject_value31 = Subject[base+64*l+30]; - - const value_t subject_value32 = Subject[base+64*l+31]; - const value_t subject_value33 = Subject[base+64*l+32]; - const value_t subject_value34 = Subject[base+64*l+33]; - const value_t subject_value35 = Subject[base+64*l+34]; - const value_t subject_value36 = Subject[base+64*l+35]; - const value_t subject_value37 = Subject[base+64*l+36]; - const value_t subject_value38 = Subject[base+64*l+37]; - const value_t subject_value39 = Subject[base+64*l+38]; - const value_t subject_value40 = Subject[base+64*l+39]; - const value_t subject_value41 = Subject[base+64*l+40]; - const value_t subject_value42 = Subject[base+64*l+41]; - const value_t subject_value43 = Subject[base+64*l+42]; - const value_t subject_value44 = Subject[base+64*l+43]; - const value_t subject_value45 = Subject[base+64*l+44]; - const value_t subject_value46 = Subject[base+64*l+45]; - const value_t subject_value47 = Subject[base+64*l+46]; - - const value_t subject_value48 = Subject[base+64*l+47]; - const value_t subject_value49 = Subject[base+64*l+48]; - const value_t subject_value50 = Subject[base+64*l+49]; - const value_t subject_value51 = Subject[base+64*l+50]; - const value_t subject_value52 = Subject[base+64*l+51]; - const value_t subject_value53 = Subject[base+64*l+52]; - const value_t subject_value54 = Subject[base+64*l+53]; - const value_t subject_value55 = Subject[base+64*l+54]; - const value_t subject_value56 = Subject[base+64*l+55]; - const value_t subject_value57 = Subject[base+64*l+56]; - const value_t subject_value58 = Subject[base+64*l+57]; - const value_t subject_value59 = Subject[base+64*l+58]; - const value_t subject_value60 = Subject[base+64*l+59]; - const value_t subject_value61 = Subject[base+64*l+60]; - const value_t subject_value62 = Subject[base+64*l+61]; - const value_t subject_value63 = Subject[base+64*l+62]; - //if (blid == 0) Dist[2*l] = subject_value0; - //if (blid == 0) Dist[2*l+1] = subject_value1; - index_t counter = 1; - value_t query_value = INFINITY; - value_t new_query_value = Query[thid]; - if (thid == 0) query_value = new_query_value; - if (thid == 0) penalty_here1 = 0; // (query_value - subject_value0)*(query_value - subject_value0); - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //const index_t j = l; - //if (blid == 0 && thid == 31 && iter == 0) Dist[0] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[2*thid+1] = penalty_here1; - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = INFINITY; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_temp1 = penalty_here31; - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - penalty_temp0 = penalty_here32; - penalty_here32 = (query_value-subject_value32) * (query_value-subject_value32) + min(penalty_here31, min(penalty_here32, penalty_temp1)); - penalty_temp1 = penalty_here33; - penalty_here33 = (query_value-subject_value33) * (query_value-subject_value33) + min(penalty_here32, min(penalty_here33, penalty_temp0)); - penalty_temp0 = penalty_here34; - penalty_here34 = (query_value-subject_value34) * (query_value-subject_value34) + min(penalty_here33, min(penalty_here34, penalty_temp1)); - penalty_temp1 = penalty_here35; - penalty_here35 = (query_value-subject_value35) * (query_value-subject_value35) + min(penalty_here34, min(penalty_here35, penalty_temp0)); - penalty_temp0 = penalty_here36; - penalty_here36 = (query_value-subject_value36) * (query_value-subject_value36) + min(penalty_here35, min(penalty_here36, penalty_temp1)); - penalty_temp1 = penalty_here37; - penalty_here37 = (query_value-subject_value37) * (query_value-subject_value37) + min(penalty_here36, min(penalty_here37, penalty_temp0)); - penalty_temp0 = penalty_here38; - penalty_here38 = (query_value-subject_value38) * (query_value-subject_value38) + min(penalty_here37, min(penalty_here38, penalty_temp1)); - penalty_temp1 = penalty_here39; - penalty_here39 = (query_value-subject_value39) * (query_value-subject_value39) + min(penalty_here38, min(penalty_here39, penalty_temp0)); - penalty_temp0 = penalty_here40; - penalty_here40 = (query_value-subject_value40) * (query_value-subject_value40) + min(penalty_here39, min(penalty_here40, penalty_temp1)); - penalty_temp1 = penalty_here41; - penalty_here41 = (query_value-subject_value41) * (query_value-subject_value41) + min(penalty_here40, min(penalty_here41, penalty_temp0)); - penalty_temp0 = penalty_here42; - penalty_here42 = (query_value-subject_value42) * (query_value-subject_value42) + min(penalty_here41, min(penalty_here42, penalty_temp1)); - penalty_temp1 = penalty_here43; - penalty_here43 = (query_value-subject_value43) * (query_value-subject_value43) + min(penalty_here42, min(penalty_here43, penalty_temp0)); - penalty_temp0 = penalty_here44; - penalty_here44 = (query_value-subject_value44) * (query_value-subject_value44) + min(penalty_here43, min(penalty_here44, penalty_temp1)); - penalty_temp1 = penalty_here45; - penalty_here45 = (query_value-subject_value45) * (query_value-subject_value45) + min(penalty_here44, min(penalty_here45, penalty_temp0)); - penalty_temp0 = penalty_here46; - penalty_here46 = (query_value-subject_value46) * (query_value-subject_value46) + min(penalty_here45, min(penalty_here46, penalty_temp1)); - penalty_temp1 = penalty_here47; - penalty_here47 = (query_value-subject_value47) * (query_value-subject_value47) + min(penalty_here46, min(penalty_here47, penalty_temp0)); - - penalty_temp0 = penalty_here48; - penalty_here48 = (query_value-subject_value48) * (query_value-subject_value48) + min(penalty_here47, min(penalty_here48, penalty_temp1)); - penalty_temp1 = penalty_here49; - penalty_here49 = (query_value-subject_value49) * (query_value-subject_value49) + min(penalty_here48, min(penalty_here49, penalty_temp0)); - penalty_temp0 = penalty_here50; - penalty_here50 = (query_value-subject_value50) * (query_value-subject_value50) + min(penalty_here49, min(penalty_here50, penalty_temp1)); - penalty_temp1 = penalty_here51; - penalty_here51 = (query_value-subject_value51) * (query_value-subject_value51) + min(penalty_here50, min(penalty_here51, penalty_temp0)); - penalty_temp0 = penalty_here52; - penalty_here52 = (query_value-subject_value52) * (query_value-subject_value52) + min(penalty_here51, min(penalty_here52, penalty_temp1)); - penalty_temp1 = penalty_here53; - penalty_here53 = (query_value-subject_value53) * (query_value-subject_value53) + min(penalty_here52, min(penalty_here53, penalty_temp0)); - penalty_temp0 = penalty_here54; - penalty_here54 = (query_value-subject_value54) * (query_value-subject_value54) + min(penalty_here53, min(penalty_here54, penalty_temp1)); - penalty_temp1 = penalty_here55; - penalty_here55 = (query_value-subject_value55) * (query_value-subject_value55) + min(penalty_here54, min(penalty_here55, penalty_temp0)); - penalty_temp0 = penalty_here56; - penalty_here56 = (query_value-subject_value56) * (query_value-subject_value56) + min(penalty_here55, min(penalty_here56, penalty_temp1)); - penalty_temp1 = penalty_here57; - penalty_here57 = (query_value-subject_value57) * (query_value-subject_value57) + min(penalty_here56, min(penalty_here57, penalty_temp0)); - penalty_temp0 = penalty_here58; - penalty_here58 = (query_value-subject_value58) * (query_value-subject_value58) + min(penalty_here57, min(penalty_here58, penalty_temp1)); - penalty_temp1 = penalty_here59; - penalty_here59 = (query_value-subject_value59) * (query_value-subject_value59) + min(penalty_here58, min(penalty_here59, penalty_temp0)); - penalty_temp0 = penalty_here60; - penalty_here60 = (query_value-subject_value60) * (query_value-subject_value60) + min(penalty_here59, min(penalty_here60, penalty_temp1)); - penalty_temp1 = penalty_here61; - penalty_here61 = (query_value-subject_value61) * (query_value-subject_value61) + min(penalty_here60, min(penalty_here61, penalty_temp0)); - penalty_temp0 = penalty_here62; - penalty_here62 = (query_value-subject_value62) * (query_value-subject_value62) + min(penalty_here61, min(penalty_here62, penalty_temp1)); - penalty_here63 = (query_value-subject_value63) * (query_value-subject_value63) + min(penalty_here62, min(penalty_here63, penalty_temp0)); - - - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here31, 1, 32); - - if (thid == 0) penalty_left = INFINITY; - - for (index_t k = 3; k < lane+32-1; k++) { - const index_t i = k-l; - //outside = k <= l || i >= lane; - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_temp1 = penalty_here31; - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - penalty_temp0 = penalty_here32; - penalty_here32 = (query_value-subject_value32) * (query_value-subject_value32) + min(penalty_here31, min(penalty_here32, penalty_temp1)); - penalty_temp1 = penalty_here33; - penalty_here33 = (query_value-subject_value33) * (query_value-subject_value33) + min(penalty_here32, min(penalty_here33, penalty_temp0)); - penalty_temp0 = penalty_here34; - penalty_here34 = (query_value-subject_value34) * (query_value-subject_value34) + min(penalty_here33, min(penalty_here34, penalty_temp1)); - penalty_temp1 = penalty_here35; - penalty_here35 = (query_value-subject_value35) * (query_value-subject_value35) + min(penalty_here34, min(penalty_here35, penalty_temp0)); - penalty_temp0 = penalty_here36; - penalty_here36 = (query_value-subject_value36) * (query_value-subject_value36) + min(penalty_here35, min(penalty_here36, penalty_temp1)); - penalty_temp1 = penalty_here37; - penalty_here37 = (query_value-subject_value37) * (query_value-subject_value37) + min(penalty_here36, min(penalty_here37, penalty_temp0)); - penalty_temp0 = penalty_here38; - penalty_here38 = (query_value-subject_value38) * (query_value-subject_value38) + min(penalty_here37, min(penalty_here38, penalty_temp1)); - penalty_temp1 = penalty_here39; - penalty_here39 = (query_value-subject_value39) * (query_value-subject_value39) + min(penalty_here38, min(penalty_here39, penalty_temp0)); - penalty_temp0 = penalty_here40; - penalty_here40 = (query_value-subject_value40) * (query_value-subject_value40) + min(penalty_here39, min(penalty_here40, penalty_temp1)); - penalty_temp1 = penalty_here41; - penalty_here41 = (query_value-subject_value41) * (query_value-subject_value41) + min(penalty_here40, min(penalty_here41, penalty_temp0)); - penalty_temp0 = penalty_here42; - penalty_here42 = (query_value-subject_value42) * (query_value-subject_value42) + min(penalty_here41, min(penalty_here42, penalty_temp1)); - penalty_temp1 = penalty_here43; - penalty_here43 = (query_value-subject_value43) * (query_value-subject_value43) + min(penalty_here42, min(penalty_here43, penalty_temp0)); - penalty_temp0 = penalty_here44; - penalty_here44 = (query_value-subject_value44) * (query_value-subject_value44) + min(penalty_here43, min(penalty_here44, penalty_temp1)); - penalty_temp1 = penalty_here45; - penalty_here45 = (query_value-subject_value45) * (query_value-subject_value45) + min(penalty_here44, min(penalty_here45, penalty_temp0)); - penalty_temp0 = penalty_here46; - penalty_here46 = (query_value-subject_value46) * (query_value-subject_value46) + min(penalty_here45, min(penalty_here46, penalty_temp1)); - penalty_temp1 = penalty_here47; - penalty_here47 = (query_value-subject_value47) * (query_value-subject_value47) + min(penalty_here46, min(penalty_here47, penalty_temp0)); - - penalty_temp0 = penalty_here48; - penalty_here48 = (query_value-subject_value48) * (query_value-subject_value48) + min(penalty_here47, min(penalty_here48, penalty_temp1)); - penalty_temp1 = penalty_here49; - penalty_here49 = (query_value-subject_value49) * (query_value-subject_value49) + min(penalty_here48, min(penalty_here49, penalty_temp0)); - penalty_temp0 = penalty_here50; - penalty_here50 = (query_value-subject_value50) * (query_value-subject_value50) + min(penalty_here49, min(penalty_here50, penalty_temp1)); - penalty_temp1 = penalty_here51; - penalty_here51 = (query_value-subject_value51) * (query_value-subject_value51) + min(penalty_here50, min(penalty_here51, penalty_temp0)); - penalty_temp0 = penalty_here52; - penalty_here52 = (query_value-subject_value52) * (query_value-subject_value52) + min(penalty_here51, min(penalty_here52, penalty_temp1)); - penalty_temp1 = penalty_here53; - penalty_here53 = (query_value-subject_value53) * (query_value-subject_value53) + min(penalty_here52, min(penalty_here53, penalty_temp0)); - penalty_temp0 = penalty_here54; - penalty_here54 = (query_value-subject_value54) * (query_value-subject_value54) + min(penalty_here53, min(penalty_here54, penalty_temp1)); - penalty_temp1 = penalty_here55; - penalty_here55 = (query_value-subject_value55) * (query_value-subject_value55) + min(penalty_here54, min(penalty_here55, penalty_temp0)); - penalty_temp0 = penalty_here56; - penalty_here56 = (query_value-subject_value56) * (query_value-subject_value56) + min(penalty_here55, min(penalty_here56, penalty_temp1)); - penalty_temp1 = penalty_here57; - penalty_here57 = (query_value-subject_value57) * (query_value-subject_value57) + min(penalty_here56, min(penalty_here57, penalty_temp0)); - penalty_temp0 = penalty_here58; - penalty_here58 = (query_value-subject_value58) * (query_value-subject_value58) + min(penalty_here57, min(penalty_here58, penalty_temp1)); - penalty_temp1 = penalty_here59; - penalty_here59 = (query_value-subject_value59) * (query_value-subject_value59) + min(penalty_here58, min(penalty_here59, penalty_temp0)); - penalty_temp0 = penalty_here60; - penalty_here60 = (query_value-subject_value60) * (query_value-subject_value60) + min(penalty_here59, min(penalty_here60, penalty_temp1)); - penalty_temp1 = penalty_here61; - penalty_here61 = (query_value-subject_value61) * (query_value-subject_value61) + min(penalty_here60, min(penalty_here61, penalty_temp0)); - penalty_temp0 = penalty_here62; - penalty_here62 = (query_value-subject_value62) * (query_value-subject_value62) + min(penalty_here61, min(penalty_here62, penalty_temp1)); - penalty_here63 = (query_value-subject_value63) * (query_value-subject_value63) + min(penalty_here62, min(penalty_here63, penalty_temp0)); - - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)+1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid+1] = penalty_here1; - - if (counter%32 == 0) new_query_value = Query[i+2*thid-1]; - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //if (thid == 0) if (!outside) Dist[counter] = query_value; else Dist[counter] = 0; - counter++; - - // shuffle the penalty - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here63, 1, 32); - //if (thid == 0 && !outside) penalty_left = Subject_cache[i+1]; // TO DO: replace by shuffles - // if (iter > 0 && thid == 0 && k>l) penalty_left = Subject_cache[i+1]; - //if (iter && thid == 0) penalty_left = Subject_cache[i+1]; - //if (!iter && thid == 0) penalty_left = INFINITY; - if (thid == 0) penalty_left = INFINITY; - } - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_diag)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_temp1 = penalty_here31; - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - penalty_temp0 = penalty_here32; - penalty_here32 = (query_value-subject_value32) * (query_value-subject_value32) + min(penalty_here31, min(penalty_here32, penalty_temp1)); - penalty_temp1 = penalty_here33; - penalty_here33 = (query_value-subject_value33) * (query_value-subject_value33) + min(penalty_here32, min(penalty_here33, penalty_temp0)); - penalty_temp0 = penalty_here34; - penalty_here34 = (query_value-subject_value34) * (query_value-subject_value34) + min(penalty_here33, min(penalty_here34, penalty_temp1)); - penalty_temp1 = penalty_here35; - penalty_here35 = (query_value-subject_value35) * (query_value-subject_value35) + min(penalty_here34, min(penalty_here35, penalty_temp0)); - penalty_temp0 = penalty_here36; - penalty_here36 = (query_value-subject_value36) * (query_value-subject_value36) + min(penalty_here35, min(penalty_here36, penalty_temp1)); - penalty_temp1 = penalty_here37; - penalty_here37 = (query_value-subject_value37) * (query_value-subject_value37) + min(penalty_here36, min(penalty_here37, penalty_temp0)); - penalty_temp0 = penalty_here38; - penalty_here38 = (query_value-subject_value38) * (query_value-subject_value38) + min(penalty_here37, min(penalty_here38, penalty_temp1)); - penalty_temp1 = penalty_here39; - penalty_here39 = (query_value-subject_value39) * (query_value-subject_value39) + min(penalty_here38, min(penalty_here39, penalty_temp0)); - penalty_temp0 = penalty_here40; - penalty_here40 = (query_value-subject_value40) * (query_value-subject_value40) + min(penalty_here39, min(penalty_here40, penalty_temp1)); - penalty_temp1 = penalty_here41; - penalty_here41 = (query_value-subject_value41) * (query_value-subject_value41) + min(penalty_here40, min(penalty_here41, penalty_temp0)); - penalty_temp0 = penalty_here42; - penalty_here42 = (query_value-subject_value42) * (query_value-subject_value42) + min(penalty_here41, min(penalty_here42, penalty_temp1)); - penalty_temp1 = penalty_here43; - penalty_here43 = (query_value-subject_value43) * (query_value-subject_value43) + min(penalty_here42, min(penalty_here43, penalty_temp0)); - penalty_temp0 = penalty_here44; - penalty_here44 = (query_value-subject_value44) * (query_value-subject_value44) + min(penalty_here43, min(penalty_here44, penalty_temp1)); - penalty_temp1 = penalty_here45; - penalty_here45 = (query_value-subject_value45) * (query_value-subject_value45) + min(penalty_here44, min(penalty_here45, penalty_temp0)); - penalty_temp0 = penalty_here46; - penalty_here46 = (query_value-subject_value46) * (query_value-subject_value46) + min(penalty_here45, min(penalty_here46, penalty_temp1)); - penalty_temp1 = penalty_here47; - penalty_here47 = (query_value-subject_value47) * (query_value-subject_value47) + min(penalty_here46, min(penalty_here47, penalty_temp0)); - - penalty_temp0 = penalty_here48; - penalty_here48 = (query_value-subject_value48) * (query_value-subject_value48) + min(penalty_here47, min(penalty_here48, penalty_temp1)); - penalty_temp1 = penalty_here49; - penalty_here49 = (query_value-subject_value49) * (query_value-subject_value49) + min(penalty_here48, min(penalty_here49, penalty_temp0)); - penalty_temp0 = penalty_here50; - penalty_here50 = (query_value-subject_value50) * (query_value-subject_value50) + min(penalty_here49, min(penalty_here50, penalty_temp1)); - penalty_temp1 = penalty_here51; - penalty_here51 = (query_value-subject_value51) * (query_value-subject_value51) + min(penalty_here50, min(penalty_here51, penalty_temp0)); - penalty_temp0 = penalty_here52; - penalty_here52 = (query_value-subject_value52) * (query_value-subject_value52) + min(penalty_here51, min(penalty_here52, penalty_temp1)); - penalty_temp1 = penalty_here53; - penalty_here53 = (query_value-subject_value53) * (query_value-subject_value53) + min(penalty_here52, min(penalty_here53, penalty_temp0)); - penalty_temp0 = penalty_here54; - penalty_here54 = (query_value-subject_value54) * (query_value-subject_value54) + min(penalty_here53, min(penalty_here54, penalty_temp1)); - penalty_temp1 = penalty_here55; - penalty_here55 = (query_value-subject_value55) * (query_value-subject_value55) + min(penalty_here54, min(penalty_here55, penalty_temp0)); - penalty_temp0 = penalty_here56; - penalty_here56 = (query_value-subject_value56) * (query_value-subject_value56) + min(penalty_here55, min(penalty_here56, penalty_temp1)); - penalty_temp1 = penalty_here57; - penalty_here57 = (query_value-subject_value57) * (query_value-subject_value57) + min(penalty_here56, min(penalty_here57, penalty_temp0)); - penalty_temp0 = penalty_here58; - penalty_here58 = (query_value-subject_value58) * (query_value-subject_value58) + min(penalty_here57, min(penalty_here58, penalty_temp1)); - penalty_temp1 = penalty_here59; - penalty_here59 = (query_value-subject_value59) * (query_value-subject_value59) + min(penalty_here58, min(penalty_here59, penalty_temp0)); - penalty_temp0 = penalty_here60; - penalty_here60 = (query_value-subject_value60) * (query_value-subject_value60) + min(penalty_here59, min(penalty_here60, penalty_temp1)); - penalty_temp1 = penalty_here61; - penalty_here61 = (query_value-subject_value61) * (query_value-subject_value61) + min(penalty_here60, min(penalty_here61, penalty_temp0)); - penalty_temp0 = penalty_here62; - penalty_here62 = (query_value-subject_value62) * (query_value-subject_value62) + min(penalty_here61, min(penalty_here62, penalty_temp1)); - penalty_here63 = (query_value-subject_value63) * (query_value-subject_value63) + min(penalty_here62, min(penalty_here63, penalty_temp0)); - - if(thid == blockDim.x-1) Dist[blid] = penalty_here63; -} - - -#endif diff --git a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_255.cuh b/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_255.cuh deleted file mode 100755 index f3f5bed..0000000 --- a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_255.cuh +++ /dev/null @@ -1,189 +0,0 @@ -#ifndef SHFL_FULLDTW_255 -#define SHFL_FULLDTW_255 - -// 8 values per thread no shared memory -template < - typename index_t, - typename value_t> __global__ -void shfl_FullDTW_255( // was DTW_fast_256_shuffle_kernel_no_shared_memory - const value_t * Query, - value_t * Subject, - value_t * Dist, - index_t num_entries, - index_t num_features) { - - const index_t blid = blockIdx.x; - const index_t thid = threadIdx.x; - const index_t lane = num_features+1; - const index_t base = blid*num_features; - const index_t WARP_SIZE = 32; - const index_t l = thid; - - - //extern __shared__ value_t Subject_cache[]; - - value_t penalty_left = INFINITY; - value_t penalty_diag = 0; // INFINITY; - value_t penalty_here0 = INFINITY; // 0; - value_t penalty_here1 = INFINITY; // 0; - value_t penalty_here2 = INFINITY; // 0; - value_t penalty_here3 = INFINITY; // 0; - value_t penalty_here4 = INFINITY; // 0; - value_t penalty_here5 = INFINITY; // 0; - value_t penalty_here6 = INFINITY; // 0; - value_t penalty_here7 = INFINITY; // 0; - value_t penalty_temp0; - value_t penalty_temp1; - - // Init shared memeory for right column - //for (index_t l = thid; l < lane; l += blockDim.x) - // Subject_cache[l] = INFINITY; - //__syncthreads(); - - //index_t iter = 0; - - //for (index_t l = thid; l < lane/8; l += blockDim.x) { - // const index_t iter = l/WARP_SIZE; - if (thid == 0) { - //if (iter > 0) penalty_left = Subject_cache[1]; else penalty_left = INFINITY; - penalty_left = INFINITY; - penalty_diag = INFINITY; - penalty_here0 = INFINITY; // 0; - penalty_here1 = INFINITY; // 0; - penalty_here2 = INFINITY; // 0; - penalty_here3 = INFINITY; // 0; - penalty_here4 = INFINITY; // 0; - penalty_here5 = INFINITY; // 0; - penalty_here6 = INFINITY; // 0; - penalty_here7 = INFINITY; // 0; - } - - //const value_t subject_value = Subject[base+l-1]; - const value_t subject_value0 = l == 0 ? 0 : Subject[base+8*l-1]; - const value_t subject_value1 = Subject[base+8*l-0]; - const value_t subject_value2 = Subject[base+8*l+1]; - const value_t subject_value3 = Subject[base+8*l+2]; - const value_t subject_value4 = Subject[base+8*l+3]; - const value_t subject_value5 = Subject[base+8*l+4]; - const value_t subject_value6 = Subject[base+8*l+5]; - const value_t subject_value7 = Subject[base+8*l+6]; - //if (blid == 0) Dist[2*l] = subject_value0; - //if (blid == 0) Dist[2*l+1] = subject_value1; - index_t counter = 1; - value_t query_value = INFINITY; - value_t new_query_value = Query[thid]; - if (thid == 0) query_value = new_query_value; - if (thid == 0) penalty_here1 = 0; // (query_value - subject_value0)*(query_value - subject_value0); - //penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here1, 1, 32); - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //const index_t j = l; - //if (blid == 0 && thid == 31 && iter == 0) Dist[0] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[2*thid+1] = penalty_here1; - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = INFINITY; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here7, 1, 32); - - //if (iter && thid == 0) penalty_left = Subject_cache[2+1]; - if (thid == 0) penalty_left = INFINITY; - - for (index_t k = 3; k < lane+WARP_SIZE-1; k++) { - const index_t i = k-l; - //outside = k <= l || i >= lane; - - //const value_t residue = outside ? INFINITY : Query[i-1]-subject_value; - //const value_t residue = outside ? INFINITY : query_value-subject_value; - //if (thid == 0 && iter == 0 && k == 2) penalty_temp = INFINITY; else - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)+1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid+1] = penalty_here1; - - //if (counter%32 == 0 && counter > 1) new_query_value = Query[i+2*thid-1]; - if (counter%32 == 0) new_query_value = Query[i+2*thid-1]; - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //if (thid == 0) if (!outside) Dist[counter] = query_value; else Dist[counter] = 0; - counter++; - - // save the right column - //if (!outside && thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[i] = penalty_here; // TO DO: replace this by shhffles - //if (iter < lane/WARP_SIZE-1 && thid == 31 && k>l) Subject_cache[i] = penalty_here7; - - // shuffle the penalty - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here7, 1, 32); - //if (thid == 0 && !outside) penalty_left = Subject_cache[i+1]; // TO DO: replace by shuffles - // if (iter > 0 && thid == 0 && k>l) penalty_left = Subject_cache[i+1]; - //if (iter && thid == 0) penalty_left = Subject_cache[i+1]; - //if (!iter && thid == 0) penalty_left = INFINITY; - if (thid == 0) penalty_left = INFINITY; - - //if (thid == 0 && k>l) if (iter > 0) penalty_left = Subject_cache[i+1]; else penalty_left = INFINITY; - //if (thid == 0 && k>l && iter > 0) penalty_left = Subject_cache[i+1]; - } - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - //if (thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[lane+(iter+1)*WARP_SIZE-1-l] = penalty_here7; - //iter++; - - if(thid == blockDim.x-1) Dist[blid] = penalty_here7; -} - - - -#endif diff --git a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_511.cuh b/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_511.cuh deleted file mode 100755 index fdf276d..0000000 --- a/baleen/_cuda_dtw/cudtw/include/kernels/SHFL_FULLDTW_511.cuh +++ /dev/null @@ -1,267 +0,0 @@ -#ifndef SHFL_FULLDTW_511 -#define SHFL_FULLDTW_511 - -// 16 values per thread no shared memory -template < - typename index_t, - typename value_t> __global__ -void shfl_FullDTW_511( // was DTW_fast_512_shuffle_kernel_no_shared_memory - const value_t * Query, - value_t * Subject, - value_t * Dist, - index_t num_entries, - index_t num_features) { - - const index_t blid = blockIdx.x; - const index_t thid = threadIdx.x; - const index_t lane = num_features+1; - const index_t base = blid*num_features; - const index_t WARP_SIZE = 32; - const index_t l = thid; - - //extern __shared__ value_t Subject_cache[]; - - value_t penalty_left = INFINITY; - value_t penalty_diag = 0; // INFINITY; - value_t penalty_here0 = INFINITY; // 0; - value_t penalty_here1 = INFINITY; // 0; - value_t penalty_here2 = INFINITY; // 0; - value_t penalty_here3 = INFINITY; // 0; - value_t penalty_here4 = INFINITY; // 0; - value_t penalty_here5 = INFINITY; // 0; - value_t penalty_here6 = INFINITY; // 0; - value_t penalty_here7 = INFINITY; // 0; - value_t penalty_here8 = INFINITY; // 0; - value_t penalty_here9 = INFINITY; // 0; - value_t penalty_here10 = INFINITY; // 0; - value_t penalty_here11 = INFINITY; // 0; - value_t penalty_here12 = INFINITY; // 0; - value_t penalty_here13 = INFINITY; // 0; - value_t penalty_here14 = INFINITY; // 0; - value_t penalty_here15 = INFINITY; // 0; - value_t penalty_temp0; - value_t penalty_temp1; - - // Init shared memeory for right column - //for (index_t l = thid; l < lane; l += blockDim.x) - // Subject_cache[l] = INFINITY; - //__syncthreads(); - - //index_t iter = 0; - - //for (index_t l = thid; l < lane/16; l += blockDim.x) { - // const index_t iter = l/WARP_SIZE; - if (thid == 0) { - penalty_left = INFINITY; - penalty_diag = INFINITY; - penalty_here0 = INFINITY; // 0; - penalty_here1 = INFINITY; // 0; - penalty_here2 = INFINITY; // 0; - penalty_here3 = INFINITY; // 0; - penalty_here4 = INFINITY; // 0; - penalty_here5 = INFINITY; // 0; - penalty_here6 = INFINITY; // 0; - penalty_here7 = INFINITY; // 0; - penalty_here8 = INFINITY; // 0; - penalty_here9 = INFINITY; // 0; - penalty_here10 = INFINITY; // 0; - penalty_here11 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here13 = INFINITY; // 0; - penalty_here14 = INFINITY; // 0; - penalty_here15 = INFINITY; // 0; - } - - //const value_t subject_value = Subject[base+l-1]; - const value_t subject_value0 = l == 0 ? 0 : Subject[base+16*l-1]; - const value_t subject_value1 = Subject[base+16*l-0]; - const value_t subject_value2 = Subject[base+16*l+1]; - const value_t subject_value3 = Subject[base+16*l+2]; - const value_t subject_value4 = Subject[base+16*l+3]; - const value_t subject_value5 = Subject[base+16*l+4]; - const value_t subject_value6 = Subject[base+16*l+5]; - const value_t subject_value7 = Subject[base+16*l+6]; - const value_t subject_value8 = Subject[base+16*l+7]; - const value_t subject_value9 = Subject[base+16*l+8]; - const value_t subject_value10 = Subject[base+16*l+9]; - const value_t subject_value11 = Subject[base+16*l+10]; - const value_t subject_value12 = Subject[base+16*l+11]; - const value_t subject_value13 = Subject[base+16*l+12]; - const value_t subject_value14 = Subject[base+16*l+13]; - const value_t subject_value15 = Subject[base+16*l+14]; - //if (blid == 0) Dist[2*l] = subject_value0; - //if (blid == 0) Dist[2*l+1] = subject_value1; - index_t counter = 1; - value_t query_value = INFINITY; - value_t new_query_value = Query[thid]; - if (thid == 0) query_value = new_query_value; - if (thid == 0) penalty_here1 = 0; // (query_value - subject_value0)*(query_value - subject_value0); - //penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here1, 1, 32); - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //const index_t j = l; - //if (blid == 0 && thid == 31 && iter == 0) Dist[0] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[2*thid+1] = penalty_here1; - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = INFINITY; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here15, 1, 32); - - //if (iter && thid == 0) penalty_left = Subject_cache[2+1]; - if (thid == 0) penalty_left = INFINITY; - - for (index_t k = 3; k < lane+WARP_SIZE-1; k++) { - const index_t i = k-l; - //outside = k <= l || i >= lane; - - //const value_t residue = outside ? INFINITY : Query[i-1]-subject_value; - //const value_t residue = outside ? INFINITY : query_value-subject_value; - //if (thid == 0 && iter == 0 && k == 2) penalty_temp = INFINITY; else - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - //if (i==2) penalty_temp1 = INFINITY; else penalty_temp1 = penalty_here1; // -> move before main loop!!! - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)+1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid+1] = penalty_here1; - - //if (counter%32 == 0 && counter > 1) new_query_value = Query[i+2*thid-1]; - if (counter%32 == 0) new_query_value = Query[i+2*thid-1]; - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - //if (thid == 0) if (!outside) Dist[counter] = query_value; else Dist[counter] = 0; - counter++; - - // save the right column - //if (!outside && thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[i] = penalty_here; // TO DO: replace this by shhffles - //if (iter < lane/WARP_SIZE-1 && thid == 31 && k>l) Subject_cache[i] = penalty_here15; - - // shuffle the penalty - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here15, 1, 32); - //if (thid == 0 && !outside) penalty_left = Subject_cache[i+1]; // TO DO: replace by shuffles - // if (iter > 0 && thid == 0 && k>l) penalty_left = Subject_cache[i+1]; - //if (iter && thid == 0) penalty_left = Subject_cache[i+1]; - //if (!iter && thid == 0) penalty_left = INFINITY; - if (thid == 0) penalty_left = INFINITY; - - //if (thid == 0 && k>l) if (iter > 0) penalty_left = Subject_cache[i+1]; else penalty_left = INFINITY; - //if (thid == 0 && k>l && iter > 0) penalty_left = Subject_cache[i+1]; - } - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - //if (thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[lane+(iter+1)*WARP_SIZE-1-l] = penalty_here15; - //iter++; - - if(thid == blockDim.x-1) Dist[blid] = penalty_here15; -} - - - -#endif diff --git a/baleen/_cuda_dtw/cudtw/include/kernels/SUB_WARP.cuh b/baleen/_cuda_dtw/cudtw/include/kernels/SUB_WARP.cuh deleted file mode 100755 index 8040b55..0000000 --- a/baleen/_cuda_dtw/cudtw/include/kernels/SUB_WARP.cuh +++ /dev/null @@ -1,407 +0,0 @@ -#ifndef SUB_WARP_DTW -#define SUB_WARP_DTW - - -// 32 values per thread no shared memory -template < - int group_size, - typename index_t, - typename value_t> __global__ -void sub_warp_DTW( // replace DTW_1024_sub_warp - value_t * Subject, - value_t * Dist, - index_t num_entries, - index_t num_features) { - - const index_t blid = blockIdx.x; - const index_t thid = threadIdx.x; - const index_t l = thid; - const index_t lane = num_features+1; - //const index_t WARP_SIZE = 32; - //const index_t group_size = WARP_SIZE/(1024/lane); - const index_t base = (32/group_size)*blid*num_features; - - //extern __shared__ value_t Subject_cache[]; - - value_t penalty_left = INFINITY; - value_t penalty_diag = 0; // INFINITY; - value_t penalty_here0 = INFINITY; // 0; - value_t penalty_here1 = INFINITY; // 0; - value_t penalty_here2 = INFINITY; // 0; - value_t penalty_here3 = INFINITY; // 0; - value_t penalty_here4 = INFINITY; // 0; - value_t penalty_here5 = INFINITY; // 0; - value_t penalty_here6 = INFINITY; // 0; - value_t penalty_here7 = INFINITY; // 0; - value_t penalty_here8 = INFINITY; // 0; - value_t penalty_here9 = INFINITY; // 0; - value_t penalty_here10 = INFINITY; // 0; - value_t penalty_here11 = INFINITY; // 0; - value_t penalty_here12 = INFINITY; // 0; - value_t penalty_here13 = INFINITY; // 0; - value_t penalty_here14 = INFINITY; // 0; - value_t penalty_here15 = INFINITY; // 0; - value_t penalty_here16 = INFINITY; // 0; - value_t penalty_here17 = INFINITY; // 0; - value_t penalty_here18 = INFINITY; // 0; - value_t penalty_here19 = INFINITY; // 0; - value_t penalty_here20 = INFINITY; // 0; - value_t penalty_here21 = INFINITY; // 0; - value_t penalty_here22 = INFINITY; // 0; - value_t penalty_here23 = INFINITY; // 0; - value_t penalty_here24 = INFINITY; // 0; - value_t penalty_here25 = INFINITY; // 0; - value_t penalty_here26 = INFINITY; // 0; - value_t penalty_here27 = INFINITY; // 0; - value_t penalty_here28 = INFINITY; // 0; - value_t penalty_here29 = INFINITY; // 0; - value_t penalty_here30 = INFINITY; // 0; - value_t penalty_here31 = INFINITY; // 0; - value_t penalty_temp0; - value_t penalty_temp1; - - // Init shared memeory for right column - //for (index_t l = thid; l < lane; l += blockDim.x) - // Subject_cache[l] = INFINITY; - //__syncthreads(); - - //index_t iter = 0; - - if (thid % group_size == 0) { - penalty_left = INFINITY; - penalty_diag = INFINITY; - penalty_here0 = INFINITY; // 0; - penalty_here1 = INFINITY; // 0; - penalty_here2 = INFINITY; // 0; - penalty_here3 = INFINITY; // 0; - penalty_here4 = INFINITY; // 0; - penalty_here5 = INFINITY; // 0; - penalty_here6 = INFINITY; // 0; - penalty_here7 = INFINITY; // 0; - penalty_here8 = INFINITY; // 0; - penalty_here9 = INFINITY; // 0; - penalty_here10 = INFINITY; // 0; - penalty_here11 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here13 = INFINITY; // 0; - penalty_here14 = INFINITY; // 0; - penalty_here15 = INFINITY; // 0; - penalty_here16 = INFINITY; // 0; - penalty_here17 = INFINITY; // 0; - penalty_here18 = INFINITY; // 0; - penalty_here19 = INFINITY; // 0; - penalty_here20 = INFINITY; // 0; - penalty_here21 = INFINITY; // 0; - penalty_here22 = INFINITY; // 0; - penalty_here23 = INFINITY; // 0; - penalty_here24 = INFINITY; // 0; - penalty_here25 = INFINITY; // 0; - penalty_here26 = INFINITY; // 0; - penalty_here27 = INFINITY; // 0; - penalty_here28 = INFINITY; // 0; - penalty_here29 = INFINITY; // 0; - penalty_here30 = INFINITY; // 0; - penalty_here31 = INFINITY; // 0; - } - - //if (thid >= 8 && thid < 16) { - // subject_value0 = Subject[base+num_features+16*(l-8)-1]; - //const value_t subject_value0 = Subject[base+(thid/group_size)*num_features + 32*(l-(group_size*(thid/group_size)))-1]; - - //const value_t subject_value0 = base+l == 0 ? 0 : Subject[base+32*l-1]; - //const value_t subject_value1 = Subject[base+32*l-0]; - //const value_t subject_value2 = Subject[base+32*l+1]; - - const value_t subject_value0 = base+(thid/group_size)*num_features + (l%group_size) == 0 ? 0 : Subject[base+(thid/group_size)*num_features + 32*(l%group_size)-1]; - //const value_t subject_value0 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)-1]; - const value_t subject_value1 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)-0]; - const value_t subject_value2 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+1]; - const value_t subject_value3 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+2]; - const value_t subject_value4 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+3]; - const value_t subject_value5 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+4]; - const value_t subject_value6 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+5]; - const value_t subject_value7 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+6]; - const value_t subject_value8 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+7]; - const value_t subject_value9 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+8]; - const value_t subject_value10 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+9]; - const value_t subject_value11 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+10]; - const value_t subject_value12 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+11]; - const value_t subject_value13 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+12]; - const value_t subject_value14 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+13]; - const value_t subject_value15 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+14]; - - const value_t subject_value16 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+15]; - const value_t subject_value17 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+16]; - const value_t subject_value18 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+17]; - const value_t subject_value19 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+18]; - const value_t subject_value20 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+19]; - const value_t subject_value21 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+20]; - const value_t subject_value22 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+21]; - const value_t subject_value23 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+22]; - const value_t subject_value24 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+23]; - const value_t subject_value25 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+24]; - const value_t subject_value26 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+25]; - const value_t subject_value27 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+26]; - const value_t subject_value28 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+27]; - const value_t subject_value29 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+28]; - const value_t subject_value30 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+29]; - const value_t subject_value31 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+30]; - //if (blid == 0) Dist[2*l] = subject_value0; - //if (blid == 0) Dist[2*l+1] = subject_value1; - index_t counter = 1; - value_t query_value = INFINITY; - value_t new_query_value = cQuery[thid%group_size]; - if (thid % group_size == 0) query_value = new_query_value; - if (thid % group_size == 0) penalty_here1 = 0; // (query_value - subject_value0)*(query_value - subject_value0); - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = INFINITY; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid % group_size == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here31, 1, 32); - - //penalty_left = INFINITY; - - if (thid % group_size == 0) penalty_left = INFINITY; - - const index_t group_id = thid%group_size; - - for (index_t k = 3; k < lane+group_size-1; k++) { - const index_t i = k-l%group_size; - //outside = k <= l || i >= lane; - - //const value_t residue = outside ? INFINITY : cQuery[i-1]-subject_value; - //const value_t residue = outside ? INFINITY : query_value-subject_value; - //if (thid == 0 && iter == 0 && k == 2) penalty_temp = INFINITY; else - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)+1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid+1] = penalty_here1; - - //if (counter%32 == 0 && counter > 1) new_query_value = cQuery[i+2*thid-1]; - if (counter%group_size == 0) new_query_value = cQuery[i+2*(thid%group_size)-1]; - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (!group_id) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - // save the right column - //if (!outside && thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[i] = penalty_here; // TO DO: replace this by shhffles - //if (iter < lane/WARP_SIZE-1 && thid == 31 && k>l) Subject_cache[i] = penalty_here31; - - // shuffle the penalty - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here31, 1, 32); - //if (thid == 0 && !outside) penalty_left = Subject_cache[i+1]; // TO DO: replace by shuffles - // if (iter > 0 && thid == 0 && k>l) penalty_left = Subject_cache[i+1]; - //if (iter && thid == 0) penalty_left = Subject_cache[i+1]; - //if (!iter && thid == 0) penalty_left = INFINITY; - if (!group_id) penalty_left = INFINITY; - - //if (thid == 0 && k>l) if (iter > 0) penalty_left = Subject_cache[i+1]; else penalty_left = INFINITY; - //if (thid == 0 && k>l && iter > 0) penalty_left = Subject_cache[i+1]; - } - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_diag)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - if(thid % group_size == group_size-1) Dist[(32/group_size)*blid+thid/group_size] = penalty_here31; -} - - -#endif diff --git a/baleen/_cuda_dtw/cudtw/include/kernels/SUB_WARP_MULTI_QUERY.cuh b/baleen/_cuda_dtw/cudtw/include/kernels/SUB_WARP_MULTI_QUERY.cuh deleted file mode 100755 index f3b3d32..0000000 --- a/baleen/_cuda_dtw/cudtw/include/kernels/SUB_WARP_MULTI_QUERY.cuh +++ /dev/null @@ -1,408 +0,0 @@ -#ifndef sub_warp_DTW_MULTI_QUERY -#define SUB_WARP_DTW_MULTI_QUERY - - -// 32 values per thread no shared memory -template < - int group_size, - int num_queries, - typename index_t, - typename value_t> __global__ -void sub_warp_DTW_multi_query( - value_t * Subject, - value_t * Dist, - index_t num_entries, - index_t num_features - //index_t num_queries - ) { - - const index_t blid = blockIdx.x; - const index_t thid = threadIdx.x; - const index_t l = thid; - const index_t lane = num_features+1; - const index_t base = (32/group_size)*blid*num_features; - - value_t penalty_left = INFINITY; - value_t penalty_diag = 0; // INFINITY; - value_t penalty_here0 = INFINITY; // 0; - value_t penalty_here1 = INFINITY; // 0; - value_t penalty_here2 = INFINITY; // 0; - value_t penalty_here3 = INFINITY; // 0; - value_t penalty_here4 = INFINITY; // 0; - value_t penalty_here5 = INFINITY; // 0; - value_t penalty_here6 = INFINITY; // 0; - value_t penalty_here7 = INFINITY; // 0; - value_t penalty_here8 = INFINITY; // 0; - value_t penalty_here9 = INFINITY; // 0; - value_t penalty_here10 = INFINITY; // 0; - value_t penalty_here11 = INFINITY; // 0; - value_t penalty_here12 = INFINITY; // 0; - value_t penalty_here13 = INFINITY; // 0; - value_t penalty_here14 = INFINITY; // 0; - value_t penalty_here15 = INFINITY; // 0; - value_t penalty_here16 = INFINITY; // 0; - value_t penalty_here17 = INFINITY; // 0; - value_t penalty_here18 = INFINITY; // 0; - value_t penalty_here19 = INFINITY; // 0; - value_t penalty_here20 = INFINITY; // 0; - value_t penalty_here21 = INFINITY; // 0; - value_t penalty_here22 = INFINITY; // 0; - value_t penalty_here23 = INFINITY; // 0; - value_t penalty_here24 = INFINITY; // 0; - value_t penalty_here25 = INFINITY; // 0; - value_t penalty_here26 = INFINITY; // 0; - value_t penalty_here27 = INFINITY; // 0; - value_t penalty_here28 = INFINITY; // 0; - value_t penalty_here29 = INFINITY; // 0; - value_t penalty_here30 = INFINITY; // 0; - value_t penalty_here31 = INFINITY; // 0; - value_t penalty_temp0; - value_t penalty_temp1; - // Init shared memeory for right column - //for (index_t l = thid; l < lane; l += blockDim.x) - // Subject_cache[l] = INFINITY; - //__syncthreads(); - - //index_t iter = 0; - - //if (thid >= 8 && thid < 16) { - // subject_value0 = Subject[base+num_features+16*(l-8)-1]; - //const value_t subject_value0 = Subject[base+(thid/group_size)*num_features + 32*(l-(group_size*(thid/group_size)))-1]; - - //const value_t subject_value = Subject[base+l-1]; - //const value_t subject_value0 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)-1]; - const value_t subject_value0 = base+(thid/group_size)*num_features + (l%group_size) == 0 ? 0 : Subject[base+(thid/group_size)*num_features + 32*(l%group_size)-1]; - - const value_t subject_value1 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)-0]; - const value_t subject_value2 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+1]; - const value_t subject_value3 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+2]; - const value_t subject_value4 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+3]; - const value_t subject_value5 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+4]; - const value_t subject_value6 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+5]; - const value_t subject_value7 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+6]; - const value_t subject_value8 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+7]; - const value_t subject_value9 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+8]; - const value_t subject_value10 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+9]; - const value_t subject_value11 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+10]; - const value_t subject_value12 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+11]; - const value_t subject_value13 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+12]; - const value_t subject_value14 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+13]; - const value_t subject_value15 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+14]; - - const value_t subject_value16 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+15]; - const value_t subject_value17 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+16]; - const value_t subject_value18 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+17]; - const value_t subject_value19 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+18]; - const value_t subject_value20 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+19]; - const value_t subject_value21 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+20]; - const value_t subject_value22 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+21]; - const value_t subject_value23 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+22]; - const value_t subject_value24 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+23]; - const value_t subject_value25 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+24]; - const value_t subject_value26 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+25]; - const value_t subject_value27 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+26]; - const value_t subject_value28 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+27]; - const value_t subject_value29 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+28]; - const value_t subject_value30 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+29]; - const value_t subject_value31 = Subject[base+(thid/group_size)*num_features + 32*(l%group_size)+30]; - //if (blid == 0) Dist[2*l] = subject_value0; - //if (blid == 0) Dist[2*l+1] = subject_value1; - for (index_t query_number = 0; query_number < num_queries; query_number++) { - - penalty_left = INFINITY; - penalty_diag = 0; - penalty_here0 = INFINITY; // 0; - penalty_here1 = INFINITY; // 0; - penalty_here2 = INFINITY; // 0; - penalty_here3 = INFINITY; // 0; - penalty_here4 = INFINITY; // 0; - penalty_here5 = INFINITY; // 0; - penalty_here6 = INFINITY; // 0; - penalty_here7 = INFINITY; // 0; - penalty_here8 = INFINITY; // 0; - penalty_here9 = INFINITY; // 0; - penalty_here10 = INFINITY; // 0; - penalty_here11 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here12 = INFINITY; // 0; - penalty_here13 = INFINITY; // 0; - penalty_here14 = INFINITY; // 0; - penalty_here15 = INFINITY; // 0; - penalty_here16 = INFINITY; // 0; - penalty_here17 = INFINITY; // 0; - penalty_here18 = INFINITY; // 0; - penalty_here19 = INFINITY; // 0; - penalty_here20 = INFINITY; // 0; - penalty_here21 = INFINITY; // 0; - penalty_here22 = INFINITY; // 0; - penalty_here23 = INFINITY; // 0; - penalty_here24 = INFINITY; // 0; - penalty_here25 = INFINITY; // 0; - penalty_here26 = INFINITY; // 0; - penalty_here27 = INFINITY; // 0; - penalty_here28 = INFINITY; // 0; - penalty_here29 = INFINITY; // 0; - penalty_here30 = INFINITY; // 0; - penalty_here31 = INFINITY; // 0; - - if (thid % group_size == 0) penalty_diag = INFINITY; - - - index_t counter = 1; - value_t query_value = INFINITY; - value_t new_query_value = cQuery[query_number*num_features+thid%group_size]; - if (thid % group_size == 0) query_value = new_query_value; - if (thid % group_size == 0) penalty_here1 = 0; // (query_value - subject_value0)*(query_value - subject_value0); - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = INFINITY; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (thid % group_size == 0) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here31, 1, 32); - - //penalty_left = INFINITY; - - if (thid % group_size == 0) penalty_left = INFINITY; - - const index_t group_id = thid%group_size; - - for (index_t k = 3; k < lane+group_size-1; k++) { - const index_t i = k-l%group_size; - //outside = k <= l || i >= lane; - - //const value_t residue = outside ? INFINITY : cQuery[i-1]-subject_value; - //const value_t residue = outside ? INFINITY : query_value-subject_value; - //if (thid == 0 && iter == 0 && k == 2) penalty_temp = INFINITY; else - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_temp1)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)] = penalty_here0; - //if (blid == 0 && thid == 31 && iter == 0) Dist[2*(k-1)+1] = penalty_here1; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid] = penalty_here0; - //if (blid == 0 && iter == 0) Dist[64*(k-1)+2*thid+1] = penalty_here1; - - //if (counter%32 == 0 && counter > 1) new_query_value = cQuery[i+2*thid-1]; - if (counter%group_size == 0) new_query_value = cQuery[query_number*num_features+i+2*(thid%group_size)-1]; - //if (counter%group_size == 0) new_query_value = cQuery[i+2*(thid%group_size)-1]; - query_value = __shfl_up_sync(0xFFFFFFFF, query_value, 1, 32); - if (!group_id) query_value = new_query_value; - new_query_value = __shfl_down_sync(0xFFFFFFFF, new_query_value, 1, 32); - counter++; - - // save the right column - //if (!outside && thid == 31 && iter < lane/WARP_SIZE-1) Subject_cache[i] = penalty_here; // TO DO: replace this by shhffles - //if (iter < lane/WARP_SIZE-1 && thid == 31 && k>l) Subject_cache[i] = penalty_here31; - - // shuffle the penalty - penalty_diag = penalty_left; - penalty_left = __shfl_up_sync(0xFFFFFFFF, penalty_here31, 1, 32); - //if (thid == 0 && !outside) penalty_left = Subject_cache[i+1]; // TO DO: replace by shuffles - // if (iter > 0 && thid == 0 && k>l) penalty_left = Subject_cache[i+1]; - //if (iter && thid == 0) penalty_left = Subject_cache[i+1]; - //if (!iter && thid == 0) penalty_left = INFINITY; - if (!group_id) penalty_left = INFINITY; - - //if (thid == 0 && k>l) if (iter > 0) penalty_left = Subject_cache[i+1]; else penalty_left = INFINITY; - //if (thid == 0 && k>l && iter > 0) penalty_left = Subject_cache[i+1]; - } - penalty_temp0 = penalty_here0; - penalty_here0 = (query_value-subject_value0) * (query_value-subject_value0) + min(penalty_left, min(penalty_here0, penalty_diag)); - penalty_temp1 = penalty_here1; - penalty_here1 = (query_value-subject_value1) * (query_value-subject_value1) + min(penalty_here0, min(penalty_here1, penalty_temp0)); - penalty_temp0 = penalty_here2; - penalty_here2 = (query_value-subject_value2) * (query_value-subject_value2) + min(penalty_here1, min(penalty_here2, penalty_temp1)); - penalty_temp1 = penalty_here3; - penalty_here3 = (query_value-subject_value3) * (query_value-subject_value3) + min(penalty_here2, min(penalty_here3, penalty_temp0)); - penalty_temp0 = penalty_here4; - penalty_here4 = (query_value-subject_value4) * (query_value-subject_value4) + min(penalty_here3, min(penalty_here4, penalty_temp1)); - penalty_temp1 = penalty_here5; - penalty_here5 = (query_value-subject_value5) * (query_value-subject_value5) + min(penalty_here4, min(penalty_here5, penalty_temp0)); - penalty_temp0 = penalty_here6; - penalty_here6 = (query_value-subject_value6) * (query_value-subject_value6) + min(penalty_here5, min(penalty_here6, penalty_temp1)); - penalty_temp1 = penalty_here7; - penalty_here7 = (query_value-subject_value7) * (query_value-subject_value7) + min(penalty_here6, min(penalty_here7, penalty_temp0)); - penalty_temp0 = penalty_here8; - penalty_here8 = (query_value-subject_value8) * (query_value-subject_value8) + min(penalty_here7, min(penalty_here8, penalty_temp1)); - penalty_temp1 = penalty_here9; - penalty_here9 = (query_value-subject_value9) * (query_value-subject_value9) + min(penalty_here8, min(penalty_here9, penalty_temp0)); - penalty_temp0 = penalty_here10; - penalty_here10 = (query_value-subject_value10) * (query_value-subject_value10) + min(penalty_here9, min(penalty_here10, penalty_temp1)); - penalty_temp1 = penalty_here11; - penalty_here11 = (query_value-subject_value11) * (query_value-subject_value11) + min(penalty_here10, min(penalty_here11, penalty_temp0)); - penalty_temp0 = penalty_here12; - penalty_here12 = (query_value-subject_value12) * (query_value-subject_value12) + min(penalty_here11, min(penalty_here12, penalty_temp1)); - penalty_temp1 = penalty_here13; - penalty_here13 = (query_value-subject_value13) * (query_value-subject_value13) + min(penalty_here12, min(penalty_here13, penalty_temp0)); - penalty_temp0 = penalty_here14; - penalty_here14 = (query_value-subject_value14) * (query_value-subject_value14) + min(penalty_here13, min(penalty_here14, penalty_temp1)); - penalty_temp1 = penalty_here15; - penalty_here15 = (query_value-subject_value15) * (query_value-subject_value15) + min(penalty_here14, min(penalty_here15, penalty_temp0)); - - penalty_temp0 = penalty_here16; - penalty_here16 = (query_value-subject_value16) * (query_value-subject_value16) + min(penalty_here15, min(penalty_here16, penalty_diag)); - penalty_temp1 = penalty_here17; - penalty_here17 = (query_value-subject_value17) * (query_value-subject_value17) + min(penalty_here16, min(penalty_here17, penalty_temp0)); - penalty_temp0 = penalty_here18; - penalty_here18 = (query_value-subject_value18) * (query_value-subject_value18) + min(penalty_here17, min(penalty_here18, penalty_temp1)); - penalty_temp1 = penalty_here19; - penalty_here19 = (query_value-subject_value19) * (query_value-subject_value19) + min(penalty_here18, min(penalty_here19, penalty_temp0)); - penalty_temp0 = penalty_here20; - penalty_here20 = (query_value-subject_value20) * (query_value-subject_value20) + min(penalty_here19, min(penalty_here20, penalty_temp1)); - penalty_temp1 = penalty_here21; - penalty_here21 = (query_value-subject_value21) * (query_value-subject_value21) + min(penalty_here20, min(penalty_here21, penalty_temp0)); - penalty_temp0 = penalty_here22; - penalty_here22 = (query_value-subject_value22) * (query_value-subject_value22) + min(penalty_here21, min(penalty_here22, penalty_temp1)); - penalty_temp1 = penalty_here23; - penalty_here23 = (query_value-subject_value23) * (query_value-subject_value23) + min(penalty_here22, min(penalty_here23, penalty_temp0)); - penalty_temp0 = penalty_here24; - penalty_here24 = (query_value-subject_value24) * (query_value-subject_value24) + min(penalty_here23, min(penalty_here24, penalty_temp1)); - penalty_temp1 = penalty_here25; - penalty_here25 = (query_value-subject_value25) * (query_value-subject_value25) + min(penalty_here24, min(penalty_here25, penalty_temp0)); - penalty_temp0 = penalty_here26; - penalty_here26 = (query_value-subject_value26) * (query_value-subject_value26) + min(penalty_here25, min(penalty_here26, penalty_temp1)); - penalty_temp1 = penalty_here27; - penalty_here27 = (query_value-subject_value27) * (query_value-subject_value27) + min(penalty_here26, min(penalty_here27, penalty_temp0)); - penalty_temp0 = penalty_here28; - penalty_here28 = (query_value-subject_value28) * (query_value-subject_value28) + min(penalty_here27, min(penalty_here28, penalty_temp1)); - penalty_temp1 = penalty_here29; - penalty_here29 = (query_value-subject_value29) * (query_value-subject_value29) + min(penalty_here28, min(penalty_here29, penalty_temp0)); - penalty_temp0 = penalty_here30; - penalty_here30 = (query_value-subject_value30) * (query_value-subject_value30) + min(penalty_here29, min(penalty_here30, penalty_temp1)); - penalty_here31 = (query_value-subject_value31) * (query_value-subject_value31) + min(penalty_here30, min(penalty_here31, penalty_temp0)); - - if(thid % group_size == group_size-1) Dist[query_number*num_entries+(32/group_size)*blockIdx.x+thid/group_size] = penalty_here31; - } - //if(thid % group_size == group_size-1) Dist[(32/group_size)*blid+thid/group_size] = penalty_here31; -} - -#endif diff --git a/baleen/_cuda_dtw/cudtw_wrapper.cu b/baleen/_cuda_dtw/cudtw_wrapper.cu deleted file mode 100644 index fd9af03..0000000 --- a/baleen/_cuda_dtw/cudtw_wrapper.cu +++ /dev/null @@ -1,762 +0,0 @@ -// cudtw_wrapper.cu — cuDTW++ warp-shuffle kernel wrapper for baleen -// -// Query is passed through global memory (no __constant__ serialisation), -// enabling CUDA-stream concurrency across pairwise queries. - -#include -#include -#include -#include -#include -#include -#include - -// -------------------------------------------------------------------------- -// cuDTW++ kernel includes -// -------------------------------------------------------------------------- -typedef float value_t; -typedef uint64_t index_t; - -#include "cudtw/include/DTW.hpp" - -// -------------------------------------------------------------------------- -// Error-checking macro -// -------------------------------------------------------------------------- -#define CUDA_CHECK(call) \ - do { \ - cudaError_t err = (call); \ - if (err != cudaSuccess) { \ - fprintf(stderr, "CUDA error at %s:%d — %s\n", \ - __FILE__, __LINE__, cudaGetErrorString(err)); \ - return -1; \ - } \ - } while (0) - -// -------------------------------------------------------------------------- -// Length bucketing -// -------------------------------------------------------------------------- -static int select_bucket(int len) { - if (len <= 127) return 127; - if (len <= 255) return 255; - if (len <= 511) return 511; - if (len <= 1023) return 1023; - if (len <= 2047) return 2047; - return -1; // too long — caller must resample -} - -// -------------------------------------------------------------------------- -// Single-pair DTW (kept for API compat; not perf-critical) -// -------------------------------------------------------------------------- -int opendba_dtw_cuda( - const float *seq1, size_t len1, - const float *seq2, size_t len2, - float *out_distance) -{ - if (!seq1 || !seq2 || !out_distance || len1 == 0 || len2 == 0) { - fprintf(stderr, "cudtw_wrapper: invalid input\n"); - return -1; - } - - size_t max_len = (len1 > len2) ? len1 : len2; - int bucket = select_bucket((int)max_len); - if (bucket < 0) { - fprintf(stderr, "cudtw_wrapper: sequence too long (%zu), max 2047\n", max_len); - return -1; - } - - float *h_query = (float *)calloc(bucket, sizeof(float)); - float *h_subject = (float *)calloc(bucket, sizeof(float)); - if (!h_query || !h_subject) { free(h_query); free(h_subject); return -1; } - memcpy(h_query, seq1, len1 * sizeof(float)); - memcpy(h_subject, seq2, len2 * sizeof(float)); - - float *d_query = nullptr, *d_subject = nullptr, *d_dist = nullptr; - CUDA_CHECK(cudaMalloc(&d_query, bucket * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_subject, bucket * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_dist, sizeof(float))); - CUDA_CHECK(cudaMemcpy(d_query, h_query, bucket * sizeof(float), - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(d_subject, h_subject, bucket * sizeof(float), - cudaMemcpyHostToDevice)); - free(h_query); free(h_subject); - - bool ok = FullDTW::dist( - d_query, d_subject, d_dist, (index_t)1, (index_t)bucket); - if (!ok) { - cudaFree(d_query); cudaFree(d_subject); cudaFree(d_dist); - fprintf(stderr, "cudtw_wrapper: unsupported bucket %d\n", bucket); - return -1; - } - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaDeviceSynchronize()); - - float sq_cost; - CUDA_CHECK(cudaMemcpy(&sq_cost, d_dist, sizeof(float), cudaMemcpyDeviceToHost)); - *out_distance = sqrtf(sq_cost); - - cudaFree(d_query); - cudaFree(d_subject); - cudaFree(d_dist); - return 0; -} - -// -------------------------------------------------------------------------- -// Cleanup -// -------------------------------------------------------------------------- -void opendba_dtw_cleanup() { - int count = 0; - if (cudaGetDeviceCount(&count) == cudaSuccess) { - for (int i = 0; i < count; i++) { - cudaSetDevice(i); - cudaDeviceReset(); - } - } -} - -// -------------------------------------------------------------------------- -// Internal: stream-concurrent pairwise DTW over one padded subject buffer. -// -// d_subjects: [N * bucket] device buffer (padded to bucket width). -// d_out: [N * N] device output buffer (float32, squared cost). -// Each row i is launched on streams[i % num_streams]: -// Query = d_subjects + i*bucket, Subject = d_subjects, Dist = d_out + i*N. -// Caller is responsible for the final cudaDeviceSynchronize + sqrtf. -// -------------------------------------------------------------------------- -static int launch_pairwise_rows( - const float *d_subjects, float *d_out, - size_t N, int bucket, - cudaStream_t *streams, int num_streams) -{ - for (size_t i = 0; i < N; i++) { - cudaStream_t s = streams[i % (size_t)num_streams]; - bool ok = FullDTW::dist( - d_subjects + i * (size_t)bucket, - const_cast(d_subjects), - d_out + i * N, - (index_t)N, (index_t)bucket, s); - if (!ok) return -1; - } - return 0; -} - -// -------------------------------------------------------------------------- -// Host helper: take a [N*N] float32 squared-cost matrix → sqrtf + zero diag. -// -------------------------------------------------------------------------- -static void finalize_matrix(float *m, size_t N) { - for (size_t i = 0; i < N; i++) { - for (size_t j = 0; j < N; j++) { - m[i * N + j] = (i == j) ? 0.0f : sqrtf(m[i * N + j]); - } - } -} - -// -------------------------------------------------------------------------- -// Batch pairwise DTW (equal-length sequences, single position) -// -------------------------------------------------------------------------- -int opendba_dtw_pairwise_batch( - const float *sequences, - size_t num_sequences, - size_t seq_length, - float *out_distances) -{ - if (!sequences || !out_distances || num_sequences < 2 || seq_length == 0) { - fprintf(stderr, "cudtw_wrapper: invalid pairwise input\n"); - return -1; - } - - int bucket = select_bucket((int)seq_length); - if (bucket < 0) { - fprintf(stderr, "cudtw_wrapper: seq_length %zu > 2047\n", seq_length); - return -1; - } - - const size_t N = num_sequences; - - // Pad all sequences to bucket on host (single contiguous buffer) - float *h_padded = (float *)calloc(N * bucket, sizeof(float)); - if (!h_padded) return -1; - for (size_t i = 0; i < N; i++) - memcpy(h_padded + i * bucket, sequences + i * seq_length, - seq_length * sizeof(float)); - - // Allocate device memory: subjects (doubles as query source) + output - float *d_subjects = nullptr, *d_out = nullptr; - CUDA_CHECK(cudaMalloc(&d_subjects, N * bucket * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_out, N * N * sizeof(float))); - CUDA_CHECK(cudaMemcpy(d_subjects, h_padded, N * bucket * sizeof(float), - cudaMemcpyHostToDevice)); - free(h_padded); - - // Stream pool — sized to min(N, 16) for single-position batch - int num_streams = (N < 16) ? (int)N : 16; - std::vector streams(num_streams); - for (int s = 0; s < num_streams; s++) - CUDA_CHECK(cudaStreamCreate(&streams[s])); - - int rc = launch_pairwise_rows(d_subjects, d_out, N, bucket, - streams.data(), num_streams); - if (rc == 0) rc = (cudaGetLastError() == cudaSuccess) ? 0 : -1; - if (rc == 0) rc = (cudaDeviceSynchronize() == cudaSuccess) ? 0 : -1; - - for (int s = 0; s < num_streams; s++) cudaStreamDestroy(streams[s]); - - if (rc != 0) { - cudaFree(d_subjects); cudaFree(d_out); - return -1; - } - - // One D→H copy, then sqrtf + diagonal zeroing on host - CUDA_CHECK(cudaMemcpy(out_distances, d_out, N * N * sizeof(float), - cudaMemcpyDeviceToHost)); - finalize_matrix(out_distances, N); - - cudaFree(d_subjects); - cudaFree(d_out); - return 0; -} - -// -------------------------------------------------------------------------- -// Variable-length pairwise DTW (single position) -// -------------------------------------------------------------------------- -int opendba_dtw_pairwise_varlen( - const float *sequences, - const size_t *seq_lengths, - size_t num_sequences, - size_t max_length, - float *out_distances) -{ - if (!sequences || !seq_lengths || !out_distances || - num_sequences < 2 || max_length == 0) { - fprintf(stderr, "cudtw_wrapper: invalid varlen input\n"); - return -1; - } - - size_t actual_max = 0; - for (size_t i = 0; i < num_sequences; i++) - if (seq_lengths[i] > actual_max) actual_max = seq_lengths[i]; - - int bucket = select_bucket((int)actual_max); - if (bucket < 0) { - fprintf(stderr, "cudtw_wrapper: max_length %zu > 2047\n", actual_max); - return -1; - } - - const size_t N = num_sequences; - - // Repack to bucket width - float *h_padded = (float *)calloc(N * bucket, sizeof(float)); - if (!h_padded) return -1; - for (size_t i = 0; i < N; i++) - memcpy(h_padded + i * bucket, sequences + i * max_length, - seq_lengths[i] * sizeof(float)); - - float *d_subjects = nullptr, *d_out = nullptr; - CUDA_CHECK(cudaMalloc(&d_subjects, N * bucket * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_out, N * N * sizeof(float))); - CUDA_CHECK(cudaMemcpy(d_subjects, h_padded, N * bucket * sizeof(float), - cudaMemcpyHostToDevice)); - free(h_padded); - - int num_streams = (N < 16) ? (int)N : 16; - std::vector streams(num_streams); - for (int s = 0; s < num_streams; s++) - CUDA_CHECK(cudaStreamCreate(&streams[s])); - - int rc = launch_pairwise_rows(d_subjects, d_out, N, bucket, - streams.data(), num_streams); - if (rc == 0) rc = (cudaGetLastError() == cudaSuccess) ? 0 : -1; - if (rc == 0) rc = (cudaDeviceSynchronize() == cudaSuccess) ? 0 : -1; - - for (int s = 0; s < num_streams; s++) cudaStreamDestroy(streams[s]); - - if (rc != 0) { - cudaFree(d_subjects); cudaFree(d_out); - return -1; - } - - CUDA_CHECK(cudaMemcpy(out_distances, d_out, N * N * sizeof(float), - cudaMemcpyDeviceToHost)); - finalize_matrix(out_distances, N); - - cudaFree(d_subjects); - cudaFree(d_out); - return 0; -} - -// -------------------------------------------------------------------------- -// Multi-position batched pairwise DTW — whole chunk shares one d_subjects / -// one d_out, one H→D, one D→H, one device-sync. Kernels across positions -// and within a position fan out over a stream pool for real concurrency. -// -------------------------------------------------------------------------- -int opendba_dtw_multi_position_pairwise( - const float *all_sequences, - const size_t *all_seq_lengths, - const size_t *position_seq_counts, - size_t num_positions, - size_t global_max_length, - float *out_distances, - int num_cuda_streams, - int device_id) -{ - if (!all_sequences || !all_seq_lengths || !position_seq_counts || - !out_distances || num_positions == 0 || global_max_length == 0) { - fprintf(stderr, "cudtw_wrapper: invalid multi-position input\n"); - return -1; - } - - CUDA_CHECK(cudaSetDevice(device_id)); - - // Pre-scan: per-position bucket + sequence offset + output offset. - // Each position packs its sequences into d_subjects with stride = - // its own bucket (NOT a global max). The cuDTW++ kernel uses - // num_features as both the subject stride and the walk length, so - // buffer stride and kernel num_features MUST match per position — - // otherwise every row beyond 0 reads from padding → garbage. - std::vector buckets(num_positions, 0); - std::vector seq_index_offsets(num_positions, 0); // sequence-count prefix sum - std::vector pos_float_offsets(num_positions, 0); // float offset into d_subjects - std::vector out_offsets(num_positions, 0); - - size_t total_seqs = 0; - size_t total_out = 0; - size_t total_floats = 0; - - for (size_t p = 0; p < num_positions; p++) { - size_t n = position_seq_counts[p]; - seq_index_offsets[p] = total_seqs; - pos_float_offsets[p] = total_floats; - out_offsets[p] = total_out; - - size_t pos_max = 0; - for (size_t i = 0; i < n; i++) { - size_t l = all_seq_lengths[total_seqs + i]; - if (l > pos_max) pos_max = l; - } - int b = (n < 2) ? 0 : select_bucket((int)pos_max); - if (b < 0) { - fprintf(stderr, "cudtw_wrapper: position %zu max_len %zu > 2047\n", - p, pos_max); - return -1; - } - buckets[p] = b; - total_floats += n * (size_t)b; // 0 for n<2 positions - total_seqs += n; - total_out += n * n; - } - - // Degenerate: every position had n<2 — just zero the output and return. - if (total_floats == 0) { - memset(out_distances, 0, total_out * sizeof(float)); - return 0; - } - - // Pack sequences per-position with per-position stride. Skipped entirely - // for n<2 positions (they contribute 0 floats). - float *h_padded = (float *)calloc(total_floats, sizeof(float)); - if (!h_padded) return -1; - - for (size_t p = 0; p < num_positions; p++) { - size_t n = position_seq_counts[p]; - if (n < 2) continue; - size_t b = (size_t)buckets[p]; - size_t seq_off = seq_index_offsets[p]; - size_t pos_base = pos_float_offsets[p]; - for (size_t i = 0; i < n; i++) { - size_t slen = all_seq_lengths[seq_off + i]; - const float *src = all_sequences + (seq_off + i) * global_max_length; - memcpy(h_padded + pos_base + i * b, src, slen * sizeof(float)); - } - } - - float *d_subjects = nullptr, *d_out = nullptr; - CUDA_CHECK(cudaMalloc(&d_subjects, total_floats * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_out, total_out * sizeof(float))); - CUDA_CHECK(cudaMemcpy(d_subjects, h_padded, total_floats * sizeof(float), - cudaMemcpyHostToDevice)); - free(h_padded); - - // Zero the output buffer so positions with n<2 leave a valid zero block. - CUDA_CHECK(cudaMemsetAsync(d_out, 0, total_out * sizeof(float), 0)); - - int ns = num_cuda_streams > 0 ? num_cuda_streams : 16; - std::vector streams(ns); - for (int s = 0; s < ns; s++) - CUDA_CHECK(cudaStreamCreate(&streams[s])); - - // Dispatch: every (position, row) pair hits a stream round-robin. - // Buffer stride for position p is buckets[p] — same as the kernel's - // num_features, so Query/Subject indexing is consistent. - int launch_err = 0; - size_t launch_counter = 0; - for (size_t p = 0; p < num_positions && launch_err == 0; p++) { - size_t n = position_seq_counts[p]; - if (n < 2) continue; - size_t b = (size_t)buckets[p]; - size_t pos_base = pos_float_offsets[p]; - size_t out_off = out_offsets[p]; - for (size_t i = 0; i < n; i++) { - cudaStream_t s = streams[launch_counter % (size_t)ns]; - launch_counter++; - bool ok = FullDTW::dist( - d_subjects + pos_base + i * b, // Query = row i (stride b) - d_subjects + pos_base, // Subject base (stride b) - d_out + out_off + i * n, - (index_t)n, (index_t)b, s); - if (!ok) { launch_err = 1; break; } - } - } - - int rc = launch_err; - if (rc == 0 && cudaGetLastError() != cudaSuccess) rc = -1; - if (rc == 0 && cudaDeviceSynchronize() != cudaSuccess) rc = -1; - - for (int s = 0; s < ns; s++) cudaStreamDestroy(streams[s]); - - if (rc != 0) { - cudaFree(d_subjects); cudaFree(d_out); - return -1; - } - - CUDA_CHECK(cudaMemcpy(out_distances, d_out, total_out * sizeof(float), - cudaMemcpyDeviceToHost)); - - // Finalize each position block: sqrtf + diagonal zero - for (size_t p = 0; p < num_positions; p++) { - size_t n = position_seq_counts[p]; - if (n < 2) continue; - finalize_matrix(out_distances + out_offsets[p], n); - } - - cudaFree(d_subjects); - cudaFree(d_out); - return 0; -} - -// ========================================================================== -// Python C API Bindings -// ========================================================================== - -#include -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include - -// --- dtw_distance --- -// Accepts use_open_start / use_open_end as ignored kwargs for backward -// compat with older Python callers. cuDTW++ has no open-boundary mode. -static PyObject *py_dtw_cuda(PyObject *self, PyObject *args, PyObject *kwargs) { - PyArrayObject *seq1_array = NULL, *seq2_array = NULL; - int use_open_start = 0, use_open_end = 0; - static char *kwlist[] = {(char*)"seq1", (char*)"seq2", - (char*)"use_open_start", (char*)"use_open_end", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|ii", kwlist, - &PyArray_Type, &seq1_array, - &PyArray_Type, &seq2_array, - &use_open_start, &use_open_end)) - return NULL; - (void)use_open_start; (void)use_open_end; - - if (PyArray_NDIM(seq1_array) != 1 || PyArray_NDIM(seq2_array) != 1) { - PyErr_SetString(PyExc_ValueError, "Input arrays must be 1-dimensional"); - return NULL; - } - if (PyArray_TYPE(seq1_array) != NPY_FLOAT32 || - PyArray_TYPE(seq2_array) != NPY_FLOAT32) { - PyErr_SetString(PyExc_TypeError, "Input arrays must be float32"); - return NULL; - } - - npy_intp len1 = PyArray_DIM(seq1_array, 0); - npy_intp len2 = PyArray_DIM(seq2_array, 0); - if (len1 == 0 || len2 == 0) { - PyErr_SetString(PyExc_ValueError, "Input arrays cannot be empty"); - return NULL; - } - - float *s1 = (float *)PyArray_DATA(seq1_array); - float *s2 = (float *)PyArray_DATA(seq2_array); - float distance = 0.0f; - - int rc = opendba_dtw_cuda(s1, (size_t)len1, s2, (size_t)len2, &distance); - if (rc != 0) { - PyErr_SetString(PyExc_RuntimeError, "CUDA DTW computation failed"); - return NULL; - } - return PyFloat_FromDouble((double)distance); -} - -// --- dtw_pairwise --- -static PyObject *py_dtw_pairwise(PyObject *self, PyObject *args, PyObject *kwargs) { - PyArrayObject *seq_array; - int use_open_start = 0, use_open_end = 0; - static char *kwlist[] = {(char*)"sequences", - (char*)"use_open_start", (char*)"use_open_end", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|ii", kwlist, - &PyArray_Type, &seq_array, - &use_open_start, &use_open_end)) - return NULL; - (void)use_open_start; (void)use_open_end; - - if (PyArray_NDIM(seq_array) != 2) { - PyErr_SetString(PyExc_ValueError, "sequences must be 2D"); - return NULL; - } - if (PyArray_TYPE(seq_array) != NPY_FLOAT32) { - PyErr_SetString(PyExc_TypeError, "sequences must be float32"); - return NULL; - } - - npy_intp *dims = PyArray_DIMS(seq_array); - size_t N = (size_t)dims[0], L = (size_t)dims[1]; - if (N < 2) { - PyErr_SetString(PyExc_ValueError, "Need at least 2 sequences"); - return NULL; - } - - npy_intp out_dims[2] = {(npy_intp)N, (npy_intp)N}; - PyArrayObject *out = (PyArrayObject *)PyArray_ZEROS(2, out_dims, NPY_FLOAT32, 0); - if (!out) return NULL; - - int rc = opendba_dtw_pairwise_batch( - (float *)PyArray_DATA(seq_array), N, L, - (float *)PyArray_DATA(out)); - if (rc != 0) { - Py_DECREF(out); - PyErr_SetString(PyExc_RuntimeError, "CUDA pairwise DTW failed"); - return NULL; - } - return (PyObject *)out; -} - -// --- dtw_pairwise_varlen --- -static PyObject *py_dtw_pairwise_varlen(PyObject *self, PyObject *args, PyObject *kwargs) { - PyArrayObject *seq_array, *len_array; - int use_open_start = 0, use_open_end = 0; - static char *kwlist[] = {(char*)"sequences", (char*)"lengths", - (char*)"use_open_start", (char*)"use_open_end", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|ii", kwlist, - &PyArray_Type, &seq_array, - &PyArray_Type, &len_array, - &use_open_start, &use_open_end)) - return NULL; - (void)use_open_start; (void)use_open_end; - - if (PyArray_NDIM(seq_array) != 2) { - PyErr_SetString(PyExc_ValueError, "sequences must be 2D"); - return NULL; - } - if (PyArray_TYPE(seq_array) != NPY_FLOAT32) { - PyErr_SetString(PyExc_TypeError, "sequences must be float32"); - return NULL; - } - if (PyArray_NDIM(len_array) != 1) { - PyErr_SetString(PyExc_ValueError, "lengths must be 1D"); - return NULL; - } - - npy_intp *sdims = PyArray_DIMS(seq_array); - size_t N = (size_t)sdims[0], max_len = (size_t)sdims[1]; - if ((size_t)PyArray_DIM(len_array, 0) != N) { - PyErr_SetString(PyExc_ValueError, "lengths size != num_sequences"); - return NULL; - } - if (N < 2) { - PyErr_SetString(PyExc_ValueError, "Need at least 2 sequences"); - return NULL; - } - - size_t *h_lengths = new size_t[N]; - for (size_t i = 0; i < N; i++) { - long long val; - if (PyArray_TYPE(len_array) == NPY_INT64) - val = *((long long *)PyArray_GETPTR1(len_array, i)); - else if (PyArray_TYPE(len_array) == NPY_INT32) - val = *((int *)PyArray_GETPTR1(len_array, i)); - else { - delete[] h_lengths; - PyErr_SetString(PyExc_TypeError, "lengths must be int32 or int64"); - return NULL; - } - if (val <= 0 || (size_t)val > max_len) { - delete[] h_lengths; - PyErr_Format(PyExc_ValueError, - "length[%zu]=%lld out of range", i, val); - return NULL; - } - h_lengths[i] = (size_t)val; - } - - npy_intp out_dims[2] = {(npy_intp)N, (npy_intp)N}; - PyArrayObject *out = (PyArrayObject *)PyArray_ZEROS(2, out_dims, NPY_FLOAT32, 0); - if (!out) { delete[] h_lengths; return NULL; } - - int rc = opendba_dtw_pairwise_varlen( - (float *)PyArray_DATA(seq_array), h_lengths, N, max_len, - (float *)PyArray_DATA(out)); - delete[] h_lengths; - - if (rc != 0) { - Py_DECREF(out); - PyErr_SetString(PyExc_RuntimeError, "CUDA varlen DTW failed"); - return NULL; - } - return (PyObject *)out; -} - -// --- dtw_multi_position_pairwise --- -static PyObject *py_dtw_multi_position_pairwise( - PyObject *self, PyObject *args, PyObject *kwargs) -{ - PyArrayObject *seq_array, *len_array, *cnt_array; - int use_open_start = 0, use_open_end = 0; - int num_cuda_streams = 16, device_id = 0; - static char *kwlist[] = { - (char*)"sequences", (char*)"lengths", (char*)"counts", - (char*)"use_open_start", (char*)"use_open_end", - (char*)"num_cuda_streams", (char*)"device_id", NULL}; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!|iiii", kwlist, - &PyArray_Type, &seq_array, - &PyArray_Type, &len_array, - &PyArray_Type, &cnt_array, - &use_open_start, &use_open_end, - &num_cuda_streams, &device_id)) - return NULL; - (void)use_open_start; (void)use_open_end; - - if (PyArray_NDIM(seq_array) != 2) { - PyErr_SetString(PyExc_ValueError, "sequences must be 2D"); return NULL; - } - if (PyArray_TYPE(seq_array) != NPY_FLOAT32) { - PyErr_SetString(PyExc_TypeError, "sequences must be float32"); return NULL; - } - if (PyArray_NDIM(len_array) != 1) { - PyErr_SetString(PyExc_ValueError, "lengths must be 1D"); return NULL; - } - if (PyArray_NDIM(cnt_array) != 1) { - PyErr_SetString(PyExc_ValueError, "counts must be 1D"); return NULL; - } - - npy_intp *sdims = PyArray_DIMS(seq_array); - size_t total_seqs = (size_t)sdims[0]; - size_t gml = (size_t)sdims[1]; - size_t num_pos = (size_t)PyArray_DIM(cnt_array, 0); - - if ((size_t)PyArray_DIM(len_array, 0) != total_seqs) { - PyErr_SetString(PyExc_ValueError, "lengths size != total sequences"); - return NULL; - } - - size_t *h_lengths = new size_t[total_seqs]; - for (size_t i = 0; i < total_seqs; i++) { - long long val; - if (PyArray_TYPE(len_array) == NPY_INT64) - val = *((long long *)PyArray_GETPTR1(len_array, i)); - else if (PyArray_TYPE(len_array) == NPY_INT32) - val = *((int *)PyArray_GETPTR1(len_array, i)); - else { - delete[] h_lengths; - PyErr_SetString(PyExc_TypeError, "lengths must be int32 or int64"); - return NULL; - } - if (val <= 0 || (size_t)val > gml) { - delete[] h_lengths; - PyErr_Format(PyExc_ValueError, "length[%zu]=%lld out of range", i, val); - return NULL; - } - h_lengths[i] = (size_t)val; - } - - size_t *h_counts = new size_t[num_pos]; - size_t check_total = 0; - for (size_t p = 0; p < num_pos; p++) { - long long val; - if (PyArray_TYPE(cnt_array) == NPY_INT64) - val = *((long long *)PyArray_GETPTR1(cnt_array, p)); - else if (PyArray_TYPE(cnt_array) == NPY_INT32) - val = *((int *)PyArray_GETPTR1(cnt_array, p)); - else { - delete[] h_lengths; delete[] h_counts; - PyErr_SetString(PyExc_TypeError, "counts must be int32 or int64"); - return NULL; - } - h_counts[p] = (size_t)val; - check_total += (size_t)val; - } - if (check_total != total_seqs) { - delete[] h_lengths; delete[] h_counts; - PyErr_Format(PyExc_ValueError, - "sum(counts)=%zu != total_sequences=%zu", - check_total, total_seqs); - return NULL; - } - - size_t total_out = 0; - for (size_t p = 0; p < num_pos; p++) - total_out += h_counts[p] * h_counts[p]; - - npy_intp out_dim = (npy_intp)total_out; - PyArrayObject *out = (PyArrayObject *)PyArray_ZEROS(1, &out_dim, NPY_FLOAT32, 0); - if (!out) { delete[] h_lengths; delete[] h_counts; return NULL; } - - int rc = opendba_dtw_multi_position_pairwise( - (float *)PyArray_DATA(seq_array), h_lengths, h_counts, - num_pos, gml, - (float *)PyArray_DATA(out), num_cuda_streams, device_id); - - delete[] h_lengths; - delete[] h_counts; - - if (rc != 0) { - Py_DECREF(out); - PyErr_SetString(PyExc_RuntimeError, "CUDA multi-position DTW failed"); - return NULL; - } - return (PyObject *)out; -} - -// --- cleanup --- -static PyObject *py_dtw_cleanup(PyObject *self, PyObject *args) { - opendba_dtw_cleanup(); - Py_RETURN_NONE; -} - -// ========================================================================== -// Module table -// ========================================================================== -static PyMethodDef DtwMethods[] = { - {"dtw_distance", (PyCFunction)py_dtw_cuda, METH_VARARGS | METH_KEYWORDS, - "Compute DTW distance between two sequences (cuDTW++ warp-shuffle)."}, - {"dtw_pairwise", (PyCFunction)py_dtw_pairwise, METH_VARARGS | METH_KEYWORDS, - "Compute pairwise DTW distances (cuDTW++ warp-shuffle)."}, - {"dtw_pairwise_varlen", (PyCFunction)py_dtw_pairwise_varlen, - METH_VARARGS | METH_KEYWORDS, - "Compute pairwise DTW distances for variable-length sequences."}, - {"dtw_multi_position_pairwise", (PyCFunction)py_dtw_multi_position_pairwise, - METH_VARARGS | METH_KEYWORDS, - "Compute pairwise DTW for multiple positions (cuDTW++ warp-shuffle)."}, - {"cleanup", py_dtw_cleanup, METH_NOARGS, - "Reset CUDA device and free resources."}, - {NULL, NULL, 0, NULL} -}; - -static struct PyModuleDef dtwmodule = { - PyModuleDef_HEAD_INIT, - "_cuda_dtw", - "cuDTW++ warp-shuffle DTW for baleen", - -1, - DtwMethods -}; - -PyMODINIT_FUNC PyInit__cuda_dtw(void) { - import_array(); - if (PyErr_Occurred()) return NULL; - - PyObject *m = PyModule_Create(&dtwmodule); - if (!m) return NULL; - - PyModule_AddIntConstant(m, "__version_major__", 0); - PyModule_AddIntConstant(m, "__version_minor__", 3); - PyModule_AddStringConstant(m, "__version__", "0.3.1-cudtw"); - - return m; -} diff --git a/baleen/_cuda_dtw/dtw.hpp b/baleen/_cuda_dtw/dtw.hpp deleted file mode 100644 index 40100b5..0000000 --- a/baleen/_cuda_dtw/dtw.hpp +++ /dev/null @@ -1,318 +0,0 @@ -#ifndef __dtw_hpp_included -#define __dtw_hpp_included - -#include "cuda_utils.hpp" -#include "limits.hpp" // for device side numeric_limits min() and max() - -using namespace cudahack; // for device side numeric_limits - -// sentinel value for the start of the DTW alignment, the stop condition for backtracking (ergo has no corresponding moveI or moveJ) -#define NIL 255 -#define DIAGONAL 1 -#define RIGHT 2 -#define UP 3 -// Special move designations that do not differently affect backtracking algorithm per se, but does affect cost (open=no accumulation of cost for rightward move). -#define OPEN_RIGHT 4 -#define NIL_OPEN_RIGHT 5 - -// For two series I & J, encode that the cost matrix DTW path (i,j) backtracking index decrement options for the DTW steps declared above are: -// unset (0) => (-1, -1), DIAGONAL => (-1,-1), RIGHT => (0,-1), UP => (-1,0), OPEN_RIGHT => (0,-1), OPEN_RIGHT and NIL_OPEN_RIGHT as per RIGHT -__device__ __constant__ short moveI[] = {-1, -1, 0, -1, 0, 0, 0}; -__device__ __constant__ short moveJ[] = {-1, -1, -1, 0, -1, -1, -1}; - -// How to find the 1D index of (X,Y) in the pitched (i.e. coalescing memory access aligned) memory for the DTW path matrix -#define pitchedCoord(Column, Row, mem_pitch) ((size_t)((Row) * (mem_pitch)) + (Column)) - -#define ARITH_SERIES_SUM(n) (((n) * (n + 1)) / 2) - -// Need this because you cannot template dynamically allocated kernel memory in CUDA, as per https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory -template -__device__ T *shared_memory_proxy() -{ - extern __shared__ unsigned char memory[]; - return reinterpret_cast(memory); -} - -/** - * Compute the distance between a given pair of sequences along every White-Neely step pattern option, for the given vertical swath of the cost matrix. - * Here "First" sequence is on the Y axis, "Second" sequence is on the X axis with respect to the DTW's up, right and diagonal move options. - */ -template -__global__ void DTWDistance(const T *first_seq_input, const size_t first_seq_input_length, const T *second_seq_input, const size_t second_seq_input_length, const size_t first_seq_index, - const size_t offset_within_second_seq, const T *gpu_sequences, const size_t maxSeqLength, const size_t num_sequences, const size_t *gpu_sequence_lengths, - T *dtwCostSoFar, T *newDtwCostSoFar, unsigned char *pathMatrix, const size_t pathMemPitch, T *dtwPairwiseDistances, const int use_open_start, const int use_open_end) -{ - // We need temporary storage for three diagonals of the wavefront calculation of the cost matrix to calculate the optimal path steps as a diagonal "wavefront" until we iterate - // through every position of the first sequence. - T *costs = shared_memory_proxy(); - - // Which two are we comparing in this threadblock? - // See if there is anything to process in this thread block - const size_t second_seq_length = second_seq_input ? second_seq_input_length : gpu_sequence_lengths[first_seq_index + blockIdx.x + 1]; - if (offset_within_second_seq >= second_seq_length) - { - return; // all threads in the threadblock will return - } - - const size_t first_seq_length = first_seq_input ? first_seq_input_length : gpu_sequence_lengths[first_seq_index]; - const T *first_seq = first_seq_input ? first_seq_input : &gpu_sequences[first_seq_index * maxSeqLength]; - const T *second_seq = second_seq_input ? second_seq_input : &gpu_sequences[(first_seq_index + blockIdx.x + 1) * maxSeqLength]; - - // Point to the correct spot in global memory where the costs are being stored. - dtwCostSoFar = &dtwCostSoFar[first_seq_length * blockIdx.x]; - if (newDtwCostSoFar != 0) - newDtwCostSoFar = &newDtwCostSoFar[first_seq_length * blockIdx.x]; - - // Each thread will be using the same second sequence value throughout the rest of the kernel, so store it as a local variable for efficiency. - const T second_seq_thread_val = offset_within_second_seq + threadIdx.x >= second_seq_length ? 0 : second_seq[offset_within_second_seq + threadIdx.x]; - // printf("offset_within_second_seq: %i, second_seq_thread_val: %f\n", offset_within_second_seq, second_seq_thread_val); - - // Possible shortcut: If we are allowing open right moves only at the end of the alignment, - // and the top row of the matrix is in that state, and it's the lowest cost option in this column - // of the DTW cost matrix, there is no need to continue computing further grid X elements. Why? - // Because any other path will just get more expensive, while the open right move can continue - // eating up the subject to get to the upper right corner of the DTW matrix without any additional cost. - // This will speed up prefix searches in particular, where the prefix (1st seq) length is a small proportion of the 2nd's. - // We don't even need to be keeping the pathMatrix to know we are in that state, because it's the overriding move choice (see used_open_right_end_cost below) - // when you're at the top of the matrix in open end mode. - if (offset_within_second_seq > first_seq_length && use_open_end && !use_open_start) - { - // Check if the search has already been abrogated by a previous kernel call (further left in the DTW matrix calculation) - if (dtwCostSoFar[0] == numeric_limits::max()) - { - if (pathMatrix != 0 && offset_within_second_seq + threadIdx.x < second_seq_length) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0) + threadIdx.x, first_seq_length - 1, pathMemPitch)] = OPEN_RIGHT; - } - return; - } - - // Otherwise map/reduce within this kernel to pretty efficiently find the minimum value across the 1D dtwCostSoFar array without variable length threadblock shared memory. - T minval = threadIdx.x < first_seq_length ? dtwCostSoFar[threadIdx.x] : numeric_limits::max(); - for (int i = 1; i * blockDim.x < first_seq_length; i++) - { - // Assign each thread to find the minimum values strided (by # threads doing work) across the length of the first sequence. - if (i * blockDim.x + threadIdx.x < first_seq_length && minval > dtwCostSoFar[i * blockDim.x + threadIdx.x]) - { - // Hopefully mostly coalesced memory access - minval = dtwCostSoFar[i * blockDim.x + threadIdx.x]; - } - } - minval = warpReduceMin(minval); // across the warp - int lane = threadIdx.x % CUDA_WARP_WIDTH; - int wid = threadIdx.x / CUDA_WARP_WIDTH; - T *warp_minvals = costs; // threadblock shared memory space is the same as the costs pseudo 2D array since we won't need it outside of this block - if (!lane) - warp_minvals[wid] = minval; - __syncthreads(); - // Get in-bounds values only for final threadblock reduction, calculated by the first warp's threads (threadblock may not be full). - if (!wid) - { - minval = (threadIdx.x < blockDim.x / CUDA_WARP_WIDTH) ? warp_minvals[lane] : numeric_limits::max(); - warp_minvals[0] = warpReduceMin(minval); // across all threads in the block - } - __syncthreads(); - - // Top row value is the lowest for this column, only need to populate the open_right move for correct backtracking and cumulative cost calcs - if (dtwCostSoFar[first_seq_length - 1] == warp_minvals[0]) - { - if (pathMatrix != 0 && offset_within_second_seq + threadIdx.x < second_seq_length) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0) + threadIdx.x, first_seq_length - 1, pathMemPitch)] = OPEN_RIGHT; - } - // Make sure bottom row's DTW cost so far is set to the max possible. - // This is how we indicate that we've decided to abrogated the rest of the search. - if (threadIdx.x == 0 && newDtwCostSoFar != 0) - { - newDtwCostSoFar[0] = numeric_limits::max(); - } - // As we've made a final determination for the cost, record it to GPU memory if we've been given a spot for it. - if (dtwPairwiseDistances != 0 && threadIdx.x == 0 && newDtwCostSoFar != 0) - { - // If the alignment has open right end, the medoid calculations will always be biased towards the shortest sequences since the open state is "free", - // which is troublesome for retaining consensus features in clusters. To remove this bias, we will normalize the distance matrix to be relative to the length of the - // shorter sequence with the assumption on average that the shorter sequence is the one generating "free" - // alignment ends that longer sequences can't compete with. - T normalized_pair_distance = (T)(sqrtf(newDtwCostSoFar[first_seq_length - 1]) / first_seq_length); - - // 1D index for row into distances upper left pairs triangle is the total size of the triangle, minus all those that haven't been processed yet. - dtwPairwiseDistances[ARITH_SERIES_SUM(num_sequences - 1) - ARITH_SERIES_SUM(num_sequences - first_seq_index - 1) + blockIdx.x] = normalized_pair_distance; - } - return; - } - } - - if (threadIdx.x == 0) - { - // Populate the bottom row of the vertical swath on every kernel invocation, this can't be done in parallel. - const T first_seq_start_val = first_seq[0]; - if (offset_within_second_seq == 0) - { - costs[0] = 0; - if (pathMatrix != 0) - { - pathMatrix[pitchedCoord(0, 0, pathMemPitch)] = use_open_start ? NIL_OPEN_RIGHT : NIL; // sentinel for path backtracking algorithm termination - } - } - else - { - costs[0] = dtwCostSoFar[0]; - if (pathMatrix != 0) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0), 0, pathMemPitch)] = use_open_start ? OPEN_RIGHT : RIGHT; - } - } - costs[0] += use_open_start ? 0 : (first_seq_start_val - second_seq_thread_val) * (first_seq_start_val - second_seq_thread_val); - int col; - for (col = 1; col < blockDim.x && offset_within_second_seq + col < second_seq_length; col++) - { - T diff = use_open_start ? 0 : first_seq_start_val - second_seq[offset_within_second_seq + col]; - costs[col + blockDim.x * (col % 3)] = costs[(col - 1) + blockDim.x * ((col - 1) % 3)] + diff * diff; - if (pathMatrix != 0) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0) + col, 0, pathMemPitch)] = use_open_start ? OPEN_RIGHT : RIGHT; - } - } - if (newDtwCostSoFar != 0) - newDtwCostSoFar[0] = costs[(col - 1)]; - } - - int i; // Indicates the ordinal of the diagonal of the wave front cost values being calculated - for (i = 1; i < first_seq_length + blockDim.x; i++) - { - - if (offset_within_second_seq + threadIdx.x < second_seq_length && // We're within the sequence bounds? - threadIdx.x < i && // The diff still corresponds to a spot in the cost matrix? - i - threadIdx.x < first_seq_length) - { - // NOTE: these were originally `volatile` but there is no cross-thread - // access here — up/diag/right_cost and diff are strictly thread-local - // scratch values. Dropping `volatile` lets the compiler keep them in - // registers instead of spilling to shared memory on every read/write, - // which gives a measurable speedup on Ampere+. - T up_cost = numeric_limits::max(); - T diag_cost = numeric_limits::max(); - T right_cost = numeric_limits::max(); - T diff = first_seq[i - threadIdx.x] - second_seq_thread_val; - - // The left edge of cost matrix vertical swath is a special case as we need to - // access previously global mem stored intermediate costs. - int used_open_right_end_cost = 0; - if (threadIdx.x == 0) - { - // Straight up is always an option - up_cost = costs[blockDim.x * ((i - 1) % 3)] + diff * diff; - if (offset_within_second_seq != 0) - { - // All three steps are possible, two drawn from previous intermediate results - right_cost = dtwCostSoFar[i]; - diag_cost = dtwCostSoFar[i - 1]; - if (i - threadIdx.x < first_seq_length - 1 || !use_open_end) - { - right_cost += diff * diff; - diag_cost += diff * diff; - } - else - { - used_open_right_end_cost = 1; - } - } - } - // For all other threads all the input data is stored locally in costs[]. - else - { - up_cost = costs[threadIdx.x + blockDim.x * ((i - 1) % 3)] + diff * diff; - right_cost = costs[(threadIdx.x - 1) + blockDim.x * ((i - 1) % 3)] + diff * diff; - diag_cost = costs[(threadIdx.x - 1) + blockDim.x * ((i - 2) % 3)] + diff * diff; - } - - // Use the White-Neely step pattern (a diagonal move is preferred to right-up or up-right if costs are equivalent). - if (use_open_end && i - threadIdx.x == first_seq_length - 1 && threadIdx.x != 0) - { - // No extra cost to consume a sequence element from the first sequence, just copy it over from the previous column. - right_cost = costs[(threadIdx.x - 1) + blockDim.x * ((i - 1) % 3)]; - used_open_right_end_cost = 1; - } - // char move; - if (diag_cost > up_cost) - { - if (up_cost > right_cost) - { - costs[threadIdx.x + blockDim.x * (i % 3)] = right_cost; - // Implicitly, if we aren't tracking the new cost so far, we assume we're doing a striped backtracing of the path so no path offset - // is required as it's being built and printed/used one stripe at a time. - if (pathMatrix != 0) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0) + threadIdx.x, i - threadIdx.x, pathMemPitch)] = used_open_right_end_cost ? OPEN_RIGHT : RIGHT; - } - // move = 'R'; - } - else - { - costs[threadIdx.x + blockDim.x * (i % 3)] = up_cost; - if (pathMatrix != 0) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0) + threadIdx.x, i - threadIdx.x, pathMemPitch)] = UP; - } - // move = 'U'; - } - } - else - { - if (diag_cost > right_cost) - { - costs[threadIdx.x + blockDim.x * (i % 3)] = right_cost; - if (pathMatrix != 0) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0) + threadIdx.x, i - threadIdx.x, pathMemPitch)] = used_open_right_end_cost ? OPEN_RIGHT : RIGHT; - } - // move = 'R'; - } - else - { - costs[threadIdx.x + blockDim.x * (i % 3)] = diag_cost; - if (pathMatrix != 0) - { - pathMatrix[pitchedCoord((newDtwCostSoFar ? offset_within_second_seq : 0) + threadIdx.x, i - threadIdx.x, pathMemPitch)] = DIAGONAL; - } - // move = 'D'; - } - } - // if(first_seq_index == 0) printf("0, %i, %hi, %f, %f, %f, %f, %c\n", i, threadIdx.x, up_cost, right_cost, diag_cost, diff*diff, move); - // if(first_seq_index == 1) printf("1, %i, %hi, %f, %f, %f, %f, %c\n", i, threadIdx.x, up_cost, right_cost, diag_cost, diff*diff, move); - // Right edge is a special case as we need to store back out intermediate result to global mem - // for the use of the next kernel call with a larger offset_within_second_seq. - if (newDtwCostSoFar != 0 && (threadIdx.x == blockDim.x - 1 || offset_within_second_seq + threadIdx.x == second_seq_length - 1)) - { - newDtwCostSoFar[i - threadIdx.x] = costs[threadIdx.x + blockDim.x * (i % 3)]; - } - } - - // To ensure all required previous costs from neighbouring threads are calculated and available for the next iteration. - __syncthreads(); - } - // If this is the end of the second sequence, we now know the total cost of the alignment and can populate - // global var dtwPairwiseDistances. This is more efficient than doing a round trip on the PCI bus to the CPU for the same purpose. - if (offset_within_second_seq + blockDim.x >= second_seq_length) - { - if (dtwPairwiseDistances != 0 && threadIdx.x == 0 && newDtwCostSoFar != 0) - { - // 1D index for row into distances upper left pairs triangle is the total size of the triangle, minus all those that haven't been processed yet. - int result_index = ARITH_SERIES_SUM(num_sequences - 1) - ARITH_SERIES_SUM(num_sequences - first_seq_index - 1) + blockIdx.x; - - // If the alignment has one open end, the medoid calculations will always be biased towards the shortest sequences since the open state is "free", - // which is troublesome for retaining consensus features in clusters. To remove this bias, we will normalize the distance matrix to be relative to the length of the - // shorter sequence with the assumption on average that the shorter sequence is the one generating "free" alignment ends that longer sequences can't compete with. - if (use_open_end && !use_open_start || !use_open_end && use_open_start) - { - dtwPairwiseDistances[result_index] = (T)(sqrtf(newDtwCostSoFar[first_seq_length - 1]) / first_seq_length); - } - else - { // use the distance as-is (similar length sequences will tend to cluster together) - dtwPairwiseDistances[result_index] = (T)sqrtf(newDtwCostSoFar[first_seq_length - 1]); - } - } - } -} - -#endif \ No newline at end of file diff --git a/baleen/_cuda_dtw/dtw_api.cpp b/baleen/_cuda_dtw/dtw_api.cpp deleted file mode 100644 index bb9284c..0000000 --- a/baleen/_cuda_dtw/dtw_api.cpp +++ /dev/null @@ -1,1272 +0,0 @@ -#include "dtw_api.h" -#include "dtw.hpp" -#include "cuda_utils.hpp" -#include -#include -#include -#include - -// Define CUDA_CHECK macro for error checking -#define CUDA_CHECK(call) \ - do \ - { \ - cudaError_t err = call; \ - if (err != cudaSuccess) \ - { \ - std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " - " \ - << cudaGetErrorString(err) << std::endl; \ - return -1; \ - } \ - } while (0) - -// ============================================================================ -// Cached Device Properties (per-device) -// ============================================================================ - -#include -static std::unordered_map g_max_threads_map; - -static int ensure_device_props(int device_id = 0) { - if (g_max_threads_map.find(device_id) == g_max_threads_map.end()) { - CUDA_CHECK(cudaSetDevice(device_id)); - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); - g_max_threads_map[device_id] = prop.maxThreadsPerBlock; - } - return 0; -} - -// 复用 OpenDBA 的核函数启动逻辑,仅适配单对序列计算 -int opendba_dtw_cuda( - const float *seq1, size_t len1, - const float *seq2, size_t len2, - int use_open_start, - int use_open_end, - float *out_distance) -{ - // 1. 输入校验 - if (!seq1 || !seq2 || !out_distance || len1 == 0 || len2 == 0) - { - fprintf(stderr, "Invalid input parameters\n"); - return -1; - } - - // 2. 设备内存分配(复用 OpenDBA 的内存对齐逻辑) - float *d_seq1, *d_seq2, *d_dtw_cost, *d_new_dtw_cost; - unsigned char *d_path_matrix; - float *d_pairwise_dist; - size_t path_mem_pitch; - - // 分配序列内存 - CUDA_CHECK(cudaMalloc(&d_seq1, len1 * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_seq2, len2 * sizeof(float))); - // 分配 DTW 计算所需临时内存 - // CRITICAL: d_dtw_cost and d_new_dtw_cost must be sized by len1 (first_seq_length) - // because the kernel indexes them up to first_seq_length-1 - CUDA_CHECK(cudaMalloc(&d_dtw_cost, len1 * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_new_dtw_cost, len1 * sizeof(float))); - CUDA_CHECK(cudaMallocPitch(&d_path_matrix, &path_mem_pitch, len2 * sizeof(unsigned char), len1)); - CUDA_CHECK(cudaMalloc(&d_pairwise_dist, sizeof(float))); - - // 3. 主机→设备数据拷贝 - CUDA_CHECK(cudaMemcpy(d_seq1, seq1, len1 * sizeof(float), cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(d_seq2, seq2, len2 * sizeof(float), cudaMemcpyHostToDevice)); - // 初始化临时内存(must match allocation sizes above) - CUDA_CHECK(cudaMemset(d_dtw_cost, 0, len1 * sizeof(float))); - CUDA_CHECK(cudaMemset(d_new_dtw_cost, 0, len1 * sizeof(float))); - CUDA_CHECK(cudaMemset(d_path_matrix, 0, path_mem_pitch * len1)); - CUDA_CHECK(cudaMemset(d_pairwise_dist, 0, sizeof(float))); - - // 4. 启动 OpenDBA 原版 DTW 核函数(参数严格对齐) - // Get device properties to determine thread count - ensure_device_props(0); - int max_threads = g_max_threads_map[0]; - - dim3 thread_block(max_threads, 1, 1); - size_t shared_mem = thread_block.x * 3 * sizeof(float); // 复用 OpenDBA 的共享内存计算 - - // CRITICAL: The wavefront algorithm processes the second sequence in chunks of blockDim.x - // We must call the kernel multiple times, advancing offset_within_second_seq each time - float *d_current_cost = d_dtw_cost; - float *d_next_cost = d_new_dtw_cost; - - for (size_t offset = 0; offset < len2; offset += max_threads) - { - DTWDistance<<<1, thread_block, shared_mem>>>( - d_seq1, len1, - d_seq2, len2, - 0, offset, // Advance offset through second sequence - (const float *)nullptr, 0, 0, // 多序列相关参数置空/0 - (const size_t *)nullptr, - d_current_cost, - d_next_cost, - d_path_matrix, - path_mem_pitch, - d_pairwise_dist, - use_open_start, - use_open_end); - CUDA_CHECK(cudaGetLastError()); // 检查核函数启动错误 - - // Swap buffers for next iteration - float *temp = d_current_cost; - d_current_cost = d_next_cost; - d_next_cost = temp; - } - - CUDA_CHECK(cudaDeviceSynchronize()); // 等待所有核函数执行完成 - - // 5. 设备→主机拷贝结果 - // After the loop, d_current_cost contains the final result - // (kernel writes to d_next_cost, then we swap, so result is in d_current_cost after swap) - float final_cost; - CUDA_CHECK(cudaMemcpy(&final_cost, &d_current_cost[len1 - 1], sizeof(float), cudaMemcpyDeviceToHost)); - - // Apply same normalization/distance calculation as the kernel would - // Note: The kernel computes squared differences, so we take sqrt here - if ((use_open_end && !use_open_start) || (!use_open_end && use_open_start)) - { - *out_distance = sqrtf(final_cost) / len1; // Normalized by sequence length - } - else - { - *out_distance = sqrtf(final_cost); // Raw distance - } - - // 6. 释放设备内存 - CUDA_CHECK(cudaFree(d_seq1)); - CUDA_CHECK(cudaFree(d_seq2)); - CUDA_CHECK(cudaFree(d_dtw_cost)); - CUDA_CHECK(cudaFree(d_new_dtw_cost)); - CUDA_CHECK(cudaFree(d_path_matrix)); - CUDA_CHECK(cudaFree(d_pairwise_dist)); - - return 0; -} - -void opendba_dtw_cleanup() -{ - g_max_threads_map.clear(); - int count = 0; - if (cudaGetDeviceCount(&count) == cudaSuccess) { - for (int i = 0; i < count; i++) { - cudaSetDevice(i); - cudaDeviceReset(); - } - } -} - -// ============================================================================ -// Batch Pairwise DTW -// ============================================================================ - -int opendba_dtw_pairwise_batch( - const float *sequences, - size_t num_sequences, - size_t seq_length, - int use_open_start, - int use_open_end, - float *out_distances) -{ - if (!sequences || !out_distances || num_sequences < 2 || seq_length == 0) - { - fprintf(stderr, "Invalid input parameters for batch DTW\n"); - return -1; - } - - // Allocate GPU memory for all sequences - float *d_sequences; - size_t *d_seq_lengths; - float *d_distances; - - size_t total_seq_size = num_sequences * seq_length * sizeof(float); - size_t num_pairs = (num_sequences * (num_sequences - 1)) / 2; // Upper triangle - - CUDA_CHECK(cudaMalloc(&d_sequences, total_seq_size)); - CUDA_CHECK(cudaMalloc(&d_seq_lengths, num_sequences * sizeof(size_t))); - CUDA_CHECK(cudaMalloc(&d_distances, num_pairs * sizeof(float))); - - // Copy sequences to GPU - CUDA_CHECK(cudaMemcpy(d_sequences, sequences, total_seq_size, cudaMemcpyHostToDevice)); - - // Initialize sequence lengths (all same) - size_t *h_seq_lengths = new size_t[num_sequences]; - for (size_t i = 0; i < num_sequences; i++) - { - h_seq_lengths[i] = seq_length; - } - CUDA_CHECK(cudaMemcpy(d_seq_lengths, h_seq_lengths, num_sequences * sizeof(size_t), cudaMemcpyHostToDevice)); - delete[] h_seq_lengths; - - // Allocate temporary buffers for DTW computation - ensure_device_props(0); - int max_threads = g_max_threads_map[0]; - - // Query available GPU memory - size_t free_memory, total_memory; - CUDA_CHECK(cudaMemGetInfo(&free_memory, &total_memory)); - - // Maximum pairs we'll compute in parallel (when processing first sequence) - size_t max_pairs_parallel = num_sequences - 1; - - // Memory required for cost buffers: seq_length * max_pairs_parallel * 2 * sizeof(float) - // Example: 1000 len × 100 pairs × 8 bytes = 800 KB (reasonable) - // With 20GB GPU, you could do: 10000 len × 1000 pairs × 8 bytes = 80 MB (still tiny!) - - size_t cost_buffer_size = seq_length * max_pairs_parallel * sizeof(float); - size_t total_temp_memory = cost_buffer_size * 2; // Two buffers - - if (getenv("DTW_DEBUG")) - { - fprintf(stderr, "=== DTW Batch Pairwise Memory Usage ===\n"); - fprintf(stderr, "GPU: %.2f GB total, %.2f GB free\n", - total_memory / 1024.0 / 1024.0 / 1024.0, - free_memory / 1024.0 / 1024.0 / 1024.0); - fprintf(stderr, "Input sequences: %.2f MB (%zu × %zu)\n", - total_seq_size / 1024.0 / 1024.0, num_sequences, seq_length); - fprintf(stderr, "Cost buffers: %.2f MB\n", total_temp_memory / 1024.0 / 1024.0); - fprintf(stderr, "Output distances: %.2f KB (%zu pairs)\n", - num_pairs * sizeof(float) / 1024.0, num_pairs); - fprintf(stderr, "Total GPU memory used: %.2f MB (%.1f%% of free memory)\n", - (total_seq_size + total_temp_memory + num_pairs * sizeof(float)) / 1024.0 / 1024.0, - 100.0 * (total_seq_size + total_temp_memory) / free_memory); - fprintf(stderr, "\nNote: Low memory usage is by design for stability.\n"); - fprintf(stderr, " With your GPU, you could process:\n"); - fprintf(stderr, " - 1000 sequences × 10000 length (~800 MB)\n"); - fprintf(stderr, " - 5000 sequences × 2000 length (~800 MB)\n"); - fprintf(stderr, "=======================================\n"); - } - - float *d_dtw_cost, *d_new_dtw_cost; - - // Allocate cost buffers: seq_length per parallel pair - CUDA_CHECK(cudaMalloc(&d_dtw_cost, seq_length * max_pairs_parallel * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_new_dtw_cost, seq_length * max_pairs_parallel * sizeof(float))); - - // Don't allocate path matrix for pairwise - we only need distances, not alignments - // This saves HUGE amounts of memory (would be seq_length² per pair!) - unsigned char *d_path_matrix = nullptr; - size_t path_mem_pitch = 0; - - // Compute pairwise distances - dim3 thread_block(max_threads, 1, 1); - size_t shared_mem = thread_block.x * 3 * sizeof(float); - - if (getenv("DTW_DEBUG")) - { - fprintf(stderr, "\n=== Starting DTW Pairwise Computation ===\n"); - fprintf(stderr, "Total pairs to compute: %zu\n", num_pairs); - fprintf(stderr, "Wavefront chunks per sequence: %zu\n", (seq_length + max_threads - 1) / max_threads); - fprintf(stderr, "=========================================\n\n"); - } - - size_t total_pairs_completed = 0; - auto start_time = std::chrono::high_resolution_clock::now(); - - // Process each reference sequence - for (size_t i = 0; i < num_sequences - 1; i++) - { - size_t num_comparisons = num_sequences - i - 1; - - // Initialize buffers for all comparisons with this reference sequence - CUDA_CHECK(cudaMemset(d_dtw_cost, 0, seq_length * num_comparisons * sizeof(float))); - CUDA_CHECK(cudaMemset(d_new_dtw_cost, 0, seq_length * num_comparisons * sizeof(float))); - - float *d_current_cost = d_dtw_cost; - float *d_next_cost = d_new_dtw_cost; - - // Process sequence in wavefront chunks - size_t num_chunks = (seq_length + max_threads - 1) / max_threads; - for (size_t chunk_idx = 0, offset = 0; offset < seq_length; chunk_idx++, offset += max_threads) - { - // Detailed progress for very long sequences - if (getenv("DTW_DEBUG") && seq_length > 10000 && chunk_idx % 100 == 0) - { - fprintf(stderr, " [Seq %zu] Chunk %zu/%zu (offset %zu/%zu)\n", - i, chunk_idx, num_chunks, offset, seq_length); - fflush(stderr); - } - - // Launch kernel with num_comparisons blocks to compute all pairs with sequence i - DTWDistance<<>>( - nullptr, seq_length, // Don't pass individual seqs, use batch arrays - nullptr, seq_length, - i, offset, // Reference sequence index and offset - d_sequences, seq_length, num_sequences, - d_seq_lengths, - d_current_cost, - d_next_cost, - d_path_matrix, - path_mem_pitch, - d_distances, - use_open_start, - use_open_end); - CUDA_CHECK(cudaGetLastError()); - - // Swap buffers - float *temp = d_current_cost; - d_current_cost = d_next_cost; - d_next_cost = temp; - } - - // Update completed pairs count - total_pairs_completed += num_comparisons; - - // Synchronize and report progress AFTER GPU work is done - if (i % 10 == 0 || getenv("DTW_DEBUG")) - { - CUDA_CHECK(cudaDeviceSynchronize()); // Wait for GPU to finish - - auto now = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast(now - start_time).count(); - - if (elapsed > 0) // Avoid division by zero - { - float progress = 100.0 * total_pairs_completed / num_pairs; - float pairs_per_sec = total_pairs_completed / (elapsed / 1000.0); - size_t remaining_pairs = num_pairs - total_pairs_completed; - float eta_sec = remaining_pairs / pairs_per_sec; - - fprintf(stderr, "[Progress] Ref seq %3zu/%zu | Completed: %6zu/%zu pairs (%.1f%%) | " - "Speed: %.1f pairs/sec | ETA: %.1f sec\n", - i + 1, num_sequences - 1, total_pairs_completed, num_pairs, progress, - pairs_per_sec, eta_sec); - fflush(stderr); - } - } - } - - CUDA_CHECK(cudaDeviceSynchronize()); - - auto end_time = std::chrono::high_resolution_clock::now(); - auto total_elapsed = std::chrono::duration_cast(end_time - start_time).count(); - - if (getenv("DTW_DEBUG") || true) // Always show completion - { - fprintf(stderr, "\n[Complete] Computed %zu pairs in %.2f seconds (%.1f pairs/sec)\n", - num_pairs, total_elapsed / 1000.0, num_pairs / (total_elapsed / 1000.0)); - fflush(stderr); - } - - // Copy results back (upper triangle format) - float *h_upper_triangle = new float[num_pairs]; - CUDA_CHECK(cudaMemcpy(h_upper_triangle, d_distances, num_pairs * sizeof(float), cudaMemcpyDeviceToHost)); - - // Convert upper triangle to full symmetric matrix - size_t pair_idx = 0; - for (size_t i = 0; i < num_sequences; i++) - { - out_distances[i * num_sequences + i] = 0.0f; // Diagonal - for (size_t j = i + 1; j < num_sequences; j++) - { - float dist = h_upper_triangle[pair_idx++]; - out_distances[i * num_sequences + j] = dist; - out_distances[j * num_sequences + i] = dist; // Symmetric - } - } - delete[] h_upper_triangle; - - // Cleanup - CUDA_CHECK(cudaFree(d_sequences)); - CUDA_CHECK(cudaFree(d_seq_lengths)); - CUDA_CHECK(cudaFree(d_distances)); - CUDA_CHECK(cudaFree(d_dtw_cost)); - CUDA_CHECK(cudaFree(d_new_dtw_cost)); - // d_path_matrix is nullptr, don't free it - - return 0; -} - -// ============================================================================ -// Variable-length Batch Pairwise DTW -// ============================================================================ - -int opendba_dtw_pairwise_varlen( - const float *sequences, - const size_t *seq_lengths, - size_t num_sequences, - size_t max_length, - int use_open_start, - int use_open_end, - float *out_distances) -{ - if (!sequences || !seq_lengths || !out_distances || num_sequences < 2 || max_length == 0) - { - fprintf(stderr, "Invalid input parameters for varlen batch DTW\n"); - return -1; - } - - float *d_sequences; - size_t *d_seq_lengths; - float *d_distances; - - size_t total_seq_size = num_sequences * max_length * sizeof(float); - size_t num_pairs = (num_sequences * (num_sequences - 1)) / 2; - - CUDA_CHECK(cudaMalloc(&d_sequences, total_seq_size)); - CUDA_CHECK(cudaMalloc(&d_seq_lengths, num_sequences * sizeof(size_t))); - CUDA_CHECK(cudaMalloc(&d_distances, num_pairs * sizeof(float))); - - CUDA_CHECK(cudaMemcpy(d_sequences, sequences, total_seq_size, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(d_seq_lengths, seq_lengths, num_sequences * sizeof(size_t), cudaMemcpyHostToDevice)); - - ensure_device_props(0); - int max_threads = g_max_threads_map[0]; - - size_t max_pairs_parallel = num_sequences - 1; - - float *d_dtw_cost, *d_new_dtw_cost; - - // Cost buffers use max_length for stride so all blocks index correctly - // (the kernel indexes as dtwCostSoFar[first_seq_length * blockIdx.x], - // but first_seq_length varies per reference — we use max_length to be safe) - CUDA_CHECK(cudaMalloc(&d_dtw_cost, max_length * max_pairs_parallel * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_new_dtw_cost, max_length * max_pairs_parallel * sizeof(float))); - - unsigned char *d_path_matrix = nullptr; - size_t path_mem_pitch = 0; - - dim3 thread_block(max_threads, 1, 1); - size_t shared_mem = thread_block.x * 3 * sizeof(float); - - for (size_t i = 0; i < num_sequences - 1; i++) - { - size_t num_comparisons = num_sequences - i - 1; - - CUDA_CHECK(cudaMemset(d_dtw_cost, 0, max_length * num_comparisons * sizeof(float))); - CUDA_CHECK(cudaMemset(d_new_dtw_cost, 0, max_length * num_comparisons * sizeof(float))); - - float *d_current_cost = d_dtw_cost; - float *d_next_cost = d_new_dtw_cost; - - // Iterate over wavefront chunks up to max_length; - // shorter second sequences exit early inside the kernel - for (size_t offset = 0; offset < max_length; offset += max_threads) - { - DTWDistance<<>>( - nullptr, max_length, - nullptr, max_length, - i, offset, - d_sequences, max_length, num_sequences, - d_seq_lengths, - d_current_cost, - d_next_cost, - d_path_matrix, - path_mem_pitch, - d_distances, - use_open_start, - use_open_end); - CUDA_CHECK(cudaGetLastError()); - - float *temp = d_current_cost; - d_current_cost = d_next_cost; - d_next_cost = temp; - } - } - - CUDA_CHECK(cudaDeviceSynchronize()); - - float *h_upper_triangle = new float[num_pairs]; - CUDA_CHECK(cudaMemcpy(h_upper_triangle, d_distances, num_pairs * sizeof(float), cudaMemcpyDeviceToHost)); - - size_t pair_idx = 0; - for (size_t i = 0; i < num_sequences; i++) - { - out_distances[i * num_sequences + i] = 0.0f; - for (size_t j = i + 1; j < num_sequences; j++) - { - float dist = h_upper_triangle[pair_idx++]; - out_distances[i * num_sequences + j] = dist; - out_distances[j * num_sequences + i] = dist; - } - } - delete[] h_upper_triangle; - - CUDA_CHECK(cudaFree(d_sequences)); - CUDA_CHECK(cudaFree(d_seq_lengths)); - CUDA_CHECK(cudaFree(d_distances)); - CUDA_CHECK(cudaFree(d_dtw_cost)); - CUDA_CHECK(cudaFree(d_new_dtw_cost)); - - return 0; -} - -// ============================================================================ -// Multi-Position Batched Pairwise DTW (CUDA Streams) -// ============================================================================ - -int opendba_dtw_multi_position_pairwise( - const float *all_sequences, - const size_t *all_seq_lengths, - const size_t *position_seq_counts, - size_t num_positions, - size_t global_max_length, - int use_open_start, - int use_open_end, - float *out_distances, - int num_cuda_streams, - int device_id) -{ - if (!all_sequences || !all_seq_lengths || !position_seq_counts || - !out_distances || num_positions == 0 || global_max_length == 0) - { - fprintf(stderr, "Invalid input parameters for multi-position batch DTW\n"); - return -1; - } - - CUDA_CHECK(cudaSetDevice(device_id)); - ensure_device_props(device_id); - int max_threads = g_max_threads_map[device_id]; - - // Compute offsets and sizes - size_t total_sequences = 0; - size_t total_out_floats = 0; - size_t max_n = 0; - for (size_t p = 0; p < num_positions; p++) - { - size_t n = position_seq_counts[p]; - total_sequences += n; - total_out_floats += n * n; // full matrix per position - if (n > max_n) max_n = n; - } - - // Compute per-position offsets - size_t *seq_offsets = new size_t[num_positions]; - size_t *dist_offsets = new size_t[num_positions]; // into upper-triangle buffer - size_t *out_offsets = new size_t[num_positions]; // into output full-matrix buffer - { - size_t seq_off = 0, dist_off = 0, out_off = 0; - for (size_t p = 0; p < num_positions; p++) - { - seq_offsets[p] = seq_off; - dist_offsets[p] = dist_off; - out_offsets[p] = out_off; - size_t n = position_seq_counts[p]; - seq_off += n; - dist_off += (n * (n - 1)) / 2; - out_off += n * n; - } - } - - size_t total_pairs = 0; - for (size_t p = 0; p < num_positions; p++) - { - size_t n = position_seq_counts[p]; - total_pairs += (n * (n - 1)) / 2; - } - - // Allocate GPU memory — all at once - float *d_sequences = nullptr; - size_t *d_seq_lengths = nullptr; - float *d_distances = nullptr; - - CUDA_CHECK(cudaMalloc(&d_sequences, total_sequences * global_max_length * sizeof(float))); - CUDA_CHECK(cudaMalloc(&d_seq_lengths, total_sequences * sizeof(size_t))); - if (total_pairs > 0) - { - CUDA_CHECK(cudaMalloc(&d_distances, total_pairs * sizeof(float))); - } - - // Upload all data in one transfer - CUDA_CHECK(cudaMemcpy(d_sequences, all_sequences, - total_sequences * global_max_length * sizeof(float), - cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(d_seq_lengths, all_seq_lengths, - total_sequences * sizeof(size_t), - cudaMemcpyHostToDevice)); - - // Create streams - if (num_cuda_streams < 1) num_cuda_streams = 1; - if (num_cuda_streams > 256) num_cuda_streams = 256; - - cudaStream_t *streams = new cudaStream_t[num_cuda_streams]; - for (int s = 0; s < num_cuda_streams; s++) - { - CUDA_CHECK(cudaStreamCreate(&streams[s])); - } - - // Allocate per-stream cost buffers - size_t max_comparisons = (max_n > 0) ? max_n - 1 : 0; - size_t cost_buf_size = max_comparisons * global_max_length * sizeof(float); - - float **d_cost_a = new float*[num_cuda_streams]; - float **d_cost_b = new float*[num_cuda_streams]; - for (int s = 0; s < num_cuda_streams; s++) - { - d_cost_a[s] = nullptr; - d_cost_b[s] = nullptr; - if (cost_buf_size > 0) - { - CUDA_CHECK(cudaMalloc(&d_cost_a[s], cost_buf_size)); - CUDA_CHECK(cudaMalloc(&d_cost_b[s], cost_buf_size)); - } - } - - // Kernel launch config - dim3 thread_block(max_threads, 1, 1); - size_t shared_mem = thread_block.x * 3 * sizeof(float); - - // Process all positions across streams - for (size_t p = 0; p < num_positions; p++) - { - size_t n = position_seq_counts[p]; - if (n < 2) continue; - - int s = p % num_cuda_streams; - size_t seq_offset = seq_offsets[p]; - - for (size_t i = 0; i < n - 1; i++) - { - size_t num_comparisons = n - i - 1; - - // Reset cost buffers before each reference sequence - cudaMemsetAsync(d_cost_a[s], 0, - num_comparisons * global_max_length * sizeof(float), - streams[s]); - cudaMemsetAsync(d_cost_b[s], 0, - num_comparisons * global_max_length * sizeof(float), - streams[s]); - - float *d_current_cost = d_cost_a[s]; - float *d_next_cost = d_cost_b[s]; - - for (size_t offset = 0; offset < global_max_length; offset += max_threads) - { - DTWDistance<<>>( - nullptr, global_max_length, - nullptr, global_max_length, - i, offset, - &d_sequences[seq_offset * global_max_length], - global_max_length, n, - &d_seq_lengths[seq_offset], - d_current_cost, - d_next_cost, - (unsigned char *)nullptr, 0, // no path matrix - &d_distances[dist_offsets[p]], - use_open_start, - use_open_end); - CUDA_CHECK(cudaGetLastError()); - - float *tmp = d_current_cost; - d_current_cost = d_next_cost; - d_next_cost = tmp; - } - } - } - - // Sync all streams - CUDA_CHECK(cudaDeviceSynchronize()); - - // Copy results back and convert upper-triangle to full matrices - if (total_pairs > 0) - { - float *h_upper_triangle = new float[total_pairs]; - CUDA_CHECK(cudaMemcpy(h_upper_triangle, d_distances, - total_pairs * sizeof(float), - cudaMemcpyDeviceToHost)); - - for (size_t p = 0; p < num_positions; p++) - { - size_t n = position_seq_counts[p]; - size_t out_off = out_offsets[p]; - size_t pair_off = dist_offsets[p]; - - // Fill full symmetric matrix - size_t pair_idx = 0; - for (size_t i = 0; i < n; i++) - { - out_distances[out_off + i * n + i] = 0.0f; // diagonal - for (size_t j = i + 1; j < n; j++) - { - float dist = h_upper_triangle[pair_off + pair_idx]; - out_distances[out_off + i * n + j] = dist; - out_distances[out_off + j * n + i] = dist; - pair_idx++; - } - } - } - - delete[] h_upper_triangle; - } - else - { - // Handle positions with n <= 1 (fill diagonal zeros) - for (size_t p = 0; p < num_positions; p++) - { - size_t n = position_seq_counts[p]; - size_t out_off = out_offsets[p]; - for (size_t i = 0; i < n; i++) - { - for (size_t j = 0; j < n; j++) - { - out_distances[out_off + i * n + j] = (i == j) ? 0.0f : 0.0f; - } - } - } - } - - // Cleanup - for (int s = 0; s < num_cuda_streams; s++) - { - if (d_cost_a[s]) cudaFree(d_cost_a[s]); - if (d_cost_b[s]) cudaFree(d_cost_b[s]); - cudaStreamDestroy(streams[s]); - } - delete[] d_cost_a; - delete[] d_cost_b; - delete[] streams; - - cudaFree(d_sequences); - cudaFree(d_seq_lengths); - if (d_distances) cudaFree(d_distances); - - delete[] seq_offsets; - delete[] dist_offsets; - delete[] out_offsets; - - return 0; -} - -// ============================================================================ -// Python C API Bindings -// ============================================================================ - -#include -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include - -/** - * Python wrapper for opendba_dtw_cuda - * - * Args: - * seq1: numpy array of floats (1D) - * seq2: numpy array of floats (1D) - * use_open_start: boolean (default False) - * use_open_end: boolean (default False) - * - * Returns: - * float: DTW distance - */ -static PyObject *py_dtw_cuda(PyObject *self, PyObject *args, PyObject *kwargs) -{ - PyArrayObject *seq1_array = NULL, *seq2_array = NULL; - int use_open_start = 0; - int use_open_end = 0; - - static char *kwlist[] = {(char *)"seq1", (char *)"seq2", - (char *)"use_open_start", (char *)"use_open_end", NULL}; - - // Parse arguments - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|ii", kwlist, - &PyArray_Type, &seq1_array, - &PyArray_Type, &seq2_array, - &use_open_start, &use_open_end)) - { - return NULL; - } - - // Validate input arrays - if (PyArray_NDIM(seq1_array) != 1 || PyArray_NDIM(seq2_array) != 1) - { - PyErr_SetString(PyExc_ValueError, "Input arrays must be 1-dimensional"); - return NULL; - } - - if (PyArray_TYPE(seq1_array) != NPY_FLOAT32 || PyArray_TYPE(seq2_array) != NPY_FLOAT32) - { - PyErr_SetString(PyExc_TypeError, "Input arrays must be of type float32"); - return NULL; - } - - // Get array dimensions and data - npy_intp len1 = PyArray_DIM(seq1_array, 0); - npy_intp len2 = PyArray_DIM(seq2_array, 0); - - if (len1 == 0 || len2 == 0) - { - PyErr_SetString(PyExc_ValueError, "Input arrays cannot be empty"); - return NULL; - } - - // Get pointers to array data - float *seq1_data = (float *)PyArray_DATA(seq1_array); - float *seq2_data = (float *)PyArray_DATA(seq2_array); - - // Allocate output - float distance = 0.0f; - - // Call CUDA function - int result = opendba_dtw_cuda( - seq1_data, (size_t)len1, - seq2_data, (size_t)len2, - use_open_start, - use_open_end, - &distance); - - if (result != 0) - { - PyErr_SetString(PyExc_RuntimeError, "CUDA DTW computation failed"); - return NULL; - } - - // Return the distance as a Python float - return PyFloat_FromDouble((double)distance); -} - -/** - * Python wrapper for batch pairwise DTW - */ -static PyObject *py_dtw_pairwise(PyObject *self, PyObject *args, PyObject *kwargs) -{ - PyArrayObject *sequences_array; - int use_open_start = 0; - int use_open_end = 0; - - static char *kwlist[] = {(char *)"sequences", (char *)"use_open_start", (char *)"use_open_end", NULL}; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|ii", kwlist, - &PyArray_Type, &sequences_array, - &use_open_start, &use_open_end)) - { - return NULL; - } - - // Validate input array - if (PyArray_NDIM(sequences_array) != 2) - { - PyErr_SetString(PyExc_ValueError, "sequences must be a 2D array (num_sequences, seq_length)"); - return NULL; - } - - if (PyArray_TYPE(sequences_array) != NPY_FLOAT32) - { - PyErr_SetString(PyExc_TypeError, "sequences must be float32 dtype"); - return NULL; - } - - npy_intp *dims = PyArray_DIMS(sequences_array); - size_t num_sequences = (size_t)dims[0]; - size_t seq_length = (size_t)dims[1]; - - if (num_sequences < 2) - { - PyErr_SetString(PyExc_ValueError, "Need at least 2 sequences"); - return NULL; - } - - float *sequences_data = (float *)PyArray_DATA(sequences_array); - - // Allocate output distance matrix - npy_intp out_dims[2] = {(npy_intp)num_sequences, (npy_intp)num_sequences}; - PyArrayObject *distance_matrix = (PyArrayObject *)PyArray_ZEROS(2, out_dims, NPY_FLOAT32, 0); - if (distance_matrix == NULL) - { - return NULL; - } - - float *distances_data = (float *)PyArray_DATA(distance_matrix); - - // Call CUDA function - int result = opendba_dtw_pairwise_batch( - sequences_data, num_sequences, seq_length, - use_open_start, use_open_end, - distances_data); - - if (result != 0) - { - Py_DECREF(distance_matrix); - PyErr_SetString(PyExc_RuntimeError, "CUDA batch DTW computation failed"); - return NULL; - } - - return (PyObject *)distance_matrix; -} - -static PyObject *py_dtw_pairwise_varlen(PyObject *self, PyObject *args, PyObject *kwargs) -{ - PyArrayObject *sequences_array; - PyArrayObject *lengths_array; - int use_open_start = 0; - int use_open_end = 0; - - static char *kwlist[] = {(char *)"sequences", (char *)"lengths", - (char *)"use_open_start", (char *)"use_open_end", NULL}; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!|ii", kwlist, - &PyArray_Type, &sequences_array, - &PyArray_Type, &lengths_array, - &use_open_start, &use_open_end)) - { - return NULL; - } - - if (PyArray_NDIM(sequences_array) != 2) - { - PyErr_SetString(PyExc_ValueError, "sequences must be a 2D array (num_sequences, max_length)"); - return NULL; - } - - if (PyArray_TYPE(sequences_array) != NPY_FLOAT32) - { - PyErr_SetString(PyExc_TypeError, "sequences must be float32 dtype"); - return NULL; - } - - if (PyArray_NDIM(lengths_array) != 1) - { - PyErr_SetString(PyExc_ValueError, "lengths must be a 1D array"); - return NULL; - } - - npy_intp *seq_dims = PyArray_DIMS(sequences_array); - size_t num_sequences = (size_t)seq_dims[0]; - size_t max_length = (size_t)seq_dims[1]; - - if ((size_t)PyArray_DIM(lengths_array, 0) != num_sequences) - { - PyErr_SetString(PyExc_ValueError, "lengths array size must match num_sequences"); - return NULL; - } - - if (num_sequences < 2) - { - PyErr_SetString(PyExc_ValueError, "Need at least 2 sequences"); - return NULL; - } - - size_t *h_lengths = new size_t[num_sequences]; - for (size_t i = 0; i < num_sequences; i++) - { - long long val; - if (PyArray_TYPE(lengths_array) == NPY_INT64) - val = *((long long *)PyArray_GETPTR1(lengths_array, i)); - else if (PyArray_TYPE(lengths_array) == NPY_INT32) - val = *((int *)PyArray_GETPTR1(lengths_array, i)); - else - { - delete[] h_lengths; - PyErr_SetString(PyExc_TypeError, "lengths must be int32 or int64 dtype"); - return NULL; - } - if (val <= 0 || (size_t)val > max_length) - { - delete[] h_lengths; - PyErr_Format(PyExc_ValueError, - "length[%zu]=%lld out of range (1..%zu)", i, val, max_length); - return NULL; - } - h_lengths[i] = (size_t)val; - } - - float *sequences_data = (float *)PyArray_DATA(sequences_array); - - npy_intp out_dims[2] = {(npy_intp)num_sequences, (npy_intp)num_sequences}; - PyArrayObject *distance_matrix = (PyArrayObject *)PyArray_ZEROS(2, out_dims, NPY_FLOAT32, 0); - if (distance_matrix == NULL) - { - delete[] h_lengths; - return NULL; - } - - float *distances_data = (float *)PyArray_DATA(distance_matrix); - - int result = opendba_dtw_pairwise_varlen( - sequences_data, h_lengths, num_sequences, max_length, - use_open_start, use_open_end, - distances_data); - - delete[] h_lengths; - - if (result != 0) - { - Py_DECREF(distance_matrix); - PyErr_SetString(PyExc_RuntimeError, "CUDA varlen batch DTW computation failed"); - return NULL; - } - - return (PyObject *)distance_matrix; -} - -static PyObject *py_dtw_multi_position_pairwise(PyObject *self, PyObject *args, PyObject *kwargs) -{ - PyArrayObject *sequences_array; - PyArrayObject *lengths_array; - PyArrayObject *counts_array; - int use_open_start = 0; - int use_open_end = 0; - int num_cuda_streams = 16; - int device_id = 0; - - static char *kwlist[] = { - (char *)"sequences", (char *)"lengths", (char *)"counts", - (char *)"use_open_start", (char *)"use_open_end", - (char *)"num_cuda_streams", (char *)"device_id", NULL}; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!|iiii", kwlist, - &PyArray_Type, &sequences_array, - &PyArray_Type, &lengths_array, - &PyArray_Type, &counts_array, - &use_open_start, &use_open_end, - &num_cuda_streams, &device_id)) - { - return NULL; - } - - // Validate sequences array: 2D float32 - if (PyArray_NDIM(sequences_array) != 2) - { - PyErr_SetString(PyExc_ValueError, - "sequences must be a 2D array (total_sequences, global_max_length)"); - return NULL; - } - if (PyArray_TYPE(sequences_array) != NPY_FLOAT32) - { - PyErr_SetString(PyExc_TypeError, "sequences must be float32 dtype"); - return NULL; - } - - // Validate lengths array: 1D int64 - if (PyArray_NDIM(lengths_array) != 1) - { - PyErr_SetString(PyExc_ValueError, "lengths must be a 1D array"); - return NULL; - } - - // Validate counts array: 1D int64 - if (PyArray_NDIM(counts_array) != 1) - { - PyErr_SetString(PyExc_ValueError, "counts must be a 1D array"); - return NULL; - } - - npy_intp *seq_dims = PyArray_DIMS(sequences_array); - size_t total_sequences = (size_t)seq_dims[0]; - size_t global_max_length = (size_t)seq_dims[1]; - size_t num_positions = (size_t)PyArray_DIM(counts_array, 0); - - if ((size_t)PyArray_DIM(lengths_array, 0) != total_sequences) - { - PyErr_SetString(PyExc_ValueError, - "lengths array size must match total number of sequences"); - return NULL; - } - - // Convert lengths to size_t array - size_t *h_lengths = new size_t[total_sequences]; - for (size_t i = 0; i < total_sequences; i++) - { - long long val; - if (PyArray_TYPE(lengths_array) == NPY_INT64) - val = *((long long *)PyArray_GETPTR1(lengths_array, i)); - else if (PyArray_TYPE(lengths_array) == NPY_INT32) - val = *((int *)PyArray_GETPTR1(lengths_array, i)); - else - { - delete[] h_lengths; - PyErr_SetString(PyExc_TypeError, "lengths must be int32 or int64 dtype"); - return NULL; - } - if (val <= 0 || (size_t)val > global_max_length) - { - delete[] h_lengths; - PyErr_Format(PyExc_ValueError, - "length[%zu]=%lld out of range (1..%zu)", i, val, global_max_length); - return NULL; - } - h_lengths[i] = (size_t)val; - } - - // Convert counts to size_t array - size_t *h_counts = new size_t[num_positions]; - size_t check_total = 0; - for (size_t p = 0; p < num_positions; p++) - { - long long val; - if (PyArray_TYPE(counts_array) == NPY_INT64) - val = *((long long *)PyArray_GETPTR1(counts_array, p)); - else if (PyArray_TYPE(counts_array) == NPY_INT32) - val = *((int *)PyArray_GETPTR1(counts_array, p)); - else - { - delete[] h_lengths; - delete[] h_counts; - PyErr_SetString(PyExc_TypeError, "counts must be int32 or int64 dtype"); - return NULL; - } - if (val < 0) - { - delete[] h_lengths; - delete[] h_counts; - PyErr_Format(PyExc_ValueError, "counts[%zu]=%lld must be non-negative", p, val); - return NULL; - } - h_counts[p] = (size_t)val; - check_total += (size_t)val; - } - - if (check_total != total_sequences) - { - delete[] h_lengths; - delete[] h_counts; - PyErr_Format(PyExc_ValueError, - "sum(counts)=%zu != total sequences=%zu", check_total, total_sequences); - return NULL; - } - - // Compute total output size - size_t total_out_floats = 0; - for (size_t p = 0; p < num_positions; p++) - total_out_floats += h_counts[p] * h_counts[p]; - - // Allocate output as flat 1D array - npy_intp out_dim = (npy_intp)total_out_floats; - PyArrayObject *out_array = (PyArrayObject *)PyArray_ZEROS(1, &out_dim, NPY_FLOAT32, 0); - if (out_array == NULL) - { - delete[] h_lengths; - delete[] h_counts; - return NULL; - } - - float *sequences_data = (float *)PyArray_DATA(sequences_array); - float *out_data = (float *)PyArray_DATA(out_array); - - int result = opendba_dtw_multi_position_pairwise( - sequences_data, h_lengths, h_counts, - num_positions, global_max_length, - use_open_start, use_open_end, - out_data, num_cuda_streams, device_id); - - delete[] h_lengths; - delete[] h_counts; - - if (result != 0) - { - Py_DECREF(out_array); - PyErr_SetString(PyExc_RuntimeError, "CUDA multi-position batch DTW failed"); - return NULL; - } - - return (PyObject *)out_array; -} - -/** - * Python wrapper for opendba_dtw_cleanup - */ -static PyObject *py_dtw_cleanup(PyObject *self, PyObject *args) -{ - opendba_dtw_cleanup(); - Py_RETURN_NONE; -} - -// Method definitions -static PyMethodDef DtwMethods[] = { - {"dtw_distance", (PyCFunction)py_dtw_cuda, METH_VARARGS | METH_KEYWORDS, - "Compute DTW distance between two sequences using CUDA.\n\n" - "Parameters\n" - "----------\n" - "seq1 : np.ndarray\n" - " First sequence (1D float32 array)\n" - "seq2 : np.ndarray\n" - " Second sequence (1D float32 array)\n" - "use_open_start : bool, optional\n" - " Enable open start boundary (default: False)\n" - "use_open_end : bool, optional\n" - " Enable open end boundary (default: False)\n\n" - "Returns\n" - "-------\n" - "float\n" - " DTW distance between seq1 and seq2\n"}, - {"dtw_pairwise", (PyCFunction)py_dtw_pairwise, METH_VARARGS | METH_KEYWORDS, - "Compute pairwise DTW distances for a batch of sequences using CUDA.\n\n" - "This is much more efficient than computing distances one-by-one,\n" - "as it amortizes GPU memory transfer overhead over many computations.\n\n" - "Parameters\n" - "----------\n" - "sequences : np.ndarray\n" - " 2D array of sequences (num_sequences, seq_length) in float32\n" - " All sequences must have the same length\n" - "use_open_start : bool, optional\n" - " Enable open start boundary (default: False)\n" - "use_open_end : bool, optional\n" - " Enable open end boundary (default: False)\n\n" - "Returns\n" - "-------\n" - "np.ndarray\n" - " Distance matrix (num_sequences, num_sequences) with DTW distances\n" - " Matrix is symmetric with zeros on diagonal\n"}, - {"dtw_pairwise_varlen", (PyCFunction)py_dtw_pairwise_varlen, METH_VARARGS | METH_KEYWORDS, - "Compute pairwise DTW distances for variable-length sequences using CUDA.\n\n" - "Parameters\n" - "----------\n" - "sequences : np.ndarray\n" - " 2D padded array (num_sequences, max_length) in float32\n" - "lengths : np.ndarray\n" - " 1D array of actual sequence lengths (int32 or int64)\n" - "use_open_start : bool, optional\n" - " Enable open start boundary (default: False)\n" - "use_open_end : bool, optional\n" - " Enable open end boundary (default: False)\n\n" - "Returns\n" - "-------\n" - "np.ndarray\n" - " Distance matrix (num_sequences, num_sequences) with DTW distances\n"}, - {"dtw_multi_position_pairwise", (PyCFunction)py_dtw_multi_position_pairwise, - METH_VARARGS | METH_KEYWORDS, - "Compute pairwise DTW distances for multiple positions in one batched GPU call.\n\n" - "Parameters\n" - "----------\n" - "sequences : np.ndarray\n" - " 2D padded array (total_sequences, global_max_length) in float32\n" - "lengths : np.ndarray\n" - " 1D array of actual sequence lengths (int64)\n" - "counts : np.ndarray\n" - " 1D array of sequence counts per position (int64)\n" - "use_open_start : int, optional\n" - " Enable open start boundary (default: 0)\n" - "use_open_end : int, optional\n" - " Enable open end boundary (default: 0)\n" - "num_cuda_streams : int, optional\n" - " Number of CUDA streams (default: 16)\n\n" - "Returns\n" - "-------\n" - "np.ndarray\n" - " Flat 1D array of concatenated distance matrices (float32)\n"}, - {"cleanup", py_dtw_cleanup, METH_NOARGS, - "Reset CUDA device and free all resources.\n\n" - "This should be called when done using CUDA DTW to free GPU resources.\n"}, - {NULL, NULL, 0, NULL} // Sentinel -}; - -// Module definition -static struct PyModuleDef dtwmodule = { - PyModuleDef_HEAD_INIT, - "_cuda_dtw", - "CUDA-accelerated Dynamic Time Warping (DTW) computation\n\n" - "This module provides GPU-accelerated DTW distance calculation using CUDA.\n" - "It supports open start and open end boundary conditions.\n", - -1, - DtwMethods}; - -// Module initialization function -PyMODINIT_FUNC PyInit__cuda_dtw(void) -{ - // Import NumPy API - import_array(); - if (PyErr_Occurred()) - { - return NULL; - } - - PyObject *module = PyModule_Create(&dtwmodule); - if (module == NULL) - { - return NULL; - } - - // Add module-level constants - PyModule_AddIntConstant(module, "__version_major__", 0); - PyModule_AddIntConstant(module, "__version_minor__", 3); - PyModule_AddStringConstant(module, "__version__", "0.3.0"); - - return module; -} \ No newline at end of file diff --git a/baleen/_cuda_dtw/dtw_api.h b/baleen/_cuda_dtw/dtw_api.h deleted file mode 100644 index 7c256ae..0000000 --- a/baleen/_cuda_dtw/dtw_api.h +++ /dev/null @@ -1,108 +0,0 @@ -#ifndef DTW_C_API_H -#define DTW_C_API_H - -#ifdef __cplusplus -extern "C" -{ -#endif - - /** - * @brief 调用 OpenDBA 原版 CUDA DTW 计算两个序列的距离 - * @param seq1 主机端浮点序列1(float 类型) - * @param len1 序列1长度 - * @param seq2 主机端浮点序列2(float 类型) - * @param len2 序列2长度 - * @param use_open_start 是否启用 open start 边界 - * @param use_open_end 是否启用 open end 边界 - * @param out_distance 输出 DTW 距离(主机端,float 类型) - * @return 0=成功,非0=错误码(1=内存分配失败,2=核函数启动失败,3=数据拷贝失败) - */ - int opendba_dtw_cuda( - const float *seq1, size_t len1, - const float *seq2, size_t len2, - int use_open_start, - int use_open_end, - float *out_distance); - - /** - * @brief Compute pairwise DTW distances for a batch of sequences (all same length) - * @param sequences Flattened array of sequences (num_sequences * seq_length floats) - * @param num_sequences Number of sequences - * @param seq_length Length of each sequence (all must be same length) - * @param use_open_start Whether to use open start boundary - * @param use_open_end Whether to use open end boundary - * @param out_distances Output pairwise distance matrix (num_sequences * num_sequences floats) - * @return 0=success, non-zero=error - */ - int opendba_dtw_pairwise_batch( - const float *sequences, - size_t num_sequences, - size_t seq_length, - int use_open_start, - int use_open_end, - float *out_distances); - - /** - * @brief Compute pairwise DTW distances for variable-length sequences - * - * Sequences are stored in a padded 2D layout (num_sequences × max_length), - * with actual lengths given separately. Only the first seq_lengths[i] - * elements of row i are used; the rest are ignored. - * - * @param sequences Padded array (num_sequences * max_length floats) - * @param seq_lengths Actual length of each sequence (num_sequences values) - * @param num_sequences Number of sequences - * @param max_length Padded row width (>= max of seq_lengths) - * @param use_open_start Whether to use open start boundary - * @param use_open_end Whether to use open end boundary - * @param out_distances Output matrix (num_sequences * num_sequences floats) - * @return 0=success, non-zero=error - */ - int opendba_dtw_pairwise_varlen( - const float *sequences, - const size_t *seq_lengths, - size_t num_sequences, - size_t max_length, - int use_open_start, - int use_open_end, - float *out_distances); - - /** - * @brief Compute pairwise DTW for multiple positions in one batched GPU call. - * - * All positions' padded signals are concatenated into a single flat array. - * CUDA streams enable concurrent processing of different positions. - * - * @param all_sequences Concatenated padded signals: sum(n_i) * global_max_length floats - * @param all_seq_lengths Actual lengths for all sequences: sum(n_i) values - * @param position_seq_counts Number of sequences per position: num_positions values - * @param num_positions Number of genomic positions - * @param global_max_length Max signal length across all positions (padding width) - * @param use_open_start Open start boundary - * @param use_open_end Open end boundary - * @param out_distances Output: concatenated full distance matrices (sum(n_i^2) floats) - * @param num_cuda_streams Number of CUDA streams for concurrency - * @return 0=success, non-zero=error - */ - int opendba_dtw_multi_position_pairwise( - const float *all_sequences, - const size_t *all_seq_lengths, - const size_t *position_seq_counts, - size_t num_positions, - size_t global_max_length, - int use_open_start, - int use_open_end, - float *out_distances, - int num_cuda_streams, - int device_id); - - /** - * @brief 清理 CUDA 资源 - */ - void opendba_dtw_cleanup(); - -#ifdef __cplusplus -} -#endif - -#endif // DTW_C_API_H \ No newline at end of file diff --git a/baleen/_cuda_dtw/limits.hpp b/baleen/_cuda_dtw/limits.hpp deleted file mode 100644 index 7c568b4..0000000 --- a/baleen/_cuda_dtw/limits.hpp +++ /dev/null @@ -1,68 +0,0 @@ -#ifndef __CUDA_LIMITS_HPP -#define __CUDA_LIMITS_HPP - -#if defined(_WIN32) -typedef unsigned short ushort; -typedef unsigned int uint; -#endif - -#include -#include - -namespace cudahack -{ - - template - struct numeric_limits; - - template <> - struct numeric_limits - { - __device__ __forceinline__ static short min() { return SHRT_MIN; } - __device__ __forceinline__ static short max() { return SHRT_MAX; } - }; - - template <> - struct numeric_limits - { - __device__ __forceinline__ static ushort min() { return 0; } - __device__ __forceinline__ static ushort max() { return USHRT_MAX; } - }; - - template <> - struct numeric_limits - { - __device__ __forceinline__ static int min() { return INT_MIN; } - __device__ __forceinline__ static int max() { return INT_MAX; } - }; - - template <> - struct numeric_limits - { - __device__ __forceinline__ static uint min() { return 0; } - __device__ __forceinline__ static uint max() { return UINT_MAX; } - }; - - template <> - struct numeric_limits - { - __device__ __forceinline__ static unsigned long long int min() { return 0; } - __device__ __forceinline__ static unsigned long long int max() { return ULLONG_MAX; } - }; - - template <> - struct numeric_limits - { - __device__ __forceinline__ static float min() { return FLT_MIN; } - __device__ __forceinline__ static float max() { return FLT_MAX; } - }; - - template <> - struct numeric_limits - { - __device__ __forceinline__ static double min() { return DBL_MIN; } - __device__ __forceinline__ static double max() { return DBL_MAX; } - }; -} - -#endif \ No newline at end of file diff --git a/baleen/_cuda_dtw/multithreading.cpp b/baleen/_cuda_dtw/multithreading.cpp deleted file mode 100644 index 72e7824..0000000 --- a/baleen/_cuda_dtw/multithreading.cpp +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Copyright 1993-2012 NVIDIA Corporation. All rights reserved. - * - * See https://docs.nvidia.com/cuda/eula/index.html#nvidia-driver-license - * - * Please refer to the NVIDIA end user license agreement (EULA) associated - * with this source code for terms and conditions that govern your use of - * this software. Any use, reproduction, disclosure, or distribution of - * this software and related documentation outside the terms of the EULA - * is strictly prohibited. - * - */ - -#include "multithreading.h" - -#if _WIN32 -// Create thread -CUTThread cutStartThread(CUT_THREADROUTINE func, void *data) -{ - return CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)func, data, 0, NULL); -} - -// Wait for thread to finish -void cutEndThread(CUTThread thread) -{ - WaitForSingleObject(thread, INFINITE); - CloseHandle(thread); -} - -// Destroy thread -void cutDestroyThread(CUTThread thread) -{ - TerminateThread(thread, 0); - CloseHandle(thread); -} - -// Wait for multiple threads -void cutWaitForThreads(const CUTThread *threads, int num) -{ - WaitForMultipleObjects(num, threads, true, INFINITE); - - for (int i = 0; i < num; i++) - { - CloseHandle(threads[i]); - } -} - -// Create barrier. -CUTBarrier cutCreateBarrier(int releaseCount) -{ - CUTBarrier barrier; - - InitializeCriticalSection(&barrier.criticalSection); - barrier.barrierEvent = CreateEvent(NULL, TRUE, FALSE, TEXT("BarrierEvent")); - barrier.count = 0; - barrier.releaseCount = releaseCount; - - return barrier; -} - -// Increment barrier. (excution continues) -void cutIncrementBarrier(CUTBarrier *barrier) -{ - int myBarrierCount; - EnterCriticalSection(&barrier->criticalSection); - myBarrierCount = ++barrier->count; - LeaveCriticalSection(&barrier->criticalSection); - - if (myBarrierCount >= barrier->releaseCount) - { - SetEvent(barrier->barrierEvent); - } -} - -// Wait for barrier release. -void cutWaitForBarrier(CUTBarrier *barrier) -{ - WaitForSingleObject(barrier->barrierEvent, INFINITE); -} - -// Destory barrier -void cutDestroyBarrier(CUTBarrier *barrier) -{ -} - -#else -// Create thread -CUTThread cutStartThread(CUT_THREADROUTINE func, void *data) -{ - pthread_t thread; - pthread_create(&thread, NULL, func, data); - return thread; -} - -// Wait for thread to finish -void cutEndThread(CUTThread thread) -{ - pthread_join(thread, NULL); -} - -// Destroy thread -void cutDestroyThread(CUTThread thread) -{ - pthread_cancel(thread); -} - -// Wait for multiple threads -void cutWaitForThreads(const CUTThread *threads, int num) -{ - for (int i = 0; i < num; i++) - { - cutEndThread(threads[i]); - } -} - -// Create barrier. -CUTBarrier cutCreateBarrier(int releaseCount) -{ - CUTBarrier barrier; - - barrier.count = 0; - barrier.releaseCount = releaseCount; - - pthread_mutex_init(&barrier.mutex, 0); - pthread_cond_init(&barrier.conditionVariable, 0); - - return barrier; -} - -// Increment barrier. (excution continues) -void cutIncrementBarrier(CUTBarrier *barrier) -{ - int myBarrierCount; - pthread_mutex_lock(&barrier->mutex); - myBarrierCount = ++barrier->count; - pthread_mutex_unlock(&barrier->mutex); - - if (myBarrierCount >= barrier->releaseCount) - { - pthread_cond_signal(&barrier->conditionVariable); - } -} - -// Wait for barrier release. -void cutWaitForBarrier(CUTBarrier *barrier) -{ - pthread_mutex_lock(&barrier->mutex); - - while (barrier->count < barrier->releaseCount) - { - pthread_cond_wait(&barrier->conditionVariable, &barrier->mutex); - } - - pthread_mutex_unlock(&barrier->mutex); -} - -// Destory barrier -void cutDestroyBarrier(CUTBarrier *barrier) -{ - pthread_mutex_destroy(&barrier->mutex); - pthread_cond_destroy(&barrier->conditionVariable); -} - -#endif \ No newline at end of file diff --git a/baleen/_cuda_dtw/multithreading.h b/baleen/_cuda_dtw/multithreading.h deleted file mode 100644 index ab7cfc0..0000000 --- a/baleen/_cuda_dtw/multithreading.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 1993-2012 NVIDIA Corporation. All rights reserved. - * - * Please refer to the NVIDIA end user license agreement (EULA) associated - * with this source code for terms and conditions that govern your use of - * this software. Any use, reproduction, disclosure, or distribution of - * this software and related documentation outside the terms of the EULA - * is strictly prohibited. - * - */ - -#ifndef MULTITHREADING_H -#define MULTITHREADING_H - -// Simple portable thread library. - -#if _WIN32 -// Windows threads. -#include - -typedef HANDLE CUTThread; -typedef unsigned(WINAPI *CUT_THREADROUTINE)(void *); - -struct CUTBarrier -{ - CRITICAL_SECTION criticalSection; - HANDLE barrierEvent; - int releaseCount; - int count; -}; - -#define CUT_THREADPROC unsigned WINAPI -#define CUT_THREADEND return 0 - -#else -// POSIX threads. -#include - -typedef pthread_t CUTThread; -typedef void *(*CUT_THREADROUTINE)(void *); - -#define CUT_THREADPROC void * -#define CUT_THREADEND return 0 - -struct CUTBarrier -{ - pthread_mutex_t mutex; - pthread_cond_t conditionVariable; - int releaseCount; - int count; -}; - -#endif - -#ifdef __cplusplus -extern "C" -{ -#endif - - // Create thread. - CUTThread cutStartThread(CUT_THREADROUTINE, void *data); - - // Wait for thread to finish. - void cutEndThread(CUTThread thread); - - // Destroy thread. - void cutDestroyThread(CUTThread thread); - - // Wait for multiple threads. - void cutWaitForThreads(const CUTThread *threads, int num); - - // Create barrier. - CUTBarrier cutCreateBarrier(int releaseCount); - - // Increment barrier. (excution continues) - void cutIncrementBarrier(CUTBarrier *barrier); - - // Wait for barrier release. - void cutWaitForBarrier(CUTBarrier *barrier); - - // Destory barrier - void cutDestroyBarrier(CUTBarrier *barrier); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // MULTITHREADING_H \ No newline at end of file diff --git a/baleen/_dtw.py b/baleen/_dtw.py new file mode 100644 index 0000000..ac249d1 --- /dev/null +++ b/baleen/_dtw.py @@ -0,0 +1,166 @@ +"""DTW backend — thin delegation to krill's bundled DTW. + +krill ships the same DTW kernels py-baleen historically maintained as the +``_cuda_dtw`` C-extension (cuDTW++ on GPU, CPU fallback otherwise). This module +re-exports that surface with py-baleen's historical ``use_cuda=`` keyword +(krill uses ``use_gpu=``) and keeps the pure-Python GPU memory-planning helpers +the pipeline's chunking relies on (krill does not expose them). +""" +from __future__ import annotations + +import logging +import subprocess +from typing import Optional + +import numpy as np + +try: + import krill +except ModuleNotFoundError as exc: # pragma: no cover - install-time guard + raise ModuleNotFoundError( + "baleen requires the 'krill' engine (DTW + eventalign), which is not on " + "PyPI. Install it from the project index, e.g.:\n" + " pip install krill --no-deps " + "--index-url https://loganylchen.github.io/krill-dist/cu122/simple/ " + "(GPU/cu122)\n" + " pip install krill --no-deps " + "--index-url https://loganylchen.github.io/krill-dist/simple/ (CPU)\n" + "or use a prebuilt baleen Docker image." + ) from exc + +_log = logging.getLogger(__name__) + +# True if krill's GPU DTW is usable right now (extension built + device present). +CUDA_AVAILABLE: bool = bool(krill.dtw_available()) + + +def backend() -> str: + """Return the backend krill auto-selection will use now ('gpu' or 'cpu').""" + return krill.dtw_backend() + + +def is_available() -> bool: + """Check if GPU DTW is usable right now.""" + return CUDA_AVAILABLE + + +def cleanup() -> None: + """Release any cached GPU DTW resources.""" + krill.dtw_cleanup() + + +# --------------------------------------------------------------------------- +# DTW entry points (use_cuda -> use_gpu adapters over krill) +# --------------------------------------------------------------------------- + +def dtw_distance(seq1, seq2, use_cuda: Optional[bool] = None) -> float: + """DTW distance between two 1-D signals.""" + return float(krill.dtw_distance(seq1, seq2, use_gpu=use_cuda)) + + +def dtw_pairwise(sequences, use_cuda: Optional[bool] = None) -> np.ndarray: + """Pairwise DTW matrix for a batch of equal-length signals.""" + return krill.dtw_pairwise(sequences, use_gpu=use_cuda) + + +def dtw_pairwise_varlen(signals, use_cuda: Optional[bool] = None) -> np.ndarray: + """Pairwise DTW matrix for variable-length signals.""" + return krill.dtw_pairwise_varlen(signals, use_gpu=use_cuda) + + +def dtw_multi_position_pairwise( + position_signals, + use_cuda: Optional[bool] = None, + num_streams: int = 16, + device_id: int = 0, +) -> list: + """Pairwise DTW matrices for several positions in one batched call. + + Returns one (N_p, N_p) float64 matrix per position. + """ + return krill.dtw_multi_position_pairwise( + position_signals, + use_gpu=use_cuda, + num_streams=num_streams, + device_id=device_id, + ) + + +# --------------------------------------------------------------------------- +# GPU memory planning (pure-Python; krill does not expose these) +# --------------------------------------------------------------------------- + +_BUCKETS = (127, 255, 511, 1023, 2047) + + +def _select_bucket(n: int) -> int: + for b in _BUCKETS: + if n <= b: + return b + return _BUCKETS[-1] # caller resamples before reaching here + + +def estimate_gpu_memory(position_signals: list[list[np.ndarray]]) -> int: + """Estimate GPU memory bytes for a multi-position pairwise DTW call. + + The cuDTW++ kernel allocates: + - d_subjects: total_seqs * global_bucket * 4 bytes + - d_out: sum(n_p^2) * 4 bytes (float32 squared cost) + """ + total_seqs = sum(len(ps) for ps in position_signals) + max_len = max(len(s) for ps in position_signals for s in ps) + bucket = _select_bucket(max_len) + + input_bytes = total_seqs * bucket * 4 + output_bytes = sum(len(ps) ** 2 for ps in position_signals) * 4 + lengths_bytes = total_seqs * 8 # host-side h_lengths array + + total = input_bytes + output_bytes + lengths_bytes + return int(total * 1.2) # 20% headroom for stream/kernel overhead + + +def get_device_count() -> int: + """Return number of visible CUDA devices.""" + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], + capture_output=True, text=True, timeout=5, + ) + if result.returncode == 0: + return len([l for l in result.stdout.strip().split('\n') if l.strip()]) + except Exception: + pass + return 1 if CUDA_AVAILABLE else 0 + + +def get_per_device_memory() -> list[int]: + """Return total GPU memory in bytes for each visible CUDA device.""" + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=memory.total', + '--format=csv,noheader,nounits'], + capture_output=True, text=True, timeout=5, + ) + if result.returncode == 0: + lines = [l.strip() for l in result.stdout.strip().split('\n') if l.strip()] + return [int(mb) * 1024 * 1024 for mb in lines] + except Exception: + pass + if CUDA_AVAILABLE: + return [8 * 1024 ** 3] # default 8 GB + return [] + + +__all__ = [ + "CUDA_AVAILABLE", + "backend", + "is_available", + "cleanup", + "dtw_distance", + "dtw_pairwise", + "dtw_pairwise_varlen", + "dtw_multi_position_pairwise", + "estimate_gpu_memory", + "get_device_count", + "get_per_device_memory", +] diff --git a/baleen/cli.py b/baleen/cli.py index 433b636..460d253 100644 --- a/baleen/cli.py +++ b/baleen/cli.py @@ -39,7 +39,7 @@ def _parse_cuda_devices(spec: str) -> list[int]: """Parse device spec: '0', '0,1', '0-3', 'all'.""" - from baleen._cuda_dtw import get_device_count + from baleen._dtw import get_device_count spec = spec.strip().lower() if spec == "all": @@ -167,19 +167,19 @@ def _add_run_args(parser: argparse.ArgumentParser) -> None: help="Per-read P(mod) threshold for counting a read as modified (default: 0.9)", ) - # f5c options - f5c = parser.add_argument_group("f5c options") - f5c.add_argument( - "--f5c-threads", type=int, default=None, - help="CPU threads per f5c eventalign call (default: auto = total_cores / threads)", + # eventalign options + eventalign = parser.add_argument_group("eventalign options") + eventalign.add_argument( + "--pore", type=str, default="rna002", + help="krill pore model for eventalign (default: rna002)", ) - f5c.add_argument( + eventalign.add_argument( "--no-rna", action="store_true", default=False, - help="Disable RNA mode for f5c eventalign", + help="Disable RNA mode for eventalign", ) - f5c.add_argument( + eventalign.add_argument( "--kmer-model", type=str, default=None, - help="Custom kmer model for f5c eventalign", + help="Custom kmer model (reserved; unused by the krill engine)", ) # Misc @@ -303,20 +303,6 @@ def _cmd_run(args: argparse.Namespace) -> None: logger.info("Loaded HMM params: %s (%d-state %s)", args.hmm_params, hmm_params.n_states, hmm_params.mode) - # Auto-compute f5c threads: total_cores / pipeline_workers, clamped to [2, 16] - import os - f5c_threads = args.f5c_threads - if f5c_threads is None: - total_cores = os.cpu_count() or 4 - f5c_threads = max(2, min(16, total_cores // max(args.threads, 1))) - # Inject into extra_f5c_args (f5c uses -t for threads) - # NOTE: --iop is only for fast5, not slow5 — omit it to avoid f5c errors - extra_f5c_args = [] - if '-t' not in extra_f5c_args: - extra_f5c_args.extend(['-t', str(f5c_threads)]) - logger.info("f5c threads: -t %d (pipeline workers: %d)", - f5c_threads, args.threads) - # Run streaming pipeline (DTW → HMM → aggregation fused per contig) output_paths, metadata = run_pipeline_streaming( native_bam=args.native_bam, @@ -334,7 +320,7 @@ def _cmd_run(args: argparse.Namespace) -> None: cleanup_temp=not args.keep_temp, rna=not args.no_rna, kmer_model=args.kmer_model, - extra_f5c_args=extra_f5c_args, + pore=args.pore, min_mapq=args.min_mapq, primary_only=not args.no_primary_only, threads=args.threads, diff --git a/baleen/eventalign/_eventalign.py b/baleen/eventalign/_eventalign.py new file mode 100644 index 0000000..a7cadd2 --- /dev/null +++ b/baleen/eventalign/_eventalign.py @@ -0,0 +1,271 @@ +"""krill-backed eventalign — produces f5c-eventalign-format TSV. + +Replaces the external f5c CLI with krill's in-process aligner. For every +primary, forward-mapped read in the contig BAM, the read's raw signal is +aligned to its mapped reference subsequence (HMM confidence disabled: dense, +skip-free), and the result is written in f5c's 16-column ``--samples`` TSV +format so the rest of the pipeline (``group_signals_by_position`` -> DTW -> +V1/V2/V3 -> aggregation) runs unchanged. + +Coordinate convention (verified against real f5c output, RNA002 5-mer): + f5c ``position`` = 0-based index of the FIRST base of the k-mer. + krill ``position`` = central base (kmer_center=2 for a 5-mer). + => f5c_position = krill_position - aligner.kmer_center +``group_signals_by_position`` then applies its usual ``+ len(kmer)//2 + 1`` +shift to align predicted sites with the reference. + +samples column: f5c with ``--scale-events`` writes per-sample pA. We write +pA = (raw + offset) * range / digitisation (krill's convention), so units +match. A residual per-read scale/shift vs f5c is symmetric across native/IVT +and is absorbed by the per-position empirical-Bayes / mixture calibration +downstream. +""" +from __future__ import annotations + +import logging +import subprocess +import time +from pathlib import Path +from typing import Optional, Union, cast + +import numpy as np +import pyfastx +import pysam +import pyslow5 + +logger = logging.getLogger(__name__) + +PathLike = Union[str, Path] + +DEFAULT_PORE = "rna002" + +# f5c eventalign --samples header (16 cols). group_signals_by_position only +# reads: contig, position, reference_kmer, read_name, start_idx, samples. +_HEADER = ( + "contig\tposition\treference_kmer\tread_name\tstrand\tevent_index\t" + "event_level_mean\tevent_stdv\tevent_length\tmodel_kmer\tmodel_mean\t" + "model_stdv\tstandardized_level\tstart_idx\tend_idx\tsamples" +) + +_krill_version: Optional[str] = None +# Aligner construction loads the pore model; cache per (pore) within a process. +_ALIGNER_CACHE: dict[str, object] = {} +_REF_CACHE: dict[str, object] = {} + + +def check_krill() -> str: + """Verify krill is importable and return a version string. + + Returns + ------- + str + krill version (or ``"unknown"`` if krill exposes no ``__version__``). + + Raises + ------ + RuntimeError + If krill cannot be imported. + """ + global _krill_version + if _krill_version is not None: + return _krill_version + try: + import krill + except ModuleNotFoundError as exc: # pragma: no cover - install-time guard + raise RuntimeError( + "krill not found. baleen's eventalign engine requires the 'krill' " + "package (not on PyPI). Install it from the project index, e.g. " + "`pip install krill --no-deps --index-url " + "https://loganylchen.github.io/krill-dist/cu122/simple/` (GPU) or " + "the /simple/ index (CPU), or use a prebuilt baleen Docker image." + ) from exc + version = getattr(krill, "__version__", None) + if version is None: + try: + from importlib.metadata import version as _pkg_version + version = _pkg_version("krill") + except Exception: # noqa: BLE001 + version = "unknown" + _krill_version = str(version) + return _krill_version + + +def _get_aligner(pore: str): + import krill + + if pore not in _ALIGNER_CACHE: + _ALIGNER_CACHE[pore] = krill.Aligner( + pore=pore, + use_gpu=False, + hmm_confidence=False, + keep_kmer_skips=False, + ) + logger.info("krill Aligner(pore=%s) constructed", pore) + return _ALIGNER_CACHE[pore] + + +def _get_ref(ref_fasta: PathLike): + key = str(ref_fasta) + if key not in _REF_CACHE: + _REF_CACHE[key] = pyfastx.Fasta(key) + return _REF_CACHE[key] + + +def is_blow5_indexed(blow5: PathLike) -> bool: + """Check whether a SLOW5/BLOW5 ``.idx`` exists and is non-empty.""" + blow5_path = Path(blow5) + idx_path = blow5_path.with_name(f"{blow5_path.name}.idx") + return idx_path.exists() and idx_path.stat().st_size > 0 + + +def index_blow5(blow5: PathLike) -> None: + """Create a BLOW5 index using slow5tools (pyslow5 random access needs it). + + Raises + ------ + RuntimeError + If the slow5tools indexing command fails. + """ + blow5_path = Path(blow5) + if is_blow5_indexed(blow5_path): + logger.info("Skipping slow5tools index; BLOW5 already indexed: %s", blow5_path) + return + cmd = ["slow5tools", "index", str(blow5_path)] + logger.debug("Running command: %s", " ".join(cmd)) + try: + _ = subprocess.run(cmd, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as exc: + stderr = cast(Optional[str], exc.stderr) + raise RuntimeError(f"slow5tools index failed: {(stderr or '').strip()}") from exc + + +def run_eventalign( + bam: PathLike, + ref_fasta: PathLike, + fastq: PathLike, # unused (kept for signature parity with f5c) + blow5: PathLike, + output_tsv: PathLike, + *, + rna: bool = True, # unused; krill pore drives chemistry + kmer_model: Optional[str] = None, # unused + extra_args: Optional[list[str]] = None, # unused + min_mapq: int = 0, + primary_only: bool = True, + pore: str = DEFAULT_PORE, +) -> Path: + """Align every primary, forward-mapped read in *bam* with krill. + + Drop-in replacement for the former ``_f5c.run_eventalign``: writes an + f5c-format eventalign TSV with HMM confidence disabled (dense, skip-free). + + Returns + ------- + pathlib.Path + Output TSV path. + """ + bam_path = Path(bam) + out_path = Path(output_tsv) + aligner = _get_aligner(pore) + kmer_center = int(aligner.kmer_center) + ref = _get_ref(ref_fasta) + + s5 = pyslow5.Open(str(blow5), "r") + n_reads = n_rows = n_no_signal = n_failed = n_reverse = 0 + + tmp_path = out_path.with_suffix(".tmp") + t0 = time.perf_counter() + try: + with pysam.AlignmentFile(str(bam_path), "rb") as bamf, \ + tmp_path.open("w", encoding="utf-8") as out: + out.write(_HEADER + "\n") + for aln in bamf.fetch(until_eof=True): + if aln.is_unmapped: + continue + if primary_only and (aln.is_secondary or aln.is_supplementary): + continue + if aln.mapping_quality < min_mapq: + continue + if aln.is_reverse: + # Direct-RNA on a transcriptome ref should map forward; + # skip (and count) reverse alignments for safety. + n_reverse += 1 + continue + + rid = aln.query_name + contig = aln.reference_name + rs, re_ = aln.reference_start, aln.reference_end # 0-based half-open + ref_sub = str(ref[contig][rs:re_].seq) + if not ref_sub: + continue + + try: + rd = s5.get_read(rid, pA=False) + except Exception: # noqa: BLE001 + rd = None + if rd is None: + n_no_signal += 1 + continue + + raw = np.asarray(rd["signal"], dtype=np.float32) + digit = float(rd["digitisation"]) + offset = float(rd["offset"]) + rng = float(rd["range"]) + sr = float(rd["sampling_rate"]) + pA = (raw + offset) * (rng / digit) + + res = aligner.align({ + "read_id": rid, "sequence": ref_sub, "signal": raw, + "digitisation": digit, "offset": offset, "range": rng, + "sample_rate": sr, "start": rs, + })[0] + n_reads += 1 + if res["status"] != 0: + n_failed += 1 + continue + + P = res["position"] + RK = res["reference_kmer"] + EI = res["event_index"] + ELM = res["event_level_mean"] + ESD = res["event_stdv"] + ELN = res["event_length"] + MK = res["model_kmer"] + MM = res["model_mean"] + MSD = res["model_stdv"] + SL = res["standardized_level"] + SI = res["start_idx"] + END = res["end_idx"] + + for i in range(int(P.size)): + si = int(SI[i]) + ei = int(END[i]) + seg = pA[si:ei] + if seg.size == 0: + continue + f5c_pos = int(P[i]) - kmer_center + if f5c_pos < 0: + continue + samples = ",".join(np.char.mod("%.3f", seg)) + sl = SL[i] + sl_str = "" if (sl != sl) else f"{float(sl):.2f}" # NaN -> "" + out.write( + f"{contig}\t{f5c_pos}\t{RK[i]}\t{rid}\tt\t{int(EI[i])}\t" + f"{float(ELM[i]):.2f}\t{float(ESD[i]):.3f}\t{float(ELN[i]):.5f}\t" + f"{MK[i]}\t{float(MM[i]):.2f}\t{float(MSD[i]):.2f}\t{sl_str}\t" + f"{si}\t{ei}\t{samples}\n" + ) + n_rows += 1 + tmp_path.replace(out_path) + except BaseException: + tmp_path.unlink(missing_ok=True) + raise + finally: + s5.close() + + logger.debug( + "krill eventalign %s: %d reads, %d rows in %.1fs " + "(no_signal=%d, failed=%d, reverse=%d)", + bam_path.name, n_reads, n_rows, time.perf_counter() - t0, + n_no_signal, n_failed, n_reverse, + ) + return out_path diff --git a/baleen/eventalign/_f5c.py b/baleen/eventalign/_f5c.py deleted file mode 100644 index f7752d2..0000000 --- a/baleen/eventalign/_f5c.py +++ /dev/null @@ -1,305 +0,0 @@ -"""Utilities for invoking the f5c eventalign command-line tool. - -This module provides a small wrapper around f5c/slow5tools subprocess calls -used by the eventalign pipeline. -""" - -from __future__ import annotations - -import logging -import re -import subprocess -import time -from pathlib import Path -from typing import Optional, Union, cast - -logger = logging.getLogger(__name__) - -_f5c_version: Optional[str] = None -PathLike = Union[str, Path] -_VERSION_PATTERN = re.compile(r"\bf5c\s+v?(\d+(?:\.\d+)*)\b", re.IGNORECASE) - - -def check_f5c() -> str: - """Check f5c availability and cache its version. - - Returns - ------- - str - Installed f5c version string (for example ``"1.6"``). - - Raises - ------ - RuntimeError - If f5c is not available in ``PATH``. - RuntimeError - If version output cannot be parsed. - """ - global _f5c_version - - if _f5c_version is not None: - return _f5c_version - - try: - result = subprocess.run( - ["f5c", "--version"], - check=True, - capture_output=True, - text=True, - ) - except FileNotFoundError as exc: - raise RuntimeError("f5c not found in PATH. Please install f5c.") from exc - except subprocess.CalledProcessError as exc: - stderr = cast(Optional[str], exc.stderr) - stderr_text = (stderr or "").strip() - raise RuntimeError(f"f5c not found or failed to run: {stderr_text}") from exc - - version_output = (result.stdout or result.stderr or "").strip() - match = _VERSION_PATTERN.search(version_output) - if match is None: - raise RuntimeError(f"Could not parse f5c version from output: {version_output!r}") - - version = match.group(1) - _f5c_version = version - return version - - -def get_f5c_version() -> tuple[int, ...]: - """Return parsed f5c version components. - - Returns - ------- - tuple of int - Parsed version tuple (for example ``(1, 6)`` or ``(1, 6, 1)``). - - Raises - ------ - RuntimeError - If f5c cannot be detected or version cannot be parsed. - """ - version_str = _f5c_version if _f5c_version is not None else check_f5c() - return tuple(int(part) for part in version_str.split(".")) - - -def is_indexed(fastq: PathLike) -> bool: - """Check whether f5c FASTQ index exists and is non-empty. - - Parameters - ---------- - fastq : str or pathlib.Path - FASTQ file used for f5c indexing. - - Returns - ------- - bool - ``True`` when ``.index.readdb`` exists and has non-zero size, - otherwise ``False``. - """ - fastq_path = Path(fastq) - index_path = fastq_path.with_name(f"{fastq_path.name}.index.readdb") - return index_path.exists() and index_path.stat().st_size > 0 - - -def is_blow5_indexed(blow5: PathLike) -> bool: - """Check whether SLOW5/BLOW5 index exists and is non-empty. - - Parameters - ---------- - blow5 : str or pathlib.Path - BLOW5 file to check. - - Returns - ------- - bool - ``True`` when ``.idx`` exists and has non-zero size, - otherwise ``False``. - """ - blow5_path = Path(blow5) - idx_path = blow5_path.with_name(f"{blow5_path.name}.idx") - return idx_path.exists() and idx_path.stat().st_size > 0 - - -def _run_f5c_index(cmd: list[str], error_msg: str) -> None: - """Run an indexing command (f5c or slow5tools) with error handling. - - Parameters - ---------- - cmd : list of str - Command to execute. - error_msg : str - Prefix for error messages on failure. - - Raises - ------ - RuntimeError - If the indexing command fails. - """ - logger.debug("Running command: %s", " ".join(cmd)) - t0 = time.perf_counter() - try: - _ = subprocess.run(cmd, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as exc: - stderr = cast(Optional[str], exc.stderr) - stderr_text = (stderr or "").strip() - raise RuntimeError(f"{error_msg}: {stderr_text}") from exc - logger.debug("Indexing completed in %.1fs", time.perf_counter() - t0) - - -def index_fastq_blow5(fastq: PathLike, blow5: PathLike) -> None: - """Index FASTQ against BLOW5 using f5c. - - Parameters - ---------- - fastq : str or pathlib.Path - FASTQ file path. - blow5 : str or pathlib.Path - BLOW5 file path. - - Raises - ------ - RuntimeError - If f5c indexing command fails. - """ - fastq_path = Path(fastq) - blow5_path = Path(blow5) - - if is_indexed(fastq_path): - logger.info("Skipping f5c index; FASTQ already indexed: %s", fastq_path) - return - - cmd = ["f5c", "index", "--slow5", str(blow5_path), str(fastq_path)] - _run_f5c_index(cmd, "f5c index failed") - - -def index_blow5(blow5: PathLike) -> None: - """Create BLOW5 index using slow5tools. - - Parameters - ---------- - blow5 : str or pathlib.Path - BLOW5 file path. - - Raises - ------ - RuntimeError - If slow5tools indexing command fails. - """ - blow5_path = Path(blow5) - - if is_blow5_indexed(blow5_path): - logger.info("Skipping slow5tools index; BLOW5 already indexed: %s", blow5_path) - return - - cmd = ["slow5tools", "index", str(blow5_path)] - _run_f5c_index(cmd, "slow5tools index failed") - - -def run_eventalign( - bam: PathLike, - ref_fasta: PathLike, - fastq: PathLike, - blow5: PathLike, - output_tsv: PathLike, - *, - rna: bool = True, - kmer_model: Optional[str] = None, - extra_args: Optional[list[str]] = None, - min_mapq: int = 0, -) -> Path: - """Run ``f5c eventalign`` and write TSV output. - - Parameters - ---------- - bam : str or pathlib.Path - Input BAM file. - ref_fasta : str or pathlib.Path - Reference FASTA file. - fastq : str or pathlib.Path - Input FASTQ file. - blow5 : str or pathlib.Path - Input BLOW5 file. - output_tsv : str or pathlib.Path - Output eventalign TSV path. - rna : bool, optional - If ``True``, include the ``--rna`` flag. - kmer_model : str, optional - Optional k-mer model path/name passed via ``--kmer-model``. - extra_args : list of str, optional - Additional command-line arguments appended as-is. - min_mapq : int, optional - Minimum mapping quality passed to ``f5c --min-mapq``. Defaults to - ``0`` so that f5c does not apply its own MAPQ filter (the pipeline - already filters reads during BAM splitting). - - Returns - ------- - pathlib.Path - Output TSV path. - - Raises - ------ - RuntimeError - If the f5c eventalign command fails. - """ - bam_path = Path(bam) - ref_fasta_path = Path(ref_fasta) - fastq_path = Path(fastq) - blow5_path = Path(blow5) - output_path = Path(output_tsv) - - cmd = [ - "f5c", - "eventalign", - "-b", - str(bam_path), - "-g", - str(ref_fasta_path), - "-r", - str(fastq_path), - "--slow5", - str(blow5_path), - "--samples", - "--signal-index", - "--scale-events", - "--print-read-names", - ] - - if rna: - cmd.append("--rna") - if kmer_model is not None: - cmd.extend(["--kmer-model", kmer_model]) - cmd.extend(["--min-mapq", str(min_mapq)]) - if extra_args: - cmd.extend(extra_args) - - logger.debug("Running command: %s", " ".join(cmd)) - - # Write to a temporary file, then rename atomically on success - tmp_path = output_path.with_suffix(".tmp") - t0 = time.perf_counter() - try: - with tmp_path.open("w", encoding="utf-8") as output_file: - _ = subprocess.run( - cmd, - check=True, - stdout=output_file, - stderr=subprocess.PIPE, - text=True, - ) - # Atomic rename on success - tmp_path.replace(output_path) - except subprocess.CalledProcessError as exc: - stderr = cast(Optional[str], exc.stderr) - stderr_text = (stderr or "").strip() - # Clean up temporary file on failure - tmp_path.unlink(missing_ok=True) - raise RuntimeError(f"f5c eventalign failed: {stderr_text}") from exc - except BaseException: - # Clean up temporary file on any other exception - tmp_path.unlink(missing_ok=True) - raise - - elapsed = time.perf_counter() - t0 - size_kb = output_path.stat().st_size / 1024 - logger.debug("f5c eventalign completed in %.1fs (output: %.1f KB): %s", elapsed, size_kb, output_path) - return output_path diff --git a/baleen/eventalign/_pipeline.py b/baleen/eventalign/_pipeline.py index 92a070e..583d3e5 100644 --- a/baleen/eventalign/_pipeline.py +++ b/baleen/eventalign/_pipeline.py @@ -17,9 +17,9 @@ from numpy.typing import NDArray from tqdm.auto import tqdm -from baleen import _cuda_dtw +from baleen import _dtw from baleen.eventalign import _bam -from baleen.eventalign import _f5c +from baleen.eventalign import _eventalign from baleen.eventalign import _signal logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def _sanitize_contig_filename(name: str) -> str: # using different ``--min-depth``, modified BAMs, etc. _RESUME_PARAMS_FILENAME = ".run_params.json" -_RESUME_FINGERPRINT_SCHEMA = 1 +_RESUME_FINGERPRINT_SCHEMA = 2 def _file_fingerprint(path: Optional[PathLike]) -> Optional[dict]: @@ -99,6 +99,7 @@ def _compute_resume_fingerprint( run_hmm: bool, target_contigs: Optional[list[str]], read_intersection: bool, + pore: str, ) -> dict: """Build a JSON-serializable dict capturing everything that would invalidate a partial run. @@ -119,6 +120,7 @@ def _compute_resume_fingerprint( "depth_mode": str(depth_mode), "padding": int(padding), "min_mapq": int(min_mapq), + "pore": str(pore), "primary_only": bool(primary_only), "subsample": bool(subsample), "subsample_n": int(subsample_n), @@ -158,6 +160,19 @@ def _validate_resume_compatibility( f"Cannot resume: failed to read {fp_path}: {exc}" ) from exc + # A schema bump signals an incompatible fingerprint/output format (e.g. a + # new param or a change in how slices are written); reject outright so old + # partial outputs are never silently mixed with new ones. + prior_schema = prior.get("schema_version") + curr_schema = current.get("schema_version") + if prior_schema != curr_schema: + raise RuntimeError( + "Cannot resume: fingerprint schema mismatch " + f"(prior={prior_schema!r} now={curr_schema!r}). The run was " + f"produced by an incompatible baleen version. " + f"Delete {per_contig_dir} or run without --resume." + ) + mismatches: list[str] = [] for section in ("inputs", "params"): prior_section = prior.get(section, {}) or {} @@ -276,7 +291,7 @@ class ContigResult: @dataclass class PipelineMetadata: - f5c_version: str + eventalign_version: str min_depth: int use_cuda: Optional[bool] padding: int @@ -306,8 +321,8 @@ class _SerializedPayload(TypedDict): metadata: PipelineMetadata -_dtw_distance = cast(_DtwDistanceFn, _cuda_dtw.dtw_distance) -_dtw_pairwise_varlen = _cuda_dtw.dtw_pairwise_varlen +_dtw_distance = cast(_DtwDistanceFn, _dtw.dtw_distance) +_dtw_pairwise_varlen = _dtw.dtw_pairwise_varlen def _compute_pairwise_distances( @@ -324,7 +339,7 @@ def _compute_pairwise_distances( ) t0 = time.perf_counter() - want_cuda = use_cuda is True or (use_cuda is None and _cuda_dtw.CUDA_AVAILABLE) + want_cuda = use_cuda is True or (use_cuda is None and _dtw.CUDA_AVAILABLE) if want_cuda: matrix = _dtw_pairwise_varlen( @@ -342,17 +357,7 @@ def _compute_pairwise_distances( def _compute_pairwise_batch( signals: list[NDArray[np.float32]], ) -> NDArray[np.float64]: - from tslearn.metrics import dtw as _tslearn_dtw - - n = len(signals) - prepped = [s.reshape(-1, 1) for s in signals] - matrix = np.zeros((n, n), dtype=np.float64) - for i in range(n): - for j in range(i + 1, n): - d = float(_tslearn_dtw(prepped[i], prepped[j])) - matrix[i, j] = d - matrix[j, i] = d - return matrix + return _dtw_pairwise_varlen(signals, use_cuda=False) def _compute_pairwise_loop( @@ -453,7 +458,7 @@ def _get_gpu_memory(cuda_devices: Optional[list[int]] = None) -> list[int]: list[int] Memory in bytes per device. Falls back to ``[8 GB]`` on failure. """ - all_mems = _cuda_dtw.get_per_device_memory() + all_mems = _dtw.get_per_device_memory() if not all_mems: return [8 * 1024 ** 3] if cuda_devices is not None: @@ -505,7 +510,7 @@ def _process_contig( padding: int, rna: bool, kmer_model: Optional[str], - extra_f5c_args: Optional[list[str]], + pore: str, min_mapq: int, primary_only: bool, cleanup_temp: bool, @@ -599,9 +604,9 @@ def _process_contig( native_tsv = contig_tmp / "native.eventalign.tsv" ivt_tsv = contig_tmp / "ivt.eventalign.tsv" - logger.info(" Running f5c eventalign (native)...") + logger.info(" Running krill eventalign (native)...") ea_t0 = time.perf_counter() - _ = _f5c.run_eventalign( + _ = _eventalign.run_eventalign( native_contig_bam, ref_fasta, native_fastq, @@ -609,11 +614,12 @@ def _process_contig( native_tsv, rna=rna, kmer_model=kmer_model, - extra_args=extra_f5c_args, min_mapq=min_mapq, + primary_only=primary_only, + pore=pore, ) - logger.info(" Running f5c eventalign (IVT)...") - _ = _f5c.run_eventalign( + logger.info(" Running krill eventalign (IVT)...") + _ = _eventalign.run_eventalign( ivt_contig_bam, ref_fasta, ivt_fastq, @@ -621,8 +627,9 @@ def _process_contig( ivt_tsv, rna=rna, kmer_model=kmer_model, - extra_args=extra_f5c_args, min_mapq=min_mapq, + primary_only=primary_only, + pore=pore, ) logger.info(" Eventalign done (%s)", _fmt_elapsed(time.perf_counter() - ea_t0)) @@ -700,7 +707,7 @@ def _process_contig( current_estimate = 0 for i, sigs in enumerate(all_signal_lists): - pos_estimate = _cuda_dtw.estimate_gpu_memory([sigs]) + pos_estimate = _dtw.estimate_gpu_memory([sigs]) if current_chunk and current_estimate + pos_estimate > chunk_mem_limit: chunks.append(current_chunk) current_chunk = [i] @@ -716,9 +723,9 @@ def _process_contig( for chunk_idx, chunk_indices in enumerate(chunks): chunk_signals = [all_signal_lists[i] for i in chunk_indices] - estimated_bytes = _cuda_dtw.estimate_gpu_memory(chunk_signals) + estimated_bytes = _dtw.estimate_gpu_memory(chunk_signals) - chunk_matrices = _cuda_dtw.dtw_multi_position_pairwise( + chunk_matrices = _dtw.dtw_multi_position_pairwise( chunk_signals, use_cuda=use_cuda, num_streams=num_cuda_streams, @@ -796,7 +803,7 @@ def _process_contig_streaming( padding: int, rna: bool, kmer_model: Optional[str], - extra_f5c_args: Optional[list[str]], + pore: str, min_mapq: int, primary_only: bool, cleanup_temp: bool, @@ -878,7 +885,7 @@ def _process_contig_streaming( padding=padding, rna=rna, kmer_model=kmer_model, - extra_f5c_args=extra_f5c_args, + pore=pore, min_mapq=min_mapq, primary_only=primary_only, cleanup_temp=cleanup_temp, @@ -988,7 +995,7 @@ def run_pipeline( cleanup_temp: bool = True, rna: bool = True, kmer_model: Optional[str] = None, - extra_f5c_args: Optional[list[str]] = None, + pore: str = _eventalign.DEFAULT_PORE, min_mapq: int = 20, primary_only: bool = True, threads: int = 1, @@ -1013,11 +1020,10 @@ def run_pipeline( min_mapq, primary_only, num_cuda_streams) logger.info(" subsample=%s subsample_n=%d gpu_memory_limit=%s", subsample, subsample_n, gpu_memory_limit) - logger.info(" cleanup_temp=%s kmer_model=%s extra_f5c_args=%s", - cleanup_temp, kmer_model, extra_f5c_args) - logger.info(" DTW backend: %s (CUDA=%s, tslearn=%s)", - _cuda_dtw.backend(), _cuda_dtw.CUDA_AVAILABLE, - _cuda_dtw.TSLEARN_AVAILABLE) + logger.info(" cleanup_temp=%s kmer_model=%s pore=%s", + cleanup_temp, kmer_model, pore) + logger.info(" DTW backend: %s (GPU=%s)", + _dtw.backend(), _dtw.CUDA_AVAILABLE) logger.info("=" * 60) # Validate threads parameter @@ -1043,22 +1049,18 @@ def run_pipeline( ivt_blow5 = Path(ivt_blow5) ref_fasta = Path(ref_fasta) - # ---- Step 1: f5c version check ---- - logger.info("[Step 1/6] Checking f5c availability...") - f5c_version = _f5c.check_f5c() - logger.info("[Step 1/6] f5c version %s OK", f5c_version) + # ---- Step 1: krill engine check ---- + logger.info("[Step 1/6] Checking krill availability...") + eventalign_version = _eventalign.check_krill() + logger.info("[Step 1/6] krill version %s OK", eventalign_version) # ---- Step 2: Indexing ---- - logger.info("[Step 2/6] Indexing FASTQ and BLOW5 files...") + logger.info("[Step 2/6] Indexing BLOW5 files...") step_t0 = time.perf_counter() - logger.info(" Indexing native FASTQ against BLOW5...") - _f5c.index_fastq_blow5(native_fastq, native_blow5) - logger.info(" Indexing IVT FASTQ against BLOW5...") - _f5c.index_fastq_blow5(ivt_fastq, ivt_blow5) logger.info(" Indexing native BLOW5...") - _f5c.index_blow5(native_blow5) + _eventalign.index_blow5(native_blow5) logger.info(" Indexing IVT BLOW5...") - _f5c.index_blow5(ivt_blow5) + _eventalign.index_blow5(ivt_blow5) logger.info("[Step 2/6] Indexing complete (%s)", _fmt_elapsed(time.perf_counter() - step_t0)) # ---- Step 3: BAM validation & contig stats ---- @@ -1105,7 +1107,7 @@ def run_pipeline( logger.info(" SKIP: %s — %s", fr.contig, fr.reason.value) metadata = PipelineMetadata( - f5c_version=f5c_version, + eventalign_version=eventalign_version, min_depth=min_depth, use_cuda=use_cuda, padding=padding, @@ -1160,7 +1162,7 @@ def run_pipeline( padding=padding, rna=rna, kmer_model=kmer_model, - extra_f5c_args=extra_f5c_args, + pore=pore, min_mapq=min_mapq, primary_only=primary_only, cleanup_temp=cleanup_temp, @@ -1217,7 +1219,7 @@ def run_pipeline( padding=padding, rna=rna, kmer_model=kmer_model, - extra_f5c_args=extra_f5c_args, + pore=pore, min_mapq=min_mapq, primary_only=primary_only, cleanup_temp=cleanup_temp, @@ -1267,7 +1269,7 @@ def run_pipeline_streaming( cleanup_temp: bool = True, rna: bool = True, kmer_model: Optional[str] = None, - extra_f5c_args: Optional[list[str]] = None, + pore: str = _eventalign.DEFAULT_PORE, min_mapq: int = 20, primary_only: bool = True, threads: int = 1, @@ -1344,7 +1346,7 @@ def run_pipeline_streaming( run_hmm, legacy_scoring, mod_threshold) logger.info(" target_contigs=%s keep_intermediate=%s cleanup_temp=%s", target_contigs, keep_intermediate, cleanup_temp) - logger.info(" kmer_model=%s extra_f5c_args=%s", kmer_model, extra_f5c_args) + logger.info(" kmer_model=%s pore=%s", kmer_model, pore) logger.info("=" * 60) if threads < 1: @@ -1367,22 +1369,20 @@ def run_pipeline_streaming( ivt_blow5 = Path(ivt_blow5) ref_fasta = Path(ref_fasta) - # ---- Step 1: f5c version check ---- - logger.info("[Step 1/5] Checking f5c availability...") - f5c_version = _f5c.check_f5c() - logger.info("[Step 1/5] f5c version %s OK", f5c_version) + # ---- Step 1: krill engine check ---- + logger.info("[Step 1/5] Checking krill availability...") + eventalign_version = _eventalign.check_krill() + logger.info("[Step 1/5] krill version %s OK", eventalign_version) # ---- Step 2: Indexing ---- - logger.info("[Step 2/5] Indexing FASTQ and BLOW5 files...") + logger.info("[Step 2/5] Indexing BLOW5 files...") step_t0 = time.perf_counter() - _f5c.index_fastq_blow5(native_fastq, native_blow5) - _f5c.index_fastq_blow5(ivt_fastq, ivt_blow5) - _f5c.index_blow5(native_blow5) - _f5c.index_blow5(ivt_blow5) + _eventalign.index_blow5(native_blow5) + _eventalign.index_blow5(ivt_blow5) logger.info("[Step 2/5] Indexing complete (%s)", _fmt_elapsed(time.perf_counter() - step_t0)) # ---- Step 2.5: Read-ID intersection (BAM ∩ FASTQ ∩ BLOW5) ---- - # f5c eventalign silently drops BAM reads whose UUIDs are not in + # eventalign silently drops BAM reads whose UUIDs are not in # the BLOW5 signal file; computing the intersection up-front keeps # contig stats, ``min_depth`` filtering, and subsampling all in # sync with the read set that will actually produce signals. @@ -1473,7 +1473,7 @@ def run_pipeline_streaming( len(passed_contigs), _fmt_elapsed(time.perf_counter() - step_t0)) metadata = PipelineMetadata( - f5c_version=f5c_version, + eventalign_version=eventalign_version, min_depth=min_depth, use_cuda=use_cuda, padding=padding, @@ -1533,6 +1533,7 @@ def run_pipeline_streaming( run_hmm=run_hmm, target_contigs=target_contigs, read_intersection=read_intersection, + pore=pore, ) resumed_summaries: list[ContigSummary] = [] if resume: @@ -1630,7 +1631,7 @@ def run_pipeline_streaming( padding=padding, rna=rna, kmer_model=kmer_model, - extra_f5c_args=extra_f5c_args, + pore=pore, min_mapq=min_mapq, primary_only=primary_only, cleanup_temp=cleanup_temp, diff --git a/benchmarks/bench.py b/benchmarks/bench.py index ae9acf8..9885752 100644 --- a/benchmarks/bench.py +++ b/benchmarks/bench.py @@ -105,13 +105,13 @@ def _gpu_info() -> dict: def _env_snapshot() -> dict: - from baleen._cuda_dtw import CUDA_AVAILABLE + from baleen import _dtw env = { "timestamp": datetime.now(timezone.utc).isoformat(), "hostname": platform.node(), "python_version": platform.python_version(), - "cuda_available": CUDA_AVAILABLE, - "dtw_backend": "cuda" if CUDA_AVAILABLE else "tslearn", + "cuda_available": _dtw.CUDA_AVAILABLE, + "dtw_backend": _dtw.backend(), # 'gpu' or 'cpu' (krill) } env.update(_git_info()) env.update(_gpu_info()) @@ -264,17 +264,17 @@ def _summarize_per_contig(values: list[float]) -> dict: # Classification tiers. Order matters (first match wins). # Each tier is (bucket_name, filepath_fragments, funcname_fragments). # Funcname matching catches built-in C extensions where filepath is "~", -# e.g. "{built-in method baleen._cuda_dtw._cuda_dtw.dtw_multi_position_pairwise}". +# e.g. "{built-in method krill._krill.dtw_multi_position_pairwise}". _BUCKET_PATTERNS: list[tuple[str, tuple[str, ...], tuple[str, ...]]] = [ # Baleen project modules — match by filename fragment, also by funcname - # so built-ins like _cuda_dtw's C kernel land in the right bucket. - ("dtw", ("_cuda_dtw",), ("_cuda_dtw",)), + # so built-ins like krill's C DTW kernel land in the right bucket. + ("dtw", ("_dtw.py", "krill"), ("dtw_distance", "dtw_pairwise", "dtw_pairwise_varlen", "dtw_multi_position_pairwise")), ("hmm_training", ("_hmm_training.py",), ()), ("hmm", ("_hierarchical.py",), ()), ("probability", ("_probability.py",), ()), ("aggregation", ("_aggregation.py",), ()), ("signal", ("_signal.py",), ()), - ("f5c", ("_f5c.py",), ()), + ("eventalign", ("_eventalign.py",), ()), ("bam", ("_bam.py", "pysam/"), ()), ("pipeline", ("_pipeline.py",), ()), # IO / subprocess — typically f5c output reads and subprocess plumbing. diff --git a/benchmarks/krill_exp/gate_b.py b/benchmarks/krill_exp/gate_b.py new file mode 100644 index 0000000..1ed6f0a --- /dev/null +++ b/benchmarks/krill_exp/gate_b.py @@ -0,0 +1,117 @@ +"""Gate B: end-to-end AUROC/AUPRC for the current eventalign engine. + +Runs the production pipeline on testdata for the requested stoichiometry +levels (fixed params: padding=1, subsample_n=300, run_hmm=True) and scores +site calls against testdata/known_modifications.tsv. Emits one JSON line per +stoich to a results file so two engines (f5c vs krill) can be compared: + + AUROC_krill >= AUROC_f5c - 0.02 at every level. + +Run inside the GPU image (krill GPU DTW is identical to the legacy _cuda_dtw +GPU kernel, so this isolates the eventalign engine): + + docker run --rm --gpus all -v "$PWD":/work -w /work \ + --entrypoint python3 py-baleen-krill:gpu \ + benchmarks/krill_exp/gate_b.py --label krill --stoich 0.5,1.0 \ + --out benchmarks/krill_exp/gate_b.jsonl +""" +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path + +REPO = Path(__file__).resolve().parents[2] +TESTDATA = REPO / "testdata" + +# Ensure the mounted repo source wins over any installed baleen, and that +# bench.py (one level up) is importable for its scoring helpers. Without the +# REPO insert, `python3 benchmarks/krill_exp/gate_b.py` puts the *script dir* +# on sys.path[0] and `import baleen` silently resolves to site-packages. +sys.path.insert(0, str(REPO)) +sys.path.insert(0, str(REPO / "benchmarks")) +from bench import _compute_accuracy, _load_known_mods # noqa: E402 + + +def run_one(stoich: str) -> tuple[list, float]: + import tempfile + + from baleen.eventalign._pipeline import run_pipeline_streaming + + d = TESTDATA / stoich / "data" + nat = d / "native_1" + ctl = d / "control_1" + with tempfile.TemporaryDirectory(prefix="gateb_") as tmp: + t0 = time.perf_counter() + paths, _meta = run_pipeline_streaming( + native_bam=nat / "native_1.bam", + native_fastq=nat / "fastq" / "pass.fq.gz", + native_blow5=nat / "blow5" / "nanopore.blow5", + ivt_bam=ctl / "control_1.bam", + ivt_fastq=ctl / "fastq" / "pass.fq.gz", + ivt_blow5=ctl / "blow5" / "nanopore.blow5", + ref_fasta=TESTDATA / "ref.fa", + min_depth=10, + padding=1, + subsample_n=300, + run_hmm=True, + use_cuda=None, + threads=1, + output_dir=tmp, + write_bam=False, + cleanup_temp=True, + ) + wall = time.perf_counter() - t0 + sites = _load_sites(Path(paths["site_tsv"])) + return sites, wall + + +def _load_sites(tsv: Path) -> list: + import csv + from types import SimpleNamespace + rows = [] + with tsv.open(newline="") as f: + for row in csv.DictReader(f, delimiter="\t"): + rows.append(SimpleNamespace( + contig=row["contig"], position=int(row["position"]), + mod_ratio=float(row["mod_ratio"]), + mean_p_mod=float(row["mean_p_mod"]), + effect_size=float(row["effect_size"]), + stoichiometry=float(row["stoichiometry"]), + pvalue=float(row["pvalue"]), padj=float(row["padj"]), + )) + return rows + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--label", required=True, help="engine label, e.g. f5c or krill") + ap.add_argument("--stoich", default="0.5,1.0") + ap.add_argument("--out", default="benchmarks/krill_exp/gate_b.jsonl") + args = ap.parse_args() + + known = _load_known_mods() + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + + for stoich in [s.strip() for s in args.stoich.split(",")]: + sites, wall = run_one(stoich) + acc = _compute_accuracy(sites, known) + rec = {"label": args.label, "stoich": stoich, "wall_s": round(wall, 1), + "n_sites": len(sites), "accuracy": acc} + with out.open("a") as f: + f.write(json.dumps(rec) + "\n") + def _g(k): + v = acc.get(k) + return f"{v:.4f}" if v is not None else "n/a" + print(f"[{args.label}] stoich={stoich} sites={len(sites)} " + f"AUROC(mod_ratio)={_g('auroc_mod_ratio')} " + f"AUPRC(mod_ratio)={_g('auprc_mod_ratio')} " + f"AUROC(-log10p)={_g('auroc_nlog10_pvalue')} ({wall:.0f}s)") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/docs/changelog.md b/docs/changelog.md index d8008be..500db0b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,6 +5,31 @@ changes by theme; see the [full commit history](https://github.com/loganylchen/py-baleen/commits/dev) for detail. +## v0.4.0 + +### Changed + +- **krill engine replaces f5c + the in-tree CUDA DTW extension.** Both event + alignment and DTW now run through the + [krill](https://loganylchen.github.io/krill-dist/) package. GPU DTW is + bit-identical to the old in-tree kernel; eventalign is HMM-free forced-dense + and emits an f5c-format TSV, so every downstream stage is unchanged. krill + reads BLOW5 directly via pyslow5 — the old `f5c index` FASTQ-index step is + gone (you still need `slow5tools index` for the `.blow5.idx`). +- **baleen is now pure Python** — no C extension, no `nvcc` build. krill installs + from a project index (cu122 GPU wheel or plain CPU wheel), not PyPI. + +### Added + +- **`--pore`** — select the krill pore model for eventalign (default `rna002`). + +### Removed + +- **`--f5c-threads`** — krill eventalign runs in-process, not as a separate + multithreaded subprocess. +- **`BALEEN_NO_CUDA` / `BALEEN_CUDA_ARCHS`** build-time flags — the GPU/CPU split + is now decided by which krill wheel is installed. + ## Unreleased (dev) ### Features diff --git a/docs/contributing.md b/docs/contributing.md index dfc383b..d2f1ecc 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -9,10 +9,21 @@ commit conventions. git clone https://github.com/loganylchen/py-baleen.git cd py-baleen -# Editable install with test deps (CPU-only build is fastest for iterating) -BALEEN_NO_CUDA=1 pip install -e ".[test]" +# Editable install with test deps +pip install -e ".[test]" ``` +!!! note "krill is a required non-PyPI dependency" + The DTW + eventalign engine ships from a project index, not PyPI. Install + it separately (CPU or GPU `cu122` wheel): + + ```bash + # CPU + pip install krill --no-deps --index-url https://loganylchen.github.io/krill-dist/simple/ + # GPU (CUDA 12.2) + pip install krill --no-deps --index-url https://loganylchen.github.io/krill-dist/cu122/simple/ + ``` + For docs work, add the docs extra: ```bash @@ -31,8 +42,8 @@ pytest tests/test_dtw.py pytest tests/test_dtw.py::test_dtw_distance_basic -v ``` -CI runs the suite on Python 3.9, 3.10, and 3.11 with a CPU-only build -(`BALEEN_NO_CUDA=1`). Make sure `pytest` passes locally before opening a PR. +CI runs the suite on Python 3.10, 3.11, and 3.12. Make sure `pytest` passes +locally before opening a PR. ## Benchmarks @@ -66,19 +77,18 @@ Baleen uses [Conventional Commits](https://www.conventionalcommits.org/): A `!` after the type (e.g. `feat(filter)!:`) marks a breaking change. -## CUDA notes +## DTW engine -The CUDA kernel is **FP32-only** by design (FP16 cripples Pascal consumer GPUs). -If you touch the DTW kernel, verify any "skip work" optimisation actually reduces -thread count or diagonal count — setting cells to infinity in place is pure -overhead. See [Performance & Scaling](guide/performance.md#cuda-kernel-characteristics). +The DTW kernels (GPU + CPU) live in the external **krill** package, not in this +repo; `baleen/_dtw.py` is a thin shim over them. There is no in-tree CUDA code +to build or maintain. ## Project layout ``` baleen/ -├── _cuda_dtw/ # CUDA DTW + CPU fallback -└── eventalign/ # pipeline, BAM/signal/f5c IO, hierarchical model, HMM training +├── _dtw.py # DTW shim over krill +└── eventalign/ # pipeline, BAM/signal/eventalign IO, hierarchical model, HMM training tests/ # pytest suite benchmarks/ # bench.py harness docs/ # this site (MkDocs Material) diff --git a/docs/guide/cli.md b/docs/guide/cli.md index 404c55c..66c0295 100644 --- a/docs/guide/cli.md +++ b/docs/guide/cli.md @@ -4,7 +4,7 @@ Baleen exposes two sub-commands: | Command | Purpose | |---------|---------| -| [`baleen run`](#baleen-run) | Full pipeline: read-ID intersection → f5c eventalign → DTW → HMM → site aggregation. | +| [`baleen run`](#baleen-run) | Full pipeline: read-ID intersection → krill eventalign → DTW → HMM → site aggregation. | | [`baleen aggregate`](#baleen-aggregate) | Re-run HMM and/or site aggregation from a saved `.pkl`, skipping DTW. | ```bash @@ -55,7 +55,7 @@ baleen aggregate --help | Flag | Default | Description | |------|---------|-------------| | `--cuda [DEVICES]` | auto-detect | CUDA device(s): `0`, `0,1`, `0-3`, or `all`. | -| `--no-cuda` | off | Force the CPU (`tslearn`) backend. | +| `--no-cuda` | off | Force the CPU backend. | | `--gpu-memory-limit BYTES` | auto-detect | GPU memory budget for concurrent DTW workers. | ### HMM options @@ -65,13 +65,13 @@ baleen aggregate --help | `--hmm-params` | 3-state unsupervised | Path to a trained HMM parameters JSON. See [HMM Training Modes](hmm-training.md). | | `--no-hmm` | off | Skip HMM smoothing; output V2 scores only. | -### f5c options +### eventalign options | Flag | Default | Description | |------|---------|-------------| -| `--f5c-threads` | auto (`cores / threads`) | CPU threads per `f5c eventalign` call. | -| `--no-rna` | off | Disable RNA mode for f5c. | -| `--kmer-model` | — | Custom k-mer model for f5c. | +| `--pore` | `rna002` | krill pore model for eventalign. | +| `--no-rna` | off | Disable RNA mode for eventalign. | +| `--kmer-model` | — | Reserved; currently unused by the krill engine. | ### Miscellaneous diff --git a/docs/guide/docker.md b/docs/guide/docker.md index 6956e37..bed06f8 100644 --- a/docs/guide/docker.md +++ b/docs/guide/docker.md @@ -5,12 +5,13 @@ images to Docker Hub on every push to `main`/`dev`: | Dockerfile | Image | Base | |------------|-------|------| -| `Dockerfile.cpu` | `/py-baleen-cpu` | CPU-only (`tslearn` DTW backend). | -| `Dockerfile.gpu` | `/py-baleen-gpu` | `nvidia/cuda:12.6.3-runtime-ubuntu22.04`, CUDA DTW backend. | +| `Dockerfile.cpu` | `/py-baleen-cpu` | `python:3.11-slim`, krill CPU wheel. | +| `Dockerfile.gpu` | `/py-baleen-gpu` | `nvidia/cuda:12.2.2-runtime-ubuntu22.04`, krill cu122 GPU wheel. | The `latest` tag is published only from `main`; branch and long-SHA tags are -published for every build. Both images bundle **f5c v1.6** and set -`ENTRYPOINT ["baleen"]` with a `/data` working directory. +published for every build. Both images bundle the **krill** engine and +**slow5tools**, and set `ENTRYPOINT ["baleen"]` with a `/data` working +directory. ## Pull a published image @@ -31,13 +32,14 @@ Dockerfile directly: # CPU docker build -f Dockerfile.cpu -t py-baleen-cpu . -# GPU (needs nvcc/CUDA toolkit during build) +# GPU docker build -f Dockerfile.gpu -t py-baleen-gpu . ``` -The GPU build **fails loudly** if the `_cuda_dtw` C extension did not compile, so -a successful image is guaranteed to have a working CUDA backend rather than a -silent CPU fallback. +Both builds are pure Python (no C-extension compilation): they `pip install` +baleen, then install the appropriate krill wheel (CPU vs cu122) from the +project index. The GPU image's krill is GPU-capable only at run time when a +device is visible — see the verification step below. ## Run the pipeline in a container @@ -71,10 +73,10 @@ docker run --rm --gpus all \ ## Verify the GPU image sees the device ```bash -docker run --rm --gpus all py-baleen-gpu \ - python3 -c "from baleen._cuda_dtw import backend, is_available; \ -print('backend:', backend(), 'cuda:', is_available())" -# Expected: backend: cuda cuda: True +docker run --rm --gpus all --entrypoint python3 py-baleen-gpu \ + -c "from baleen._dtw import backend, is_available; \ +print('backend:', backend(), 'gpu:', is_available())" +# Expected: backend: gpu gpu: True ``` If it prints `backend: cpu`, the container cannot see the GPU — check the diff --git a/docs/guide/inputs.md b/docs/guide/inputs.md index 1f6899b..bb64ee4 100644 --- a/docs/guide/inputs.md +++ b/docs/guide/inputs.md @@ -28,26 +28,23 @@ samtools index native.bam # Reference FASTA samtools faidx ref.fa -# BLOW5 signal index (slow5tools) +# BLOW5 signal index (slow5tools) — produces nanopore.blow5.idx slow5tools index native.blow5 - -# FASTQ read index for f5c (produces .fq.gz.index, .index.fai, .index.gzi, .index.readdb) -f5c index --slow5 native.blow5 native.fq.gz ``` -!!! note "f5c read database" - `f5c index` writes a `.readdb` mapping read IDs to signal records. When - present, Baleen reads read IDs from this cheap index instead of decompressing - the whole FASTQ — see below. +!!! note "No event-alignment index step" + The krill engine reads the BLOW5 signal directly via pyslow5, so it only + needs the `slow5tools index` above. There is no separate FASTQ read-index + step — FASTQ read IDs are parsed straight from the FASTQ headers. ## Read-ID intersection -`f5c eventalign` **silently drops** any BAM read whose UUID is absent from the +eventalign **silently drops** any BAM read whose UUID is absent from the BLOW5 signal file. If your BAM contains reads that have no corresponding raw signal (a common result of separate basecalling/alignment and signal-export steps), those reads survive BAM parsing but vanish during event alignment. -Without correction this biases everything computed before `f5c` runs: +Without correction this biases everything computed before eventalign runs: - **depth statistics** count reads that will never produce signal, - **`--min-depth` filtering** keeps or drops contigs against the wrong count, @@ -63,14 +60,14 @@ reads(BAM) ∩ reads(FASTQ) ∩ reads(BLOW5) Every downstream stage — contig statistics, the `--min-depth` filter, subsampling, and the per-contig BAM split — is gated on this intersection, so -the read set Baleen reasons about is exactly the one `f5c` will align. +the read set Baleen reasons about is exactly the one eventalign will align. ### How read IDs are enumerated | Source | Method | |--------|--------| | BAM | Iterate alignments, collect `query_name`. | -| FASTQ | Prefer the f5c `.index.readdb` (read-id column) when present; otherwise parse FASTQ headers. | +| FASTQ | Parse read IDs from the FASTQ headers. (If a legacy `.index.readdb` file is present it is used instead — krill no longer creates one, so this only affects directories left over from an old f5c run.) | | BLOW5 | `pyslow5.Open(path).get_read_ids()`. | The intersection runs by default. Disable it with `--no-read-intersection` if diff --git a/docs/guide/overview.md b/docs/guide/overview.md index 1671824..d00c60d 100644 --- a/docs/guide/overview.md +++ b/docs/guide/overview.md @@ -15,7 +15,7 @@ flowchart TD end N --> X[Read-ID intersection
BAM ∩ FASTQ ∩ BLOW5] I --> X - X --> EA[f5c eventalign] + X --> EA[krill eventalign] R --> EA EA --> SG[Signal grouping by
genomic position] SG --> DTW[Pairwise DTW distance
matrices per position] @@ -33,18 +33,23 @@ flowchart TD Before any signal work, Baleen computes `reads(BAM) ∩ reads(FASTQ) ∩ reads(BLOW5)` independently for each condition. -`f5c` silently drops BAM reads whose UUIDs are absent from the BLOW5 signal +eventalign silently drops BAM reads whose UUIDs are absent from the BLOW5 signal file; without the intersection, depth statistics, subsampling, and the `--min-depth` filter would all be computed against a read set larger than the one that actually yields signals. Every downstream stage is gated on this intersection. Disable with `--no-read-intersection`. See [Inputs › Read-ID intersection](inputs.md#read-id-intersection). -### 1. Event alignment (`f5c eventalign`) +### 1. Event alignment (`krill eventalign`) -Each read's raw signal is aligned to its reference sequence, producing a table -that maps reference positions to segments of the current signal. Baleen invokes -the external `f5c` binary (RNA mode by default). +Each read's raw signal is aligned to its mapped reference subsequence, producing +a table that maps reference positions to segments of the current signal. Baleen +uses the [krill](https://loganylchen.github.io/krill-dist/) engine (RNA mode by +default). The alignment is HMM-free and forced-dense — every signal sample is +assigned to a reference position with no read-vs-reference skips — and krill +emits an f5c-format TSV, so every downstream stage is unchanged. krill reads the +BLOW5 signal directly via pyslow5, so no separate event-alignment index step is +required. ### 2. Signal grouping @@ -58,8 +63,8 @@ more context. For every retained position, Baleen computes a **pairwise DTW distance matrix** between native and IVT signal segments. DTW (Dynamic Time Warping) is robust to the local time-warping inherent in nanopore translocation. The computation runs -on a [CUDA backend](performance.md#dtw-backend) when available, with an -automatic `tslearn` CPU fallback. +on krill's [GPU backend](performance.md#dtw-backend) when available, with an +automatic CPU fallback. ### 4. Three-stage hierarchical modification calling diff --git a/docs/guide/performance.md b/docs/guide/performance.md index 35d56b2..91cd476 100644 --- a/docs/guide/performance.md +++ b/docs/guide/performance.md @@ -6,17 +6,20 @@ behaviour, and the knobs that control throughput. ## DTW backend -The `_cuda_dtw` module selects a backend **at import time**: +The `baleen._dtw` module is a thin shim over the +[krill](https://loganylchen.github.io/krill-dist/) engine. The backend is +selected by **which krill wheel is installed** plus device presence — not by a +compile-time flag: -- **CUDA (GPU)** if the `_cuda_dtw` C extension compiled with CUDA support. -- **CPU (`tslearn`)** fallback otherwise. +- **GPU** when krill's `cu122` wheel is installed and a CUDA device is present. +- **CPU** otherwise (krill's plain wheel, or no device). Check which one is active: ```python -from baleen._cuda_dtw import backend, is_available -print("DTW backend:", backend()) # "cuda" or "cpu" -print("CUDA available:", is_available()) +from baleen._dtw import backend, is_available +print("DTW backend:", backend()) # "gpu" or "cpu" +print("GPU available:", is_available()) ``` Force a backend per run: @@ -54,18 +57,16 @@ without OOM. | Flag | Effect on performance | |------|-----------------------| -| `--threads N` | Parallel contig workers (`ProcessPoolExecutor`). More workers = more concurrency, but each f5c call then gets fewer CPU threads. | -| `--f5c-threads N` | CPU threads per `f5c eventalign` call. Default auto = `total_cores / threads`. | +| `--threads N` | Parallel contig workers (`ProcessPoolExecutor`). More workers = more concurrency. | | `--subsample-n N` | Caps reads per condition per contig (default 300). Fewer reads → fewer DTW pairs → faster, at some statistical cost. | | `--no-subsample` | Disables the cap — slower, more memory, on deep data. | | `--min-depth` / `--depth-mode` | Skip shallow contigs entirely. | | `--target` | Restrict to specific contigs. | -!!! tip "Balancing `--threads` and `--f5c-threads`" - `f5c` is itself multithreaded. If you set `--threads 16` on a 16-core - machine, the auto rule gives each f5c call only 1 thread. For - f5c-bound workloads, fewer pipeline workers with more f5c threads each can - be faster — profile both. +!!! tip "`--threads` controls contig parallelism" + krill eventalign runs in-process, not as a separate multithreaded + subprocess, so there is no per-call thread budget to balance. `--threads` + simply sets how many contig workers run in parallel. ## Resuming long runs diff --git a/docs/index.md b/docs/index.md index e3cf050..584ac51 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,7 +32,8 @@ per-site modification probabilities. - **CUDA-accelerated DTW** — a batched multi-position GPU kernel processes all positions of a contig in a single launch with concurrent CUDA streams. - Automatic CPU fallback via `tslearn` when no GPU is available. + Automatic CPU fallback when no GPU is available. DTW is provided by the + [krill](https://loganylchen.github.io/krill-dist/) engine. - **Three-stage hierarchical modification calling** - **V1** — robust IVT null estimation with coverage-adaptive three-level shrinkage (position → local window → global). @@ -46,8 +47,9 @@ per-site modification probabilities. - **Streaming architecture** — DTW → HMM → aggregation are fused per contig and flushed to disk, so peak memory stays bounded regardless of transcriptome size. - **Read-ID intersection** — every stage is gated on - `reads(BAM) ∩ reads(FASTQ) ∩ reads(BLOW5)` per condition, so `f5c` silently - dropping reads absent from the signal file never biases subsampling. + `reads(BAM) ∩ reads(FASTQ) ∩ reads(BLOW5)` per condition, so eventalign + silently dropping BAM reads whose UUIDs are absent from the BLOW5 signal file + never biases subsampling. - **Resumable** — interrupted runs can be continued with `--resume`, reusing per-contig slices already on disk. @@ -55,7 +57,7 @@ per-site modification probabilities. ```mermaid flowchart LR - NB[Native BAM/FASTQ/BLOW5] --> F5C[f5c eventalign] + NB[Native BAM/FASTQ/BLOW5] --> F5C[krill eventalign] IB[IVT BAM/FASTQ/BLOW5] --> F5C REF[Reference FASTA] --> F5C F5C --> SG[Signal grouping
by position] diff --git a/docs/installation.md b/docs/installation.md index e22ed7d..6ab5709 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -4,14 +4,16 @@ | Requirement | Notes | |-------------|-------| -| Python ≥ 3.9 | 3.9 – 3.11 are tested. | -| [f5c](https://github.com/hasindu2008/f5c) ≥ 1.4 | Must be on `PATH`. Used for nanopore event alignment. | -| CUDA toolkit | **Optional.** Enables GPU-accelerated DTW. Without it, Baleen falls back to a CPU (`tslearn`) backend automatically. | +| Python ≥ 3.10 | 3.10 – 3.12 are tested. (krill ships cp310+ wheels.) | +| [krill](https://loganylchen.github.io/krill-dist/) | DTW + eventalign engine. **Required.** Not on PyPI — install from the project index (see below). | +| [slow5tools](https://github.com/hasindu2008/slow5tools) | Must be on `PATH`. Used to index BLOW5 signal files (`slow5tools index`). | +| NVIDIA GPU + driver | **Optional.** The krill cu122 wheel runs DTW on the GPU. Without a GPU, install the plain krill wheel for a CPU backend. | -!!! note "f5c is an external tool" - Baleen shells out to the `f5c` binary; it is not installed by `pip`. - Install it separately and make sure `f5c --version` works from your shell - before running the pipeline. +!!! note "krill is not on PyPI" + Baleen's DTW and eventalign run through the `krill` package, which is + published on a project index rather than PyPI. A plain `pip install baleen` + will **not** pull it — install krill explicitly (below) or use a Docker + image, which bundles krill and slow5tools for you. ## Install from source @@ -19,37 +21,26 @@ git clone https://github.com/loganylchen/py-baleen.git cd py-baleen -# With CUDA (auto-detected if `nvcc` is available) +# baleen is pure Python — no C extension to build. pip install . ``` -### CPU-only build - -Skip CUDA compilation entirely: - -```bash -BALEEN_NO_CUDA=1 pip install . -``` - -### Targeting specific GPU architectures - -By default the build compiles for a broad set of compute capabilities. To -restrict (faster builds) or target a specific GPU, set `BALEEN_CUDA_ARCHS` to a -comma-separated list of compute capabilities **without dots**: +Then install the krill engine from the project index: ```bash -# Ampere (8.6) + Hopper (9.0) -BALEEN_CUDA_ARCHS=86,90 pip install . +# GPU (CUDA 12.2 wheel) — recommended when a GPU is available +pip install krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/cu122/simple/ -# Auto-detect the GPU currently installed -BALEEN_CUDA_ARCHS=native pip install . +# CPU-only +pip install krill --no-deps \ + --index-url https://loganylchen.github.io/krill-dist/simple/ ``` -| Environment variable | Effect | -|----------------------|--------| -| `BALEEN_NO_CUDA=1` | Skip CUDA compilation; CPU backend only. | -| `BALEEN_CUDA_ARCHS=86,90` | Compile only for the listed compute capabilities. | -| `BALEEN_CUDA_ARCHS=native` | Auto-detect and target the installed GPU. | +!!! warning "krill install rules" + Install krill's runtime deps (`numpy scipy pyslow5 pyfastx`) from PyPI + first, then install krill itself with `--no-deps` from the project index. + Do **not** use `krill[...]` extras or `--extra-index-url`. ## Install with extras @@ -63,7 +54,8 @@ pip install ".[docs]" ## Docker -Pre-built images are published on Docker Hub: +Pre-built images bundle baleen + krill + slow5tools and are published on +Docker Hub: ```bash # CPU @@ -87,7 +79,7 @@ python -c "import baleen; print('baleen', baleen.__name__, 'import OK')" To confirm which DTW backend was selected: ```python -from baleen._cuda_dtw import backend, is_available -print("DTW backend:", backend()) # "cuda" or "cpu" -print("CUDA available:", is_available()) +from baleen._dtw import backend, is_available +print("DTW backend:", backend()) # "gpu" or "cpu" +print("GPU available:", is_available()) ``` diff --git a/pyproject.toml b/pyproject.toml index 3fe0fde..31e2fa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,12 @@ [build-system] -requires = ["setuptools>=64", "wheel", "numpy"] +requires = ["setuptools>=64", "wheel"] build-backend = "setuptools.build_meta" [project] name = "baleen" -version = "0.3.0" +version = "0.4.0" description = "Hierarchical Bayesian framework for RNA modification detection from nanopore direct RNA sequencing" -requires-python = ">=3.9" +requires-python = ">=3.10" license = {file = "LICENSE"} authors = [ {name = "Logan Chen"}, @@ -17,19 +17,27 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Bio-Informatics", ] +# NOTE: the DTW + eventalign engine 'krill' is a required runtime dependency +# but is NOT on PyPI — install it from the project index, e.g. +# pip install krill --no-deps \ +# --index-url https://loganylchen.github.io/krill-dist/cu122/simple/ (GPU) +# pip install krill --no-deps \ +# --index-url https://loganylchen.github.io/krill-dist/simple/ (CPU) +# or use a prebuilt baleen Docker image. It is intentionally omitted below so a +# plain `pip install .` does not fail looking for it on PyPI. dependencies = [ "numpy>=1.24", "scipy>=1.9", - "tslearn>=0.6", "pysam>=0.21", "tqdm>=4.60", "pandas>=1.5", "pyslow5", + "pyfastx", ] [project.urls] diff --git a/scripts/profile_pipeline.py b/scripts/profile_pipeline.py index be2ee0b..c0daab6 100644 --- a/scripts/profile_pipeline.py +++ b/scripts/profile_pipeline.py @@ -79,8 +79,8 @@ class ContigProfile: # Per-stage wall-clock seconds bam_split_native: float = 0.0 bam_split_ivt: float = 0.0 - f5c_native: float = 0.0 - f5c_ivt: float = 0.0 + eventalign_native: float = 0.0 + eventalign_ivt: float = 0.0 signal_parse_native: float = 0.0 signal_parse_ivt: float = 0.0 signal_extract: float = 0.0 @@ -116,7 +116,7 @@ class ProfileReport: # Aggregated total_wall_clock: float = 0.0 total_bam_split: float = 0.0 - total_f5c: float = 0.0 + total_eventalign: float = 0.0 total_signal_parse: float = 0.0 total_signal_extract: float = 0.0 total_dtw: float = 0.0 @@ -154,8 +154,8 @@ def profile_contig( num_cuda_streams: int, run_hmm: bool, ) -> ContigProfile: - from baleen.eventalign import _bam, _f5c, _signal - from baleen import _cuda_dtw + from baleen.eventalign import _bam, _eventalign, _signal + from baleen import _dtw prof = ContigProfile( contig=contig, @@ -191,28 +191,28 @@ def profile_contig( _fmt(prof.bam_split_native), _fmt(prof.bam_split_ivt), ) - # --- f5c eventalign --- + # --- krill eventalign --- native_tsv = contig_tmp / "native.eventalign.tsv" ivt_tsv = contig_tmp / "ivt.eventalign.tsv" - with StageTimer("f5c_native") as t: - _f5c.run_eventalign( + with StageTimer("eventalign_native") as t: + _eventalign.run_eventalign( native_contig_bam, ref_fasta, native_fastq, native_blow5, native_tsv, rna=True, ) - prof.f5c_native = t.elapsed + prof.eventalign_native = t.elapsed - with StageTimer("f5c_ivt") as t: - _f5c.run_eventalign( + with StageTimer("eventalign_ivt") as t: + _eventalign.run_eventalign( ivt_contig_bam, ref_fasta, ivt_fastq, ivt_blow5, ivt_tsv, rna=True, ) - prof.f5c_ivt = t.elapsed + prof.eventalign_ivt = t.elapsed logger.info( - " [%d/%d] %s f5c: native=%s ivt=%s", + " [%d/%d] %s eventalign: native=%s ivt=%s", contig_idx, total_contigs, contig, - _fmt(prof.f5c_native), _fmt(prof.f5c_ivt), + _fmt(prof.eventalign_native), _fmt(prof.eventalign_ivt), ) # --- Signal parsing --- @@ -285,7 +285,7 @@ def profile_contig( all_signal_lists = [d[4] for d in position_data] # Use the same chunking as the real pipeline - total_gpu = _cuda_dtw.estimate_gpu_memory(all_signal_lists) if use_cuda else 0 + total_gpu = _dtw.estimate_gpu_memory(all_signal_lists) if use_cuda else 0 gpu_mem = 80 * 1024**3 # assume 80GB for chunk sizing chunk_mem_limit = int(gpu_mem * 0.8) @@ -293,7 +293,7 @@ def profile_contig( current_chunk: list[int] = [] current_estimate = 0 for i, sigs in enumerate(all_signal_lists): - pos_estimate = _cuda_dtw.estimate_gpu_memory([sigs]) + pos_estimate = _dtw.estimate_gpu_memory([sigs]) if current_chunk and current_estimate + pos_estimate > chunk_mem_limit: chunks.append(current_chunk) current_chunk = [i] @@ -311,10 +311,8 @@ def profile_contig( for chunk_idx, chunk_indices in enumerate(chunks): chunk_signals = [all_signal_lists[i] for i in chunk_indices] chunk_t0 = time.perf_counter() - chunk_matrices = _cuda_dtw.dtw_multi_position_pairwise( + chunk_matrices = _dtw.dtw_multi_position_pairwise( chunk_signals, - use_open_start=False, - use_open_end=False, use_cuda=use_cuda, num_streams=num_cuda_streams, ) @@ -409,16 +407,16 @@ def main(): help="Output JSON path (default: profile_report.json)") args = parser.parse_args() - from baleen.eventalign import _bam, _f5c - from baleen import _cuda_dtw + from baleen.eventalign import _bam, _eventalign + from baleen import _dtw import tempfile from datetime import datetime report = ProfileReport( timestamp=datetime.now().isoformat(), - cuda_available=_cuda_dtw.CUDA_AVAILABLE, + cuda_available=_dtw.CUDA_AVAILABLE, cuda_used=args.use_cuda, - dtw_backend=_cuda_dtw.backend(), + dtw_backend=_dtw.backend(), threads=args.threads, subsample_n=args.subsample_n, padding=args.padding, @@ -434,10 +432,8 @@ def main(): # Index (idempotent) logger.info("Indexing...") - _f5c.index_fastq_blow5(native_fastq, native_blow5) - _f5c.index_fastq_blow5(ivt_fastq, ivt_blow5) - _f5c.index_blow5(native_blow5) - _f5c.index_blow5(ivt_blow5) + _eventalign.index_blow5(native_blow5) + _eventalign.index_blow5(ivt_blow5) # BAM stats logger.info("Computing BAM stats...") @@ -507,7 +503,7 @@ def main(): if c.error: continue report.total_bam_split += c.bam_split_native + c.bam_split_ivt - report.total_f5c += c.f5c_native + c.f5c_ivt + report.total_eventalign += c.eventalign_native + c.eventalign_ivt report.total_signal_parse += c.signal_parse_native + c.signal_parse_ivt report.total_signal_extract += c.signal_extract report.total_dtw += c.dtw_total @@ -525,7 +521,7 @@ def main(): stages = [ ("BAM split", report.total_bam_split), - ("f5c eventalign", report.total_f5c), + ("eventalign", report.total_eventalign), ("Signal parsing", report.total_signal_parse), ("Signal extract", report.total_signal_extract), ("DTW computation", report.total_dtw), @@ -553,7 +549,7 @@ def main(): print(f" ERROR: {c.error}") else: print(f" bam_split={_fmt(c.bam_split_native + c.bam_split_ivt)} " - f"f5c={_fmt(c.f5c_native + c.f5c_ivt)} " + f"eventalign={_fmt(c.eventalign_native + c.eventalign_ivt)} " f"signals={_fmt(c.signal_parse_native + c.signal_parse_ivt + c.signal_extract)} " f"dtw={_fmt(c.dtw_total)} " f"hier={_fmt(c.hierarchical_total)} " @@ -570,8 +566,8 @@ def _to_dict(prof: ContigProfile) -> dict: "n_positions_computed": prof.n_positions_computed, "bam_split_native_s": round(prof.bam_split_native, 3), "bam_split_ivt_s": round(prof.bam_split_ivt, 3), - "f5c_native_s": round(prof.f5c_native, 3), - "f5c_ivt_s": round(prof.f5c_ivt, 3), + "eventalign_native_s": round(prof.eventalign_native, 3), + "eventalign_ivt_s": round(prof.eventalign_ivt, 3), "signal_parse_native_s": round(prof.signal_parse_native, 3), "signal_parse_ivt_s": round(prof.signal_parse_ivt, 3), "signal_extract_s": round(prof.signal_extract, 3), @@ -609,7 +605,7 @@ def _to_dict(prof: ContigProfile) -> dict: "total_wall_clock_s": round(report.total_wall_clock, 3), "stage_totals_s": { "bam_split": round(report.total_bam_split, 3), - "f5c": round(report.total_f5c, 3), + "eventalign": round(report.total_eventalign, 3), "signal_parse": round(report.total_signal_parse, 3), "signal_extract": round(report.total_signal_extract, 3), "dtw": round(report.total_dtw, 3), diff --git a/setup.py b/setup.py deleted file mode 100644 index d86df00..0000000 --- a/setup.py +++ /dev/null @@ -1,328 +0,0 @@ -""" -Baleen - CUDA-accelerated Dynamic Time Warping - -Build modes: - - With CUDA (GPU): Automatically detected if nvcc is on PATH. - Set CUDA_HOME or CUDA_PATH to override CUDA toolkit location. - Force skip with: BALEEN_NO_CUDA=1 pip install . - - Without CUDA: Pure-Python install. The _cuda_dtw module will still import - but dtw_distance()/dtw_pairwise() raise RuntimeError. -""" - -import os -import platform -import shutil -import subprocess -import sys -import tempfile - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - - -# --------------------------------------------------------------------------- -# CUDA detection helpers -# --------------------------------------------------------------------------- - -def _find_cuda_home(): - """Find CUDA toolkit installation directory. - - Search order: - 1. CUDA_HOME environment variable - 2. CUDA_PATH environment variable - 3. nvcc location on PATH (derive from its parent) - 4. Common installation paths - """ - for env_var in ("CUDA_HOME", "CUDA_PATH"): - val = os.environ.get(env_var) - if val and os.path.isdir(val): - return val - - nvcc = shutil.which("nvcc") - if nvcc: - cuda_home = os.path.dirname(os.path.dirname(os.path.realpath(nvcc))) - if os.path.isdir(os.path.join(cuda_home, "include")): - return cuda_home - - common_paths = [ - "/usr/local/cuda", - "/usr/local/cuda-12", - "/usr/local/cuda-11", - "/opt/cuda", - ] - for path in common_paths: - if os.path.isdir(path): - return path - - return None - - -def _nvcc_is_available(): - """Quick check: is nvcc on PATH or in CUDA_HOME?""" - if shutil.which("nvcc"): - return True - cuda_home = _find_cuda_home() - if cuda_home: - nvcc_path = os.path.join(cuda_home, "bin", "nvcc") - return os.path.isfile(nvcc_path) - return False - - -def _get_nvcc(): - """Return the full path to nvcc, or None.""" - nvcc = shutil.which("nvcc") - if nvcc: - return nvcc - cuda_home = _find_cuda_home() - if cuda_home: - nvcc_path = os.path.join(cuda_home, "bin", "nvcc") - if os.path.isfile(nvcc_path): - return nvcc_path - return None - - -# Default arch list covers Pascal (1050 Ti / 1080 Ti) through Hopper (H100). -# Override with BALEEN_CUDA_ARCHS=61,75,86 (comma-separated compute capabilities -# without the dot: e.g. "61" = sm_61 = compute capability 6.1). The highest arch -# also gets PTX embedded for forward-compat with future GPUs. -_DEFAULT_CUDA_ARCHS = ["75", "80", "86", "89", "90"] - - -def _get_cuda_archs(): - """Return list of compute capability codes to target (e.g. ['61', '86']).""" - env = os.environ.get("BALEEN_CUDA_ARCHS", "").strip() - if not env: - return list(_DEFAULT_CUDA_ARCHS) - if env.lower() == "native": - # Auto-detect via nvidia-smi; fall back to defaults on failure - try: - out = subprocess.check_output( - ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"], - text=True, stderr=subprocess.DEVNULL, timeout=5, - ) - caps = sorted({line.strip().replace(".", "") - for line in out.splitlines() if line.strip()}) - if caps: - return caps - except (FileNotFoundError, subprocess.CalledProcessError, - subprocess.TimeoutExpired): - pass - return list(_DEFAULT_CUDA_ARCHS) - return [a.strip() for a in env.replace(";", ",").split(",") if a.strip()] - - -def _gencode_flags(): - """Build -gencode flag list for nvcc. Highest arch also embeds PTX.""" - archs = _get_cuda_archs() - if not archs: - return [] - flags = [] - for arch in archs: - flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) - highest = max(archs, key=lambda a: int(a)) - flags.extend(["-gencode", f"arch=compute_{highest},code=compute_{highest}"]) - return flags - - -# --------------------------------------------------------------------------- -# Custom build_ext that compiles CUDA extensions with nvcc -# --------------------------------------------------------------------------- - -class CUDABuildExt(build_ext): - """Custom build_ext that: - 1. Detects CUDA availability at build time (not setup time) - 2. Copies .cpp CUDA source to .cu so nvcc treats it as CUDA - 3. Compiles .cu sources with nvcc - 4. Falls back gracefully if CUDA is missing or compilation fails - """ - - def build_extensions(self): - # ── Gate: should we even try? ── - cuda_exts = [e for e in self.extensions if getattr(e, "_is_cuda", False)] - other_exts = [e for e in self.extensions if not getattr(e, "_is_cuda", False)] - - if os.environ.get("BALEEN_NO_CUDA", "").strip() in ("1", "true", "yes"): - print("\nℹ️ CUDA build disabled by BALEEN_NO_CUDA env var.") - self.extensions = other_exts - if self.extensions: - super().build_extensions() - return - - if not _nvcc_is_available(): - print("\nℹ️ nvcc not found. Skipping CUDA extension (CPU-only install).") - self.extensions = other_exts - if self.extensions: - super().build_extensions() - return - - # ── Prepare .cpp → .cu copies (native .cu files kept as-is) ── - for ext in cuda_exts: - new_sources = [] - for src in ext.sources: - if src.endswith(".cu"): - new_sources.append(src) # already a .cu file - elif src.endswith(".cpp") and "_cuda_dtw" in src: - os.makedirs(self.build_temp, exist_ok=True) - cu_path = os.path.join(self.build_temp, os.path.basename( - src.rsplit(".cpp", 1)[0] + ".cu" - )) - shutil.copy2(src, cu_path) - new_sources.append(cu_path) - else: - new_sources.append(src) - ext.sources = new_sources - - # ── Build, with graceful fallback ── - try: - super().build_extensions() - except Exception as e: - print(f"\n⚠️ CUDA extension build failed: {e}") - print(" Falling back to CPU-only installation.\n") - self.extensions = other_exts - if self.extensions: - super().build_extensions() - - def build_extension(self, ext): - if not getattr(ext, "_is_cuda", False): - super().build_extension(ext) - return - - nvcc = _get_nvcc() - if not nvcc: - raise RuntimeError("nvcc not found") - - cuda_home = _find_cuda_home() - - # ── Include / library directories ── - include_dirs = list(ext.include_dirs or []) - library_dirs = list(ext.library_dirs or []) - - if cuda_home: - include_dirs.append(os.path.join(cuda_home, "include")) - for lib_subdir in ("lib64", "lib"): - ldir = os.path.join(cuda_home, lib_subdir) - if os.path.isdir(ldir): - library_dirs.append(ldir) - break - - # Local CUDA headers (dtw.hpp, cuda_utils.hpp, etc.) - cuda_src_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "baleen", "_cuda_dtw", - ) - include_dirs.append(cuda_src_dir) - - # Python + NumPy includes - from sysconfig import get_paths as _get_paths - include_dirs.append(_get_paths()["include"]) - try: - import numpy - include_dirs.append(numpy.get_include()) - except ImportError: - raise RuntimeError("numpy is required to build the CUDA extension") - - # ── Output path ── - ext_path = self.get_ext_fullpath(ext.name) - os.makedirs(os.path.dirname(ext_path), exist_ok=True) - - # ── Compile each source ── - gencode = _gencode_flags() - if gencode: - archs_disp = ",".join(_get_cuda_archs()) - print(f" [baleen] nvcc targets: sm_{{{archs_disp}}} (+PTX)") - objects = [] - for src in ext.sources: - obj = src + ".o" - compile_cmd = [ - nvcc, - "-std=c++17", - "-O3", - "-Xcompiler", "-fPIC", - *gencode, - "-c", src, - "-o", obj, - ] - for inc in include_dirs: - compile_cmd.extend(["-I", inc]) - - print(f" [baleen] nvcc compile: {os.path.basename(src)}") - subprocess.check_call(compile_cmd) - objects.append(obj) - - # ── Link ── - link_cmd = [nvcc, "--shared", "-o", ext_path] + objects - - for ldir in library_dirs: - link_cmd.extend(["-L", ldir]) - for lib in (ext.libraries or []): - link_cmd.extend(["-l", lib]) - - # Platform-specific linker flags - if platform.system() == "Darwin": - link_cmd.extend(["-Xcompiler", "-undefined,dynamic_lookup"]) - - print(f" [baleen] nvcc link: {os.path.basename(ext_path)}") - subprocess.check_call(link_cmd) - - -# --------------------------------------------------------------------------- -# Extension definition — always declared, build_ext decides whether to build -# --------------------------------------------------------------------------- - -def _make_cuda_extension(): - """Create the CUDA extension module definition. - - By default uses cuDTW++ warp-shuffle kernels (cudtw_wrapper.cu). - Set BALEEN_USE_CUDTW=0 to fall back to the legacy OpenDBA wavefront kernel. - """ - cuda_src_dir = os.path.join("baleen", "_cuda_dtw") - use_cudtw = os.environ.get("BALEEN_USE_CUDTW", "1").strip() - - if use_cudtw in ("0", "false", "no"): - # Legacy kernel - sources = [ - os.path.join(cuda_src_dir, "dtw_api.cpp"), - os.path.join(cuda_src_dir, "multithreading.cpp"), - ] - else: - # cuDTW++ warp-shuffle kernel (native .cu — no copy needed) - sources = [ - os.path.join(cuda_src_dir, "cudtw_wrapper.cu"), - ] - - ext = Extension( - name="baleen._cuda_dtw._cuda_dtw", - sources=sources, - libraries=["cudart"], - language="c++", - ) - ext._is_cuda = True # type: ignore[attr-defined] - return ext - - -# --------------------------------------------------------------------------- -# Always include the CUDA extension; CUDABuildExt skips it if nvcc is absent -# --------------------------------------------------------------------------- - -setup( - name="baleen", - version="0.3.0", - description="CUDA-accelerated DTW and nanopore signal analysis pipeline", - author="Logan", - python_requires=">=3.9", - packages=find_packages(), - install_requires=[ - "numpy", - "tslearn", - "pysam", - "scipy", - "tqdm", - "pandas", - ], - entry_points={ - "console_scripts": ["baleen=baleen.cli:main"], - }, - ext_modules=[_make_cuda_extension()], - cmdclass={"build_ext": CUDABuildExt}, - zip_safe=False, -) diff --git a/tests/test_cli.py b/tests/test_cli.py index 7fe62e4..d2a3153 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -46,7 +46,7 @@ def sample_args_run(): no_read_bam=True, target=None, keep_intermediate=False, - f5c_threads=None, + pore="rna002", gpu_memory_limit=None, no_subsample=False, subsample_n=300, diff --git a/tests/test_cudtw_migration.py b/tests/test_cudtw_migration.py deleted file mode 100644 index 44ff110..0000000 --- a/tests/test_cudtw_migration.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Migration verification tests for cuDTW++ warp-shuffle kernel. - -These tests validate that the cuDTW++ backend produces correct DTW distances -by comparing against the CPU (tslearn) reference implementation. - -Run with: pytest tests/test_cudtw_migration.py -v -""" - -import numpy as np -import pytest - -from baleen._cuda_dtw import ( - CUDA_AVAILABLE, - _CUDTW_ACTIVE, - backend, - dtw_distance, - dtw_pairwise, - dtw_pairwise_varlen, - dtw_multi_position_pairwise, -) - -pytestmark = pytest.mark.skipif( - not CUDA_AVAILABLE, reason="CUDA not available" -) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _random_signals(n, length, seed=42): - rng = np.random.RandomState(seed) - return [rng.randn(length).astype(np.float32) for _ in range(n)] - - -def _pairwise_cpu(signals): - """Reference pairwise DTW via CPU backend.""" - n = len(signals) - mat = np.zeros((n, n), dtype=np.float64) - for i in range(n): - for j in range(i + 1, n): - d = dtw_distance(signals[i], signals[j], use_cuda=False) - mat[i, j] = d - mat[j, i] = d - return mat - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - -class TestBackendDetection: - def test_backend_is_cuda(self): - assert backend() == "cuda" - - def test_cudtw_active(self): - """If built with default settings, cuDTW++ should be active.""" - # This test documents which kernel is active; it's informational. - print(f"cuDTW++ active: {_CUDTW_ACTIVE}") - - -class TestSinglePairDTW: - """Single-pair DTW distance correctness.""" - - @pytest.mark.parametrize("length", [50, 127, 200, 255, 500, 511, 1023]) - def test_equal_length(self, length): - rng = np.random.RandomState(123) - s1 = rng.randn(length).astype(np.float32) - s2 = rng.randn(length).astype(np.float32) - - d_gpu = dtw_distance(s1, s2, use_cuda=True) - d_cpu = dtw_distance(s1, s2, use_cuda=False) - - # cuDTW++ pads to bucket boundaries, so distances differ slightly - # for non-bucket lengths. Use generous tolerance. - assert d_gpu >= 0 - assert d_cpu >= 0 - # Both should be in the same ballpark - if d_cpu > 0: - rel_err = abs(d_gpu - d_cpu) / d_cpu - print(f"len={length}: gpu={d_gpu:.4f} cpu={d_cpu:.4f} rel_err={rel_err:.4f}") - - @pytest.mark.parametrize("length", [127, 255, 511, 1023]) - def test_bucket_length_close_match(self, length): - """At exact bucket lengths, GPU and CPU should agree closely.""" - rng = np.random.RandomState(456) - s1 = rng.randn(length).astype(np.float32) - s2 = rng.randn(length).astype(np.float32) - - d_gpu = dtw_distance(s1, s2, use_cuda=True) - d_cpu = dtw_distance(s1, s2, use_cuda=False) - - # At bucket boundaries, zero-padding is minimal → close match - if d_cpu > 0: - rel_err = abs(d_gpu - d_cpu) / d_cpu - assert rel_err < 0.15, ( - f"Bucket {length}: gpu={d_gpu:.4f} cpu={d_cpu:.4f} rel_err={rel_err:.4f}" - ) - - def test_identical_sequences(self): - s = np.ones(127, dtype=np.float32) - d = dtw_distance(s, s, use_cuda=True) - assert d == pytest.approx(0.0, abs=1e-5) - - -class TestPairwiseDTW: - """Pairwise distance matrix correctness.""" - - def test_pairwise_symmetric(self): - sigs = _random_signals(5, 127) - seqs = np.array(sigs) - mat = dtw_pairwise(seqs, use_cuda=True) - np.testing.assert_allclose(mat, mat.T, atol=1e-5) - - def test_pairwise_diagonal_zero(self): - sigs = _random_signals(5, 127) - seqs = np.array(sigs) - mat = dtw_pairwise(seqs, use_cuda=True) - np.testing.assert_allclose(np.diag(mat), 0.0, atol=1e-5) - - def test_pairwise_nonnegative(self): - sigs = _random_signals(8, 255) - seqs = np.array(sigs) - mat = dtw_pairwise(seqs, use_cuda=True) - assert np.all(mat >= -1e-6) - - -class TestPairwiseVarlen: - """Variable-length pairwise DTW.""" - - def test_varlen_runs(self): - rng = np.random.RandomState(789) - sigs = [rng.randn(l).astype(np.float32) for l in [80, 100, 120, 90]] - mat = dtw_pairwise_varlen(sigs, use_cuda=True) - assert mat.shape == (4, 4) - np.testing.assert_allclose(mat, mat.T, atol=1e-5) - np.testing.assert_allclose(np.diag(mat), 0.0, atol=1e-5) - - -class TestMultiPosition: - """Multi-position batched pairwise DTW.""" - - def test_multi_position_runs(self): - rng = np.random.RandomState(321) - pos1 = [rng.randn(100).astype(np.float32) for _ in range(3)] - pos2 = [rng.randn(80).astype(np.float32) for _ in range(4)] - results = dtw_multi_position_pairwise([pos1, pos2], use_cuda=True) - assert len(results) == 2 - assert results[0].shape == (3, 3) - assert results[1].shape == (4, 4) - - def test_multi_position_symmetric(self): - rng = np.random.RandomState(654) - pos = [rng.randn(127).astype(np.float32) for _ in range(5)] - results = dtw_multi_position_pairwise([pos], use_cuda=True) - mat = results[0] - np.testing.assert_allclose(mat, mat.T, atol=1e-5) - - diff --git a/tests/test_dtw.py b/tests/test_dtw.py index 4ff6051..d2f9f1c 100644 --- a/tests/test_dtw.py +++ b/tests/test_dtw.py @@ -1,529 +1,151 @@ -""" -Tests for DTW computation with CPU and GPU backends. +"""Tests for the DTW shim (baleen._dtw), which delegates to krill. -These tests verify: -1. dtw_distance() and dtw_pairwise() work on CPU via tslearn/numpy -2. Backend reporting is accurate -3. Numerical correctness against known DTW properties -4. Input validation works on both paths -5. use_cuda parameter controls backend selection +Covers the contract the pipeline relies on: distance/pairwise correctness, +symmetry, zero diagonal, batch consistency, input validation, backend +reporting, and (when a GPU is present) GPU/CPU agreement on signals within the +no-resample length cap. """ -import sys -import warnings - import numpy as np import pytest +from baleen import _dtw -class TestDTWDistanceCPU: - """dtw_distance() must work without CUDA by falling back to CPU.""" +class TestDTWDistance: def test_identical_sequences_distance_zero(self): - from baleen._cuda_dtw import dtw_distance - seq = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) - dist = dtw_distance(seq, seq, use_cuda=False) - assert dist == pytest.approx(0.0, abs=1e-6) + assert _dtw.dtw_distance(seq, seq, use_cuda=False) == pytest.approx(0.0, abs=1e-6) def test_different_sequences_positive_distance(self): - from baleen._cuda_dtw import dtw_distance - seq1 = np.array([1.0, 2.0, 3.0], dtype=np.float32) seq2 = np.array([4.0, 5.0, 6.0], dtype=np.float32) - dist = dtw_distance(seq1, seq2, use_cuda=False) - assert dist > 0.0 + assert _dtw.dtw_distance(seq1, seq2, use_cuda=False) > 0.0 def test_symmetry(self): - from baleen._cuda_dtw import dtw_distance - rng = np.random.default_rng(42) seq1 = rng.standard_normal(50).astype(np.float32) seq2 = rng.standard_normal(50).astype(np.float32) - d1 = dtw_distance(seq1, seq2, use_cuda=False) - d2 = dtw_distance(seq2, seq1, use_cuda=False) + d1 = _dtw.dtw_distance(seq1, seq2, use_cuda=False) + d2 = _dtw.dtw_distance(seq2, seq1, use_cuda=False) assert d1 == pytest.approx(d2, abs=1e-5) - def test_different_lengths(self): - from baleen._cuda_dtw import dtw_distance - - seq1 = np.array([1.0, 2.0, 3.0], dtype=np.float32) - seq2 = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) - dist = dtw_distance(seq1, seq2, use_cuda=False) - assert isinstance(dist, float) - assert dist >= 0.0 - - def test_list_input(self): - from baleen._cuda_dtw import dtw_distance - - dist = dtw_distance([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], use_cuda=False) - assert dist == pytest.approx(0.0, abs=1e-6) - def test_return_type_is_float(self): - from baleen._cuda_dtw import dtw_distance - seq1 = np.array([1.0, 2.0], dtype=np.float32) seq2 = np.array([3.0, 4.0], dtype=np.float32) - dist = dtw_distance(seq1, seq2, use_cuda=False) - assert isinstance(dist, float) - - def test_triangle_inequality(self): - from baleen._cuda_dtw import dtw_distance - - rng = np.random.default_rng(123) - a = rng.standard_normal(30).astype(np.float32) - b = rng.standard_normal(30).astype(np.float32) - c = rng.standard_normal(30).astype(np.float32) - d_ac = dtw_distance(a, c, use_cuda=False) - d_ab = dtw_distance(a, b, use_cuda=False) - d_bc = dtw_distance(b, c, use_cuda=False) - assert d_ac <= d_ab + d_bc + 1e-5 - - -class TestDTWPairwiseCPU: - """dtw_pairwise() must work without CUDA by falling back to CPU.""" - - def test_pairwise_shape(self): - from baleen._cuda_dtw import dtw_pairwise - - rng = np.random.default_rng(42) - sequences = rng.standard_normal((5, 20)).astype(np.float32) - result = dtw_pairwise(sequences, use_cuda=False) - assert result.shape == (5, 5) - - def test_pairwise_diagonal_zero(self): - from baleen._cuda_dtw import dtw_pairwise - - rng = np.random.default_rng(42) - sequences = rng.standard_normal((4, 15)).astype(np.float32) - result = dtw_pairwise(sequences, use_cuda=False) - np.testing.assert_allclose(np.diag(result), 0.0, atol=1e-6) - - def test_pairwise_symmetric(self): - from baleen._cuda_dtw import dtw_pairwise - - rng = np.random.default_rng(42) - sequences = rng.standard_normal((4, 15)).astype(np.float32) - result = dtw_pairwise(sequences, use_cuda=False) - np.testing.assert_allclose(result, result.T, atol=1e-5) - - def test_pairwise_nonnegative(self): - from baleen._cuda_dtw import dtw_pairwise - - rng = np.random.default_rng(42) - sequences = rng.standard_normal((3, 10)).astype(np.float32) - result = dtw_pairwise(sequences, use_cuda=False) - assert np.all(result >= -1e-6) - - def test_pairwise_return_type(self): - from baleen._cuda_dtw import dtw_pairwise - - sequences = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - result = dtw_pairwise(sequences, use_cuda=False) - assert isinstance(result, np.ndarray) - - def test_pairwise_consistent_with_distance(self): - from baleen._cuda_dtw import dtw_distance, dtw_pairwise - - rng = np.random.default_rng(42) - sequences = rng.standard_normal((3, 10)).astype(np.float32) - matrix = dtw_pairwise(sequences, use_cuda=False) - - for i in range(3): - for j in range(3): - expected = dtw_distance(sequences[i], sequences[j], use_cuda=False) - assert matrix[i, j] == pytest.approx(expected, abs=1e-4), ( - f"matrix[{i},{j}]={matrix[i,j]} != dtw_distance={expected}" - ) - - -class TestInputValidationCPU: - """Input validation must still work on the CPU path.""" + assert isinstance(_dtw.dtw_distance(seq1, seq2, use_cuda=False), float) def test_empty_sequence_raises(self): - from baleen._cuda_dtw import dtw_distance - - with pytest.raises(ValueError, match="empty"): - dtw_distance(np.array([], dtype=np.float32), np.array([1.0], dtype=np.float32)) - - def test_2d_sequence_raises(self): - from baleen._cuda_dtw import dtw_distance - - with pytest.raises(ValueError, match="1-dimensional"): - dtw_distance( - np.array([[1.0, 2.0]], dtype=np.float32), - np.array([1.0, 2.0], dtype=np.float32), + with pytest.raises(ValueError): + _dtw.dtw_distance( + np.array([], dtype=np.float32), + np.array([1.0], dtype=np.float32), + use_cuda=False, ) - def test_pairwise_1d_raises(self): - from baleen._cuda_dtw import dtw_pairwise - - with pytest.raises(ValueError, match="2D"): - dtw_pairwise(np.array([1.0, 2.0, 3.0], dtype=np.float32)) - - def test_pairwise_single_sequence_raises(self): - from baleen._cuda_dtw import dtw_pairwise - - with pytest.raises(ValueError, match="at least 2"): - dtw_pairwise(np.array([[1.0, 2.0, 3.0]], dtype=np.float32)) - - def test_pairwise_zero_length_raises(self): - from baleen._cuda_dtw import dtw_pairwise - - with pytest.raises(ValueError, match="0"): - dtw_pairwise(np.zeros((3, 0), dtype=np.float32)) - - -class TestBackendReporting: - """Backend selection and reporting.""" - - def test_backend_function_exists(self): - from baleen._cuda_dtw import backend - - result = backend() - assert isinstance(result, str) - - def test_backend_is_cuda_or_cpu(self): - from baleen._cuda_dtw import backend - - assert backend() in ("cuda", "cpu") - - def test_backend_matches_cuda_available(self): - from baleen._cuda_dtw import CUDA_AVAILABLE, backend - - if CUDA_AVAILABLE: - assert backend() == "cuda" - else: - assert backend() == "cpu" - - def test_is_available_still_reports_cuda(self): - from baleen._cuda_dtw import CUDA_AVAILABLE, is_available - - assert is_available() == CUDA_AVAILABLE - - -# --------------------------------------------------------------------------- -# use_cuda parameter -# --------------------------------------------------------------------------- - -class TestUseCudaParameter: - """use_cuda parameter must control backend dispatch.""" - - def test_use_cuda_false_forces_cpu(self): - """use_cuda=False must use CPU even if CUDA is available.""" - from baleen._cuda_dtw import dtw_distance - - seq1 = np.array([1.0, 2.0, 3.0], dtype=np.float32) - seq2 = np.array([1.0, 2.0, 3.0], dtype=np.float32) - dist = dtw_distance(seq1, seq2, use_cuda=False) - assert dist == pytest.approx(0.0, abs=1e-6) - - def test_use_cuda_none_auto_selects(self): - """use_cuda=None (default) must auto-select based on availability.""" - from baleen._cuda_dtw import dtw_distance - - seq1 = np.array([1.0, 2.0, 3.0], dtype=np.float32) - seq2 = np.array([1.0, 2.0, 3.0], dtype=np.float32) - try: - dist = dtw_distance(seq1, seq2, use_cuda=None) - except RuntimeError as e: - if "CUDA" in str(e): - pytest.skip("CUDA kernel execution failed (driver/arch mismatch)") - raise - assert dist == pytest.approx(0.0, abs=1e-6) - - def test_use_cuda_true_raises_without_gpu(self): - """use_cuda=True must raise RuntimeError if CUDA is not available.""" - from baleen._cuda_dtw import CUDA_AVAILABLE, dtw_distance - - if not CUDA_AVAILABLE: - with pytest.raises(RuntimeError, match="CUDA"): - dtw_distance( - np.array([1.0, 2.0], dtype=np.float32), - np.array([1.0, 2.0], dtype=np.float32), - use_cuda=True, - ) - - def test_pairwise_use_cuda_false_forces_cpu(self): - """use_cuda=False on dtw_pairwise must use CPU.""" - from baleen._cuda_dtw import dtw_pairwise - - sequences = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - result = dtw_pairwise(sequences, use_cuda=False) - assert result.shape == (3, 3) - np.testing.assert_allclose(np.diag(result), 0.0, atol=1e-6) - - def test_pairwise_use_cuda_true_raises_without_gpu(self): - """use_cuda=True on dtw_pairwise must raise if no CUDA.""" - from baleen._cuda_dtw import CUDA_AVAILABLE, dtw_pairwise - - if not CUDA_AVAILABLE: - with pytest.raises(RuntimeError, match="CUDA"): - dtw_pairwise( - np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), - use_cuda=True, - ) - - def test_cpu_and_default_produce_same_result(self): - """use_cuda=False and use_cuda=None must produce identical results on CPU.""" - from baleen._cuda_dtw import CUDA_AVAILABLE, dtw_distance - - if not CUDA_AVAILABLE: - rng = np.random.default_rng(99) - seq1 = rng.standard_normal(30).astype(np.float32) - seq2 = rng.standard_normal(30).astype(np.float32) - d_auto = dtw_distance(seq1, seq2, use_cuda=None) - d_cpu = dtw_distance(seq1, seq2, use_cuda=False) - assert d_auto == pytest.approx(d_cpu, abs=1e-10) - - -# --------------------------------------------------------------------------- -# dtw_pairwise_varlen on CPU -# --------------------------------------------------------------------------- - -class TestDTWPairwiseVarlenCPU: - - def test_basic_variable_lengths(self): - from baleen._cuda_dtw import dtw_pairwise_varlen +class TestDTWPairwiseVarlen: + def test_shape_symmetry_diagonal(self): signals = [ np.array([1.0, 2.0], dtype=np.float32), np.array([1.0, 2.0, 3.0], dtype=np.float32), np.array([2.0], dtype=np.float32), ] - result = dtw_pairwise_varlen(signals, use_cuda=False) + result = _dtw.dtw_pairwise_varlen(signals, use_cuda=False) assert result.shape == (3, 3) np.testing.assert_allclose(np.diag(result), 0.0, atol=1e-6) - assert np.allclose(result, result.T) + np.testing.assert_allclose(result, result.T, atol=1e-5) def test_consistent_with_dtw_distance(self): - from baleen._cuda_dtw import dtw_distance, dtw_pairwise_varlen - signals = [ np.array([1.0, 2.0, 3.0], dtype=np.float32), np.array([4.0, 5.0], dtype=np.float32), np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32), ] - matrix = dtw_pairwise_varlen(signals, use_cuda=False) + matrix = _dtw.dtw_pairwise_varlen(signals, use_cuda=False) for i in range(3): for j in range(i + 1, 3): - expected = dtw_distance(signals[i], signals[j], use_cuda=False) - np.testing.assert_allclose( - matrix[i, j], expected, rtol=1e-5, - err_msg=f"Mismatch at ({i},{j})", - ) - - def test_many_variable_length_signals(self): - from baleen._cuda_dtw import dtw_pairwise_varlen - - rng = np.random.RandomState(42) - signals = [ - rng.randn(rng.randint(5, 30)).astype(np.float32) - for _ in range(15) - ] - matrix = dtw_pairwise_varlen(signals, use_cuda=False) - assert matrix.shape == (15, 15) - np.testing.assert_allclose(np.diag(matrix), 0.0, atol=1e-6) - assert np.allclose(matrix, matrix.T) - off_diag = matrix[np.triu_indices(15, k=1)] - assert np.all(off_diag > 0.0) + expected = _dtw.dtw_distance(signals[i], signals[j], use_cuda=False) + np.testing.assert_allclose(matrix[i, j], expected, rtol=1e-5) def test_single_signal_raises(self): - from baleen._cuda_dtw import dtw_pairwise_varlen - - with pytest.raises(ValueError, match="at least 2"): - dtw_pairwise_varlen( - [np.array([1.0, 2.0], dtype=np.float32)], - use_cuda=False, - ) - - def test_empty_signal_raises(self): - from baleen._cuda_dtw import dtw_pairwise_varlen - - with pytest.raises(ValueError, match="non-empty"): - dtw_pairwise_varlen( - [ - np.array([], dtype=np.float32), - np.array([1.0], dtype=np.float32), - ], - use_cuda=False, + with pytest.raises(ValueError): + _dtw.dtw_pairwise_varlen( + [np.array([1.0, 2.0], dtype=np.float32)], use_cuda=False ) - def test_use_cuda_true_raises_without_gpu(self): - from baleen._cuda_dtw import CUDA_AVAILABLE, dtw_pairwise_varlen - - if not CUDA_AVAILABLE: - with pytest.raises(RuntimeError, match="CUDA"): - dtw_pairwise_varlen( - [ - np.array([1.0, 2.0], dtype=np.float32), - np.array([3.0, 4.0], dtype=np.float32), - ], - use_cuda=True, - ) - - def test_identical_signals_zero_distance(self): - from baleen._cuda_dtw import dtw_pairwise_varlen - - sig = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) - result = dtw_pairwise_varlen([sig, sig.copy()], use_cuda=False) - assert result.shape == (2, 2) - np.testing.assert_allclose(result, 0.0, atol=1e-6) - - -# --------------------------------------------------------------------------- -# dtw_multi_position_pairwise on CPU -# --------------------------------------------------------------------------- - -class TestMultiPositionBatchCPU: - """dtw_multi_position_pairwise must produce same results as per-position calls.""" +class TestMultiPositionBatch: def test_batch_matches_individual(self): - """Batch result must match individual dtw_pairwise_varlen calls.""" - from baleen._cuda_dtw import dtw_multi_position_pairwise, dtw_pairwise_varlen - rng = np.random.default_rng(42) position_signals = [ - [rng.standard_normal(rng.integers(5, 20)).astype(np.float32) for _ in range(4)], - [rng.standard_normal(rng.integers(5, 20)).astype(np.float32) for _ in range(6)], - [rng.standard_normal(rng.integers(5, 20)).astype(np.float32) for _ in range(3)], + [rng.standard_normal(int(rng.integers(5, 20))).astype(np.float32) for _ in range(4)], + [rng.standard_normal(int(rng.integers(5, 20))).astype(np.float32) for _ in range(6)], + [rng.standard_normal(int(rng.integers(5, 20))).astype(np.float32) for _ in range(3)], ] - - batch_results = dtw_multi_position_pairwise(position_signals, use_cuda=False) - - assert len(batch_results) == 3 + batch = _dtw.dtw_multi_position_pairwise(position_signals, use_cuda=False) + assert len(batch) == 3 for p, signals in enumerate(position_signals): - expected = dtw_pairwise_varlen(signals, use_cuda=False) - np.testing.assert_allclose( - batch_results[p], expected, rtol=1e-5, - err_msg=f"Position {p} mismatch", - ) - - def test_single_position(self): - from baleen._cuda_dtw import dtw_multi_position_pairwise, dtw_pairwise_varlen - - signals = [ - np.array([1.0, 2.0, 3.0], dtype=np.float32), - np.array([4.0, 5.0], dtype=np.float32), - ] - batch_results = dtw_multi_position_pairwise([signals], use_cuda=False) - expected = dtw_pairwise_varlen(signals, use_cuda=False) - np.testing.assert_allclose(batch_results[0], expected, rtol=1e-5) - - def test_position_with_one_signal(self): - """Position with n=1 should produce a 1x1 zero matrix.""" - from baleen._cuda_dtw import dtw_multi_position_pairwise - - position_signals = [ - [np.array([1.0, 2.0], dtype=np.float32)], # n=1 - [np.array([1.0, 2.0], dtype=np.float32), - np.array([3.0, 4.0], dtype=np.float32)], # n=2 - ] - results = dtw_multi_position_pairwise(position_signals, use_cuda=False) - assert results[0].shape == (1, 1) - assert results[0][0, 0] == 0.0 - assert results[1].shape == (2, 2) - - def test_many_positions(self): - from baleen._cuda_dtw import dtw_multi_position_pairwise + expected = _dtw.dtw_pairwise_varlen(signals, use_cuda=False) + np.testing.assert_allclose(batch[p], expected, rtol=1e-5, + err_msg=f"position {p}") + def test_shapes_symmetry_diagonal(self): rng = np.random.default_rng(99) position_signals = [ - [rng.standard_normal(rng.integers(3, 15)).astype(np.float32) - for _ in range(rng.integers(2, 8))] - for _ in range(20) + [rng.standard_normal(int(rng.integers(3, 15))).astype(np.float32) + for _ in range(int(rng.integers(2, 8)))] + for _ in range(12) ] - results = dtw_multi_position_pairwise(position_signals, use_cuda=False) - assert len(results) == 20 - for p, (matrix, signals) in enumerate(zip(results, position_signals)): + results = _dtw.dtw_multi_position_pairwise(position_signals, use_cuda=False) + assert len(results) == 12 + for matrix, signals in zip(results, position_signals): n = len(signals) assert matrix.shape == (n, n) np.testing.assert_allclose(np.diag(matrix), 0.0, atol=1e-6) - assert np.allclose(matrix, matrix.T) + np.testing.assert_allclose(matrix, matrix.T, atol=1e-5) def test_empty_list_raises(self): - from baleen._cuda_dtw import dtw_multi_position_pairwise - - with pytest.raises(ValueError, match="at least 1"): - dtw_multi_position_pairwise([], use_cuda=False) - - def test_use_cuda_true_raises_without_gpu(self): - from baleen._cuda_dtw import CUDA_AVAILABLE, dtw_multi_position_pairwise + with pytest.raises(ValueError): + _dtw.dtw_multi_position_pairwise([], use_cuda=False) - if not CUDA_AVAILABLE: - signals = [[np.array([1.0, 2.0], dtype=np.float32), - np.array([3.0, 4.0], dtype=np.float32)]] - with pytest.raises(RuntimeError, match="CUDA"): - dtw_multi_position_pairwise(signals, use_cuda=True) +class TestBackendReporting: + def test_backend_is_gpu_or_cpu(self): + assert _dtw.backend() in ("gpu", "cpu") -# --------------------------------------------------------------------------- -# dtw_multi_position_pairwise on GPU (conditional) -# --------------------------------------------------------------------------- - -@pytest.mark.skipif( - not __import__("baleen._cuda_dtw", fromlist=["CUDA_AVAILABLE"]).CUDA_AVAILABLE, - reason="CUDA not available", -) -class TestMultiPositionBatchGPU: - """GPU batch DTW must produce same results as CPU.""" - - def test_gpu_matches_cpu(self): - from baleen._cuda_dtw import dtw_multi_position_pairwise - - rng = np.random.default_rng(42) - position_signals = [ - [rng.standard_normal(rng.integers(5, 20)).astype(np.float32) for _ in range(4)], - [rng.standard_normal(rng.integers(5, 20)).astype(np.float32) for _ in range(6)], - [rng.standard_normal(rng.integers(5, 20)).astype(np.float32) for _ in range(3)], - ] - - cpu_results = dtw_multi_position_pairwise(position_signals, use_cuda=False) - gpu_results = dtw_multi_position_pairwise(position_signals, use_cuda=True) + def test_backend_matches_availability(self): + if _dtw.CUDA_AVAILABLE: + assert _dtw.backend() == "gpu" + else: + assert _dtw.backend() == "cpu" - for p in range(3): - np.testing.assert_allclose( - gpu_results[p], cpu_results[p], rtol=1e-4, - err_msg=f"Position {p}: GPU != CPU", - ) + def test_is_available_matches_flag(self): + assert _dtw.is_available() == _dtw.CUDA_AVAILABLE - def test_gpu_many_positions(self): - """Stress test: 50 positions to verify stream concurrency works.""" - from baleen._cuda_dtw import dtw_multi_position_pairwise - rng = np.random.default_rng(123) - position_signals = [ - [rng.standard_normal(rng.integers(5, 30)).astype(np.float32) - for _ in range(rng.integers(3, 15))] - for _ in range(50) - ] +class TestMemoryHelpers: + def test_estimate_gpu_memory_positive(self): + rng = np.random.default_rng(7) + ps = [[rng.standard_normal(50).astype(np.float32) for _ in range(3)]] + assert _dtw.estimate_gpu_memory(ps) > 0 - gpu_results = dtw_multi_position_pairwise(position_signals, use_cuda=True, num_streams=16) - cpu_results = dtw_multi_position_pairwise(position_signals, use_cuda=False) + def test_per_device_memory_list(self): + assert isinstance(_dtw.get_per_device_memory(), list) - for p in range(50): - np.testing.assert_allclose( - gpu_results[p], cpu_results[p], rtol=1e-4, - err_msg=f"Position {p}: GPU != CPU", - ) - def test_gpu_different_stream_counts(self): - """Verify correctness with different numbers of streams.""" - from baleen._cuda_dtw import dtw_multi_position_pairwise +@pytest.mark.skipif(not _dtw.CUDA_AVAILABLE, reason="GPU DTW not available") +class TestGPUMatchesCPU: + """On signals within the no-resample cap, GPU and CPU must agree.""" - rng = np.random.default_rng(77) + def test_gpu_matches_cpu_short(self): + rng = np.random.default_rng(42) position_signals = [ - [rng.standard_normal(rng.integers(5, 15)).astype(np.float32) for _ in range(5)] - for _ in range(10) + [rng.standard_normal(int(rng.integers(5, 200))).astype(np.float32) for _ in range(4)], + [rng.standard_normal(int(rng.integers(5, 200))).astype(np.float32) for _ in range(6)], ] - - cpu_results = dtw_multi_position_pairwise(position_signals, use_cuda=False) - for nstreams in [1, 4, 8, 16, 32]: - gpu_results = dtw_multi_position_pairwise( - position_signals, use_cuda=True, num_streams=nstreams, - ) - for p in range(10): - np.testing.assert_allclose( - gpu_results[p], cpu_results[p], rtol=1e-4, - err_msg=f"Position {p}, streams={nstreams}: GPU != CPU", - ) + cpu = _dtw.dtw_multi_position_pairwise(position_signals, use_cuda=False) + gpu = _dtw.dtw_multi_position_pairwise(position_signals, use_cuda=True) + for p in range(len(position_signals)): + np.testing.assert_allclose(gpu[p], cpu[p], rtol=1e-4, atol=1e-4, + err_msg=f"position {p}") diff --git a/tests/test_f5c.py b/tests/test_f5c.py deleted file mode 100644 index e4dbdfa..0000000 --- a/tests/test_f5c.py +++ /dev/null @@ -1,270 +0,0 @@ -from pathlib import Path -import subprocess -from typing import cast -from unittest.mock import Mock, patch - -import pytest - -import baleen.eventalign._f5c as f5c_mod - - -@pytest.fixture(autouse=True) -def reset_f5c_cache(): - setattr(f5c_mod, "_f5c_version", None) - yield - setattr(f5c_mod, "_f5c_version", None) - - -class TestCheckF5c: - def test_check_f5c_success(self): - mock_result = Mock(stdout="f5c v1.6\n", stderr="") - with patch("baleen.eventalign._f5c.subprocess.run", return_value=mock_result) as mock_run: - assert f5c_mod.check_f5c() == "1.6" - mock_run.assert_called_once_with( - ["f5c", "--version"], - check=True, - capture_output=True, - text=True, - ) - - def test_check_f5c_not_found(self) -> None: - with patch("baleen.eventalign._f5c.subprocess.run", side_effect=FileNotFoundError): - with pytest.raises(RuntimeError, match="f5c not found"): - _ = f5c_mod.check_f5c() - - @pytest.mark.parametrize( - "version_output,expected", - [ - ("f5c v1.6", "1.6"), - ("f5c v1.6.1", "1.6.1"), - ("f5c 1.6", "1.6"), - ("f5c V2.0.3", "2.0.3"), - ], - ) - def test_check_f5c_version_parse_variants(self, version_output: str, expected: str) -> None: - mock_result = Mock(stdout=version_output, stderr="") - with patch("baleen.eventalign._f5c.subprocess.run", return_value=mock_result): - assert f5c_mod.check_f5c() == expected - - def test_check_f5c_caching(self): - mock_result = Mock(stdout="f5c v1.6\n", stderr="") - with patch("baleen.eventalign._f5c.subprocess.run", return_value=mock_result) as mock_run: - assert f5c_mod.check_f5c() == "1.6" - assert f5c_mod.check_f5c() == "1.6" - assert mock_run.call_count == 1 - - -class TestGetF5cVersion: - def test_version_tuple(self) -> None: - with patch("baleen.eventalign._f5c.check_f5c", return_value="1.6"): - assert f5c_mod.get_f5c_version() == (1, 6) - - def test_version_tuple_three_parts(self) -> None: - with patch("baleen.eventalign._f5c.check_f5c", return_value="1.6.1"): - assert f5c_mod.get_f5c_version() == (1, 6, 1) - - -class TestIsIndexed: - def test_indexed_true(self, tmp_path: Path): - fastq = tmp_path / "reads.fastq" - _ = fastq.write_text("@r\nACGT\n+\n####\n", encoding="utf-8") - readdb = tmp_path / "reads.fastq.index.readdb" - _ = readdb.write_text("indexed", encoding="utf-8") - assert f5c_mod.is_indexed(fastq) is True - - def test_indexed_false_missing(self, tmp_path: Path): - fastq = tmp_path / "reads.fastq" - _ = fastq.write_text("@r\nACGT\n+\n####\n", encoding="utf-8") - assert f5c_mod.is_indexed(fastq) is False - - def test_indexed_false_empty(self, tmp_path: Path): - fastq = tmp_path / "reads.fastq" - _ = fastq.write_text("@r\nACGT\n+\n####\n", encoding="utf-8") - readdb = tmp_path / "reads.fastq.index.readdb" - _ = readdb.write_text("", encoding="utf-8") - assert f5c_mod.is_indexed(fastq) is False - - -class TestIsBlow5Indexed: - def test_blow5_indexed_true(self, tmp_path: Path): - blow5 = tmp_path / "reads.blow5" - _ = blow5.write_text("x", encoding="utf-8") - idx = tmp_path / "reads.blow5.idx" - _ = idx.write_text("ok", encoding="utf-8") - assert f5c_mod.is_blow5_indexed(blow5) is True - - def test_blow5_indexed_false_missing(self, tmp_path: Path): - blow5 = tmp_path / "reads.blow5" - _ = blow5.write_text("x", encoding="utf-8") - assert f5c_mod.is_blow5_indexed(blow5) is False - - def test_blow5_indexed_false_empty(self, tmp_path: Path): - blow5 = tmp_path / "reads.blow5" - _ = blow5.write_text("x", encoding="utf-8") - idx = tmp_path / "reads.blow5.idx" - _ = idx.write_text("", encoding="utf-8") - assert f5c_mod.is_blow5_indexed(blow5) is False - - -class TestIndexFastqBlow5: - def test_index_runs_f5c(self, tmp_path: Path): - fastq = tmp_path / "reads.fastq" - blow5 = tmp_path / "reads.blow5" - _ = fastq.write_text("x", encoding="utf-8") - _ = blow5.write_text("x", encoding="utf-8") - - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - f5c_mod.index_fastq_blow5(fastq, blow5) - - mock_run.assert_called_once_with( - ["f5c", "index", "--slow5", str(blow5), str(fastq)], - check=True, - capture_output=True, - text=True, - ) - - def test_index_skips_if_already_indexed(self, tmp_path: Path): - fastq = tmp_path / "reads.fastq" - blow5 = tmp_path / "reads.blow5" - _ = fastq.write_text("x", encoding="utf-8") - _ = blow5.write_text("x", encoding="utf-8") - readdb = tmp_path / "reads.fastq.index.readdb" - _ = readdb.write_text("ok", encoding="utf-8") - - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - f5c_mod.index_fastq_blow5(fastq, blow5) - - mock_run.assert_not_called() - - def test_index_failure(self, tmp_path: Path): - fastq = tmp_path / "reads.fastq" - blow5 = tmp_path / "reads.blow5" - _ = fastq.write_text("x", encoding="utf-8") - _ = blow5.write_text("x", encoding="utf-8") - err = subprocess.CalledProcessError( - returncode=1, - cmd=["f5c", "index"], - stderr="index failed", - ) - - with patch("baleen.eventalign._f5c.subprocess.run", side_effect=err): - with pytest.raises(RuntimeError, match="index failed"): - _ = f5c_mod.index_fastq_blow5(fastq, blow5) - - -class TestIndexBlow5: - def test_index_blow5_runs_slow5tools(self, tmp_path: Path): - blow5 = tmp_path / "reads.blow5" - _ = blow5.write_text("x", encoding="utf-8") - - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - f5c_mod.index_blow5(blow5) - - mock_run.assert_called_once_with( - ["slow5tools", "index", str(blow5)], - check=True, - capture_output=True, - text=True, - ) - - def test_index_blow5_skips_when_indexed(self, tmp_path: Path): - blow5 = tmp_path / "reads.blow5" - _ = blow5.write_text("x", encoding="utf-8") - idx = tmp_path / "reads.blow5.idx" - _ = idx.write_text("ok", encoding="utf-8") - - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - f5c_mod.index_blow5(blow5) - - mock_run.assert_not_called() - - def test_index_blow5_failure(self, tmp_path: Path): - blow5 = tmp_path / "reads.blow5" - _ = blow5.write_text("x", encoding="utf-8") - err = subprocess.CalledProcessError( - returncode=1, - cmd=["slow5tools", "index"], - stderr="slow5tools failed", - ) - - with patch("baleen.eventalign._f5c.subprocess.run", side_effect=err): - with pytest.raises(RuntimeError, match="slow5tools failed"): - _ = f5c_mod.index_blow5(blow5) - - -class TestRunEventalign: - def test_eventalign_basic_command(self, tmp_path: Path): - bam = tmp_path / "in.bam" - ref = tmp_path / "ref.fa" - fastq = tmp_path / "reads.fastq" - blow5 = tmp_path / "reads.blow5" - out = tmp_path / "out.tsv" - - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - result = f5c_mod.run_eventalign(bam, ref, fastq, blow5, out) - - assert result == out - called_cmd = cast(list[str], mock_run.call_args.args[0]) - assert called_cmd[:2] == ["f5c", "eventalign"] - assert "--rna" in called_cmd - assert "--samples" in called_cmd - assert "--signal-index" in called_cmd - assert "--scale-events" in called_cmd - assert "--print-read-names" in called_cmd - - def test_eventalign_no_rna(self, tmp_path: Path): - out = tmp_path / "out.tsv" - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - _ = f5c_mod.run_eventalign("a.bam", "r.fa", "q.fastq", "s.blow5", out, rna=False) - called_cmd = cast(list[str], mock_run.call_args.args[0]) - assert "--rna" not in called_cmd - - def test_eventalign_with_kmer_model(self, tmp_path: Path): - out = tmp_path / "out.tsv" - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - _ = f5c_mod.run_eventalign( - "a.bam", - "r.fa", - "q.fastq", - "s.blow5", - out, - kmer_model="rna004.model", - ) - called_cmd = cast(list[str], mock_run.call_args.args[0]) - assert "--kmer-model" in called_cmd - idx = called_cmd.index("--kmer-model") - assert called_cmd[idx + 1] == "rna004.model" - - def test_eventalign_with_extra_args(self, tmp_path: Path): - out = tmp_path / "out.tsv" - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - _ = f5c_mod.run_eventalign( - "a.bam", - "r.fa", - "q.fastq", - "s.blow5", - out, - extra_args=["--threads", "4", "--iop", "32"], - ) - called_cmd = cast(list[str], mock_run.call_args.args[0]) - assert called_cmd[-4:] == ["--threads", "4", "--iop", "32"] - - def test_eventalign_failure(self, tmp_path: Path): - out = tmp_path / "out.tsv" - err = subprocess.CalledProcessError( - returncode=1, - cmd=["f5c", "eventalign"], - stderr="eventalign failed", - ) - - with patch("baleen.eventalign._f5c.subprocess.run", side_effect=err): - with pytest.raises(RuntimeError, match="eventalign failed"): - _ = f5c_mod.run_eventalign("a.bam", "r.fa", "q.fastq", "s.blow5", out) - - def test_eventalign_output_file_created(self, tmp_path: Path): - out = tmp_path / "out.tsv" - with patch("baleen.eventalign._f5c.subprocess.run") as mock_run: - _ = f5c_mod.run_eventalign("a.bam", "r.fa", "q.fastq", "s.blow5", out) - - assert out.exists() - assert "stdout" in mock_run.call_args.kwargs diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 34d69fb..b8037f1 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -122,7 +122,7 @@ def test_creation(self) -> None: class TestPipelineMetadata: def test_creation(self) -> None: metadata = PipelineMetadata( - f5c_version="1.6", + eventalign_version="1.6", min_depth=15, use_cuda=None, padding=0, @@ -131,7 +131,7 @@ def test_creation(self) -> None: n_contigs_skipped=3, filter_results=[], ) - assert metadata.f5c_version == "1.6" + assert metadata.eventalign_version == "1.6" assert metadata.n_contigs_skipped == 3 @@ -167,13 +167,10 @@ def test_standard_dtw_variable_length_matches_pairwise_loop(self) -> None: assert matrix.shape == (3, 3) assert np.allclose(np.diag(matrix), 0.0) assert np.allclose(matrix, matrix.T) - from tslearn.metrics import dtw as tslearn_dtw + from baleen import _dtw for i in range(3): for j in range(i + 1, 3): - expected = tslearn_dtw( - signals[i].reshape(-1, 1), - signals[j].reshape(-1, 1), - ) + expected = _dtw.dtw_distance(signals[i], signals[j], use_cuda=False) np.testing.assert_allclose( matrix[i, j], expected, rtol=1e-5, err_msg=f"Mismatch at ({i},{j})", @@ -283,10 +280,9 @@ def fake_dtw(a: NDArray[np.float32], b: NDArray[np.float32], **kwargs: object) - return float(abs(len(a) - len(b)) + np.mean(np.abs(a[0] - b[0]))) with ( - patch("baleen.eventalign._pipeline._f5c.check_f5c", return_value="1.6"), - patch("baleen.eventalign._pipeline._f5c.index_fastq_blow5"), - patch("baleen.eventalign._pipeline._f5c.index_blow5"), - patch("baleen.eventalign._pipeline._f5c.run_eventalign", side_effect=fake_run_eventalign), + patch("baleen.eventalign._pipeline._eventalign.check_krill", return_value="1.6"), + patch("baleen.eventalign._pipeline._eventalign.index_blow5"), + patch("baleen.eventalign._pipeline._eventalign.run_eventalign", side_effect=fake_run_eventalign), patch("baleen.eventalign._pipeline._dtw_distance", side_effect=fake_dtw), ): results, metadata = run_pipeline( @@ -301,7 +297,7 @@ def fake_dtw(a: NDArray[np.float32], b: NDArray[np.float32], **kwargs: object) - use_cuda=False, ) - assert metadata.f5c_version == "1.6" + assert metadata.eventalign_version == "1.6" assert metadata.n_contigs_passed_filter == 1 assert set(results.keys()) == {"ctg1"} contig_result = results["ctg1"] @@ -326,10 +322,9 @@ def test_empty_results_when_no_contig_passes_filter(self, tmp_path: Path) -> Non ) with ( - patch("baleen.eventalign._pipeline._f5c.check_f5c", return_value="1.6"), - patch("baleen.eventalign._pipeline._f5c.index_fastq_blow5"), - patch("baleen.eventalign._pipeline._f5c.index_blow5"), - patch("baleen.eventalign._pipeline._f5c.run_eventalign") as mock_run_eventalign, + patch("baleen.eventalign._pipeline._eventalign.check_krill", return_value="1.6"), + patch("baleen.eventalign._pipeline._eventalign.index_blow5"), + patch("baleen.eventalign._pipeline._eventalign.run_eventalign") as mock_run_eventalign, ): results, metadata = run_pipeline( native_bam=native_bam, @@ -346,7 +341,7 @@ def test_empty_results_when_no_contig_passes_filter(self, tmp_path: Path) -> Non assert metadata.n_contigs_passed_filter == 0 mock_run_eventalign.assert_not_called() - def test_f5c_not_found_raises_runtime_error(self, tmp_path: Path) -> None: + def test_krill_not_found_raises_runtime_error(self, tmp_path: Path) -> None: native_bam = _create_test_bam( tmp_path, "native", @@ -360,8 +355,8 @@ def test_f5c_not_found_raises_runtime_error(self, tmp_path: Path) -> None: [("ctg1", 20)], ) - with patch("baleen.eventalign._pipeline._f5c.check_f5c", side_effect=RuntimeError("f5c not found")): - with pytest.raises(RuntimeError, match="f5c not found"): + with patch("baleen.eventalign._pipeline._eventalign.check_krill", side_effect=RuntimeError("krill not found")): + with pytest.raises(RuntimeError, match="krill not found"): _ = run_pipeline( native_bam=native_bam, native_fastq=tmp_path / "native.fastq", @@ -407,10 +402,9 @@ def fake_run_eventalign( out_dir = tmp_path / "results" with ( - patch("baleen.eventalign._pipeline._f5c.check_f5c", return_value="1.6"), - patch("baleen.eventalign._pipeline._f5c.index_fastq_blow5"), - patch("baleen.eventalign._pipeline._f5c.index_blow5"), - patch("baleen.eventalign._pipeline._f5c.run_eventalign", side_effect=fake_run_eventalign), + patch("baleen.eventalign._pipeline._eventalign.check_krill", return_value="1.6"), + patch("baleen.eventalign._pipeline._eventalign.index_blow5"), + patch("baleen.eventalign._pipeline._eventalign.run_eventalign", side_effect=fake_run_eventalign), patch("baleen.eventalign._pipeline._dtw_distance", return_value=1.23), ): _ = run_pipeline( @@ -451,7 +445,7 @@ def test_round_trip_pickle(self, tmp_path: Path) -> None: ) } metadata = PipelineMetadata( - f5c_version="1.6", + eventalign_version="1.6", min_depth=1, use_cuda=None, padding=0, @@ -470,4 +464,4 @@ def test_round_trip_pickle(self, tmp_path: Path) -> None: loaded_results["ctg1"].positions[10].distance_matrix, matrix, ) - assert loaded_meta.f5c_version == "1.6" + assert loaded_meta.eventalign_version == "1.6" diff --git a/tests/test_read_ids.py b/tests/test_read_ids.py index 7160de0..751ea98 100644 --- a/tests/test_read_ids.py +++ b/tests/test_read_ids.py @@ -204,6 +204,7 @@ def test_fingerprint_field_present(self): padding=10, min_mapq=0, primary_only=True, subsample=True, subsample_n=300, legacy_scoring=False, mod_threshold=0.9, write_bam=True, run_hmm=True, target_contigs=None, + pore="rna002", ) fp_on = _compute_resume_fingerprint(read_intersection=True, **kwargs) fp_off = _compute_resume_fingerprint(read_intersection=False, **kwargs) diff --git a/tests/test_resume.py b/tests/test_resume.py index aefaab0..97595e7 100644 --- a/tests/test_resume.py +++ b/tests/test_resume.py @@ -64,6 +64,7 @@ def _baseline_fingerprint(paths: dict[str, Path]) -> dict: run_hmm=True, target_contigs=None, read_intersection=True, + pore="rna002", ) @@ -73,7 +74,7 @@ def test_round_trip_json(self, tmp_path: Path) -> None: # Must be json-serializable. s = json.dumps(fp, sort_keys=True) assert json.loads(s) == fp - assert fp["schema_version"] == 1 + assert fp["schema_version"] == 2 assert set(fp["inputs"]) == { "native_bam", "native_fastq", "native_blow5", "ivt_bam", "ivt_fastq", "ivt_blow5", "ref_fasta", @@ -104,6 +105,7 @@ def test_changing_param_changes_fingerprint(self, tmp_path: Path) -> None: run_hmm=True, target_contigs=None, read_intersection=True, + pore="rna002", ) assert fp1 != fp2 @@ -160,10 +162,21 @@ def test_mismatch_lists_diffs(self, tmp_path: Path) -> None: run_hmm=True, target_contigs=None, read_intersection=True, + pore="rna002", ) with pytest.raises(RuntimeError, match="mod_threshold"): _validate_resume_compatibility(per, fp_new) + def test_schema_version_mismatch_rejected(self, tmp_path: Path) -> None: + paths = _make_inputs(tmp_path) + per = tmp_path / "per_contig" + per.mkdir() + fp_old = _baseline_fingerprint(paths) + fp_old["schema_version"] = fp_old["schema_version"] - 1 # simulate older run + _write_resume_fingerprint(per, fp_old) + with pytest.raises(RuntimeError, match="schema"): + _validate_resume_compatibility(per, _baseline_fingerprint(paths)) + def test_write_is_atomic(self, tmp_path: Path) -> None: paths = _make_inputs(tmp_path) per = tmp_path / "per_contig" diff --git a/tests/test_setup.py b/tests/test_setup.py deleted file mode 100644 index 2b2240d..0000000 --- a/tests/test_setup.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Tests for baleen package setup and installation. - -These tests verify: -1. The package is importable after installation -2. The CUDA extension is optional (graceful fallback on non-GPU machines) -3. Package metadata is correct -4. The public API surface is as expected -""" - -import importlib -import subprocess -import sys - -import pytest - - -class TestPackageImport: - """Test that the baleen package can be imported.""" - - def test_import_baleen(self): - """The top-level 'baleen' package must be importable.""" - import baleen - - assert baleen is not None - - def test_import_cuda_dtw_subpackage(self): - """The _cuda_dtw subpackage must be importable (Python wrapper layer).""" - from baleen import _cuda_dtw - - assert _cuda_dtw is not None - - def test_cuda_dtw_has_public_api(self): - """The _cuda_dtw module must expose the expected public symbols.""" - from baleen import _cuda_dtw - - expected_names = ["dtw_distance", "dtw_pairwise", "cleanup", "is_available", "CUDA_AVAILABLE"] - for name in expected_names: - assert hasattr(_cuda_dtw, name), f"Missing public API: {name}" - - def test_is_available_returns_bool(self): - """is_available() must return a boolean.""" - from baleen._cuda_dtw import is_available - - result = is_available() - assert isinstance(result, bool) - - def test_cuda_available_is_bool(self): - """CUDA_AVAILABLE must be a boolean.""" - from baleen._cuda_dtw import CUDA_AVAILABLE - - assert isinstance(CUDA_AVAILABLE, bool) - - -class TestCUDAGracefulFallback: - """On machines without CUDA, the package must still import without error.""" - - def test_no_crash_on_import(self): - """Importing baleen._cuda_dtw must not raise on non-CUDA machines.""" - # This test always passes if we get here — import already succeeded. - # The real test is that the import in the test above didn't crash. - from baleen._cuda_dtw import CUDA_AVAILABLE - - # On this macOS dev machine, CUDA won't be available - if sys.platform == "darwin": - assert CUDA_AVAILABLE is False - - def test_dtw_distance_works_without_cuda(self): - """dtw_distance() must work via tslearn when CUDA is not available.""" - import numpy as np - - from baleen._cuda_dtw import CUDA_AVAILABLE, dtw_distance - - if not CUDA_AVAILABLE: - # Should NOT raise — tslearn CPU fallback handles it - seq1 = np.array([1.0, 2.0], dtype=np.float32) - seq2 = np.array([1.0, 2.0], dtype=np.float32) - dist = dtw_distance(seq1, seq2) - assert isinstance(dist, float) - assert dist >= 0.0 - - def test_dtw_pairwise_works_without_cuda(self): - """dtw_pairwise() must work via tslearn when CUDA is not available.""" - import numpy as np - - from baleen._cuda_dtw import CUDA_AVAILABLE, dtw_pairwise - - if not CUDA_AVAILABLE: - # Should NOT raise — tslearn CPU fallback handles it - sequences = np.random.randn(3, 10).astype(np.float32) - result = dtw_pairwise(sequences) - assert result.shape == (3, 3) - - def test_cleanup_noop_without_cuda(self): - """cleanup() must be a no-op (not raise) when CUDA is not available.""" - from baleen._cuda_dtw import cleanup - - # Should not raise even without CUDA - cleanup() - - -class TestPackageMetadata: - """Test that package metadata is correctly configured.""" - - def test_pip_show_baleen(self): - """pip show must recognize the installed package.""" - result = subprocess.run( - [sys.executable, "-m", "pip", "show", "baleen"], - capture_output=True, - text=True, - ) - assert result.returncode == 0, f"pip show failed: {result.stderr}" - assert "Name: baleen" in result.stdout - - def test_numpy_is_dependency(self): - """numpy must be listed as a dependency.""" - result = subprocess.run( - [sys.executable, "-m", "pip", "show", "baleen"], - capture_output=True, - text=True, - ) - # Check Requires line includes numpy - for line in result.stdout.splitlines(): - if line.startswith("Requires:"): - assert "numpy" in line, f"numpy not in Requires: {line}" - break - else: - pytest.fail("No 'Requires:' line found in pip show output") - - def test_tslearn_is_dependency(self): - """tslearn must be listed as a dependency.""" - result = subprocess.run( - [sys.executable, "-m", "pip", "show", "baleen"], - capture_output=True, - text=True, - ) - for line in result.stdout.splitlines(): - if line.startswith("Requires:"): - assert "tslearn" in line, f"tslearn not in Requires: {line}" - break - else: - pytest.fail("No 'Requires:' line found in pip show output")