Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ repos:
args: [--application-directories=python,
--unclassifiable-application-module=_tskit]
- repo: https://github.com/asottile/pyupgrade
rev: v3.20.0
rev: v3.21.2
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus]
args: [--py310-plus]
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
Expand Down
4 changes: 1 addition & 3 deletions python/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def capture_output(func, *args, **kwargs):
Runs the specified function and arguments, and returns the
tuple (stdout, stderr) as strings.
"""
buffer_class = io.BytesIO
if sys.version_info[0] == 3:
buffer_class = io.StringIO
buffer_class = io.StringIO
stdout = sys.stdout
sys.stdout = buffer_class()
stderr = sys.stderr
Expand Down
60 changes: 28 additions & 32 deletions python/tests/test_ld_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,13 @@
"""
import contextlib
import io
from collections.abc import Callable
from collections.abc import Generator
from dataclasses import dataclass
from itertools import combinations_with_replacement
from itertools import permutations
from itertools import product
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import List
from typing import Tuple
from typing import Union

import msprime
import numpy as np
Expand Down Expand Up @@ -224,7 +220,7 @@ def norm_hap_weighted(
n_a: int,
n_b: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
"""Create a vector of normalizing coefficients, length of the number of
sample sets. In this normalization strategy, we weight each allele's
Expand All @@ -250,7 +246,7 @@ def norm_hap_weighted_ij(
n_a: int,
n_b: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
"""
Create a vector of normalizing coefficients, length of the number of
Expand Down Expand Up @@ -286,7 +282,7 @@ def norm_total_weighted(
n_a: int,
n_b: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
"""Create a vector of normalizing coefficients, length of the number of
sample sets. In this normalization strategy, we weight each allele's
Expand Down Expand Up @@ -332,7 +328,7 @@ def check_order_bounds_dups(values, max_value):

def get_site_row_col_indices(
row_sites: np.ndarray, col_sites: np.ndarray
) -> Tuple[List[int], List[int], List[int]]:
) -> tuple[list[int], list[int], list[int]]:
"""Co-iterate over the row and column sites, keeping a sorted union of
site values and an index into the unique list of sites for both the row
and column sites. This function produces a list of sites of interest and
Expand Down Expand Up @@ -448,8 +444,8 @@ def get_allele_samples(


def get_mutation_samples(
ts: tskit.TreeSequence, sites: List[int], sample_index_map: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, BitSet]:
ts: tskit.TreeSequence, sites: list[int], sample_index_map: np.ndarray
) -> tuple[np.ndarray, np.ndarray, BitSet]:
"""For a given set of sites, generate a BitSet of all samples posessing
each allelic state for each site. This includes the ancestral state, along
with any mutations contained in the site.
Expand Down Expand Up @@ -507,8 +503,8 @@ def get_mutation_samples(
return num_alleles, site_offsets, allele_samples


SummaryFunc = Callable[[int, np.ndarray, int, np.ndarray, Dict[str, Any]], None]
NormFunc = Callable[[int, np.ndarray, int, int, np.ndarray, Dict[str, Any]], None]
SummaryFunc = Callable[[int, np.ndarray, int, np.ndarray, dict[str, Any]], None]
NormFunc = Callable[[int, np.ndarray, int, int, np.ndarray, dict[str, Any]], None]


def compute_general_two_site_stat_result(
Expand All @@ -523,7 +519,7 @@ def compute_general_two_site_stat_result(
result_dim: int,
func: SummaryFunc,
norm_func: NormFunc,
params: Dict[str, Any],
params: dict[str, Any],
polarised: bool,
result: np.ndarray,
) -> None:
Expand Down Expand Up @@ -777,8 +773,8 @@ def two_branch_count_stat(


def sample_sets_to_bit_array(
ts: tskit.TreeSequence, sample_sets: Union[List[List[int]], List[np.ndarray]]
) -> Tuple[np.ndarray, np.ndarray, BitSet]:
ts: tskit.TreeSequence, sample_sets: list[list[int]] | list[np.ndarray]
) -> tuple[np.ndarray, np.ndarray, BitSet]:
"""Convert the list of sample ids to a bit array. This function takes
sample identifiers and maps them to their enumerated integer values, then
stores these values in a bit array. We produce a BitArray and a numpy
Expand Down Expand Up @@ -994,7 +990,7 @@ def r2_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
"""Summary function for the r2 statistic. We first compute the proportion of
AB, A, and B haplotypes, then we compute the r2 statistic, storing the outputs
Expand Down Expand Up @@ -1028,7 +1024,7 @@ def r2_ij_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
sample_set_sizes = params["sample_set_sizes"]
set_indexes = params["set_indexes"]
Expand Down Expand Up @@ -1062,7 +1058,7 @@ def D_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1082,7 +1078,7 @@ def D2_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1103,7 +1099,7 @@ def D_prime_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1128,7 +1124,7 @@ def r_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1152,7 +1148,7 @@ def Dz_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1174,7 +1170,7 @@ def pi2_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
) -> None:
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand Down Expand Up @@ -1205,7 +1201,7 @@ def pi2_unbiased_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
):
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1227,7 +1223,7 @@ def Dz_unbiased_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
):
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1253,7 +1249,7 @@ def D2_unbiased_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
):
sample_set_sizes = params["sample_set_sizes"]
for k in range(state_dim):
Expand All @@ -1275,7 +1271,7 @@ def D2_ij_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
):
sample_set_sizes = params["sample_set_sizes"]
set_indexes = params["set_indexes"]
Expand Down Expand Up @@ -1307,7 +1303,7 @@ def D2_ij_unbiased_summary_func(
state: np.ndarray,
result_dim: int,
result: np.ndarray,
params: Dict[str, Any],
params: dict[str, Any],
):
sample_set_sizes = params["sample_set_sizes"]
set_indexes = params["set_indexes"]
Expand Down Expand Up @@ -1831,8 +1827,8 @@ class TreeState:
# 0 1
# 1 0
# 1 1
edges_out: List[int] # list of edges removed during iteration
edges_in: List[int] # list of edges added during iteration
edges_out: list[int] # list of edges removed during iteration
edges_in: list[int] # list of edges added during iteration

def __init__(self, ts, sample_sets, num_sample_sets, sample_index_map):
self.pos = tsutil.TreeIndexes(ts)
Expand Down
5 changes: 2 additions & 3 deletions python/tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import random
import string
import struct
import typing

import msprime
import numpy as np
Expand Down Expand Up @@ -1718,7 +1717,7 @@ def __iter__(self):
class EdgeRange:
start: int
stop: int
order: typing.List
order: list


class TreeIndexes:
Expand Down Expand Up @@ -2164,7 +2163,7 @@ def metadata_map(tables):
return out


@functools.lru_cache(maxsize=None)
@functools.cache
def all_trees_ts(n):
"""
Generate a tree sequence that corresponds to the lexicographic listing
Expand Down
6 changes: 2 additions & 4 deletions python/tskit/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
import operator
import warnings
import xml.dom.minidom
from collections.abc import Mapping
from dataclasses import dataclass
from typing import List
from typing import Mapping
from typing import Union

import numpy as np

Expand Down Expand Up @@ -538,7 +536,7 @@ def clip_ts(ts, x_min, x_max, max_num_trees=None):
return ts, tree_status, offsets


def check_y_ticks(ticks: Union[List, Mapping, None]) -> Mapping:
def check_y_ticks(ticks: list | Mapping | None) -> Mapping:
"""
Later we might want to implement a tick locator function, such that e.g. ticks=5
selects ~5 nicely spaced tick locations (with sensible behaviour for log scales)
Expand Down
2 changes: 1 addition & 1 deletion python/tskit/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
import pprint
import struct
import types
from collections.abc import Mapping
from itertools import islice
from typing import Any
from typing import Mapping

import jsonschema
import numpy as np
Expand Down
21 changes: 9 additions & 12 deletions python/tskit/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
import operator
import warnings
from dataclasses import dataclass
from typing import Dict
from typing import Optional
from typing import Union

import numpy as np

Expand Down Expand Up @@ -84,7 +81,7 @@ class IndividualTableRow(util.Dataclass):
"""
See :attr:`Individual.parents`
"""
metadata: Optional[Union[bytes, dict]]
metadata: bytes | dict | None
"""
See :attr:`Individual.metadata`
"""
Expand Down Expand Up @@ -124,7 +121,7 @@ class NodeTableRow(util.Dataclass):
"""
See :attr:`Node.individual`
"""
metadata: Optional[Union[bytes, dict]]
metadata: bytes | dict | None
"""
See :attr:`Node.metadata`
"""
Expand Down Expand Up @@ -154,7 +151,7 @@ class EdgeTableRow(util.Dataclass):
"""
See :attr:`Edge.child`
"""
metadata: Optional[Union[bytes, dict]]
metadata: bytes | dict | None
"""
See :attr:`Edge.metadata`
"""
Expand Down Expand Up @@ -192,7 +189,7 @@ class MigrationTableRow(util.Dataclass):
"""
See :attr:`Migration.time`
"""
metadata: Optional[Union[bytes, dict]]
metadata: bytes | dict | None
"""
See :attr:`Migration.metadata`
"""
Expand All @@ -214,7 +211,7 @@ class SiteTableRow(util.Dataclass):
"""
See :attr:`Site.ancestral_state`
"""
metadata: Optional[Union[bytes, dict]]
metadata: bytes | dict | None
"""
See :attr:`Site.metadata`
"""
Expand Down Expand Up @@ -244,7 +241,7 @@ class MutationTableRow(util.Dataclass):
"""
See :attr:`Mutation.parent`
"""
metadata: Optional[Union[bytes, dict]]
metadata: bytes | dict | None
"""
See :attr:`Mutation.metadata`
"""
Expand Down Expand Up @@ -279,7 +276,7 @@ class PopulationTableRow(util.Dataclass):
"""

__slots__ = ["metadata"]
metadata: Optional[Union[bytes, dict]]
metadata: bytes | dict | None
"""
See :attr:`Population.metadata`
"""
Expand Down Expand Up @@ -3247,7 +3244,7 @@ def asdict(self, force_offset_64=False):
return self._ll_tables.asdict(force_offset_64)

@property
def table_name_map(self) -> Dict:
def table_name_map(self) -> dict:
"""
Returns a dictionary mapping table names to the corresponding
table instances. For example, the returned dictionary will contain the
Expand All @@ -3265,7 +3262,7 @@ def table_name_map(self) -> Dict:
}

@property
def name_map(self) -> Dict:
def name_map(self) -> dict:
# Deprecated in 0.4.1
warnings.warn(
"name_map is deprecated; use table_name_map instead",
Expand Down
Loading