Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 20 additions & 5 deletions python/genvarloader/_dataset/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,29 @@ def bed_to_regions(
pl.col("chromStart", "chromEnd").cast(pl.Int32),
]

if bed.schema.get("strand", None) in (pl.Utf8, pl.Categorical):
strand_dtype = bed.schema.get("strand", None)
if strand_dtype is None:
cols.append(pl.lit(1).cast(pl.Int32).alias("strand"))
elif strand_dtype == pl.Utf8 or strand_dtype == pl.Categorical:
# Cast Categorical -> Utf8 first. The ``in (pl.Utf8, pl.Categorical)``
# check that already lives here picks up the right branch, but
# ``replace_strict({"+": 1, "-": -1}, ...)`` won't reliably accept
# Categorical keys across all supported polars versions -- on the
# versions where it doesn't, the strand column survives the
# ``select(...)`` call as Categorical, and ``to_numpy()`` on a frame
# mixing ``Int32`` + ``Categorical`` collapses to ``dtype=object``,
# which downstream numba kernels reject with
# ``non-precise type array(pyobject)``. Casting to Utf8 first keeps
# the strand column numeric and the regions array stays ``int32``.
cols.append(
pl.col("strand").replace_strict({"+": 1, "-": -1}, return_dtype=pl.Int32)
pl.col("strand")
.cast(pl.Utf8)
.replace_strict({"+": 1, "-": -1}, return_dtype=pl.Int32)
)
elif "strand" not in bed.schema:
cols.append(pl.lit(1).cast(pl.Int32).alias("strand"))
else:
cols.append(pl.col("strand"))
# An already-numeric strand column is allowed; force Int32 so the
# final array doesn't widen to int64 / object on int8 / int16 input.
cols.append(pl.col("strand").cast(pl.Int32))

return bed.select(cols).to_numpy()

Expand Down
58 changes: 57 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,65 @@
import numpy as np
from genvarloader._dataset._utils import splits_sum_le_value
import polars as pl
from genoray._utils import ContigNormalizer
from genvarloader._dataset._utils import bed_to_regions, splits_sum_le_value
from genvarloader._utils import normalize_contig_name
from pytest_cases import parametrize_with_cases


def test_bed_to_regions_categorical_strand_returns_int32() -> None:
"""Regression: BEDs whose strand column is Categorical (e.g. from
`polars.bed.sort` or any pipeline that round-trips a strand category)
must produce an int32 regions array, not dtype=object.

Without the Categorical-aware branch in `bed_to_regions`, the strand
column survived the `select(...)` call as Categorical, polars' mixed
Int32 + Categorical `to_numpy()` collapsed to dtype=object, and
downstream numba kernels (`get_diffs_sparse`) refused to compile with
`non-precise type array(pyobject, 1d, A)`. See PR for the chr19 ADNI
cohort reproducer.
"""
bed = pl.DataFrame(
{
"chrom": ["chr19", "chr19"],
"chromStart": [44906624, 44907759],
"chromEnd": [44906667, 44907952],
"strand": ["+", "+"],
}
).with_columns(pl.col("strand").cast(pl.Categorical))
assert bed.schema["strand"] == pl.Categorical
regions = bed_to_regions(bed, ContigNormalizer(["chr19"]))
assert regions.dtype == np.int32, f"want int32, got {regions.dtype}"
assert regions.shape == (2, 4)
np.testing.assert_array_equal(
regions,
np.array([[0, 44906624, 44906667, 1], [0, 44907759, 44907952, 1]], np.int32),
)


def test_bed_to_regions_utf8_strand_still_works() -> None:
"""Sanity: the existing Utf8-strand path still produces int32."""
bed = pl.DataFrame(
{
"chrom": ["chr1"],
"chromStart": [100],
"chromEnd": [200],
"strand": ["-"],
}
)
assert bed.schema["strand"] == pl.Utf8
regions = bed_to_regions(bed, ContigNormalizer(["chr1"]))
assert regions.dtype == np.int32
np.testing.assert_array_equal(regions, np.array([[0, 100, 200, -1]], np.int32))


def test_bed_to_regions_no_strand_defaults_to_plus() -> None:
"""BEDs without a strand column get strand=1 (existing behaviour)."""
bed = pl.DataFrame({"chrom": ["chr1"], "chromStart": [100], "chromEnd": [200]})
regions = bed_to_regions(bed, ContigNormalizer(["chr1"]))
assert regions.dtype == np.int32
np.testing.assert_array_equal(regions, np.array([[0, 100, 200, 1]], np.int32))


def test_splits_sum_le_value():
max_size = 10
sizes = np.array([3, 5, 2, 4, 7, 5, 2], np.int32)
Expand Down
Loading