From 86ea00b8e482c6f9801716d06b9601c4a89ab4db Mon Sep 17 00:00:00 2001 From: okiner-3 Date: Fri, 12 Jun 2026 15:33:07 +0900 Subject: [PATCH] auto-infer locations for predict_dte/pte/ldte/lpte (#68) Make the locations argument optional in predict_dte, predict_pte, predict_ldte, and predict_lpte. When omitted, locations are generated automatically from the observed outcomes using np.linspace with a bin count derived from np.histogram_bin_edges(outcomes, bins='auto'). For the interval-based methods (PTE/LPTE), the left endpoint is shifted just below outcomes.min() so the smallest observation falls inside the first interval. The actual locations array used is stored on estimator.last_locations for plotting and downstream use. --- dte_adj/base.py | 32 +++++++++++++--- dte_adj/local.py | 68 ++++++++++++++++++++++++++++------ dte_adj/util.py | 36 ++++++++++++++++++ tests/test_local_estimators.py | 55 +++++++++++++++++++++++++++ tests/test_simple_estimator.py | 55 +++++++++++++++++++++++++++ tests/test_utils.py | 39 ++++++++++++++++++- 6 files changed, 267 insertions(+), 18 deletions(-) diff --git a/dte_adj/base.py b/dte_adj/base.py index a3338e1..fc0a02f 100644 --- a/dte_adj/base.py +++ b/dte_adj/base.py @@ -4,6 +4,7 @@ from abc import ABC from tqdm.auto import tqdm import dte_adj +from dte_adj.util import _infer_default_locations class DistributionEstimatorBase(ABC): @@ -19,12 +20,13 @@ def __init__(self): self.covariates = None self.outcomes = None self.treatment_arms = None + self.last_locations = None def predict_dte( self, target_treatment_arm: int, control_treatment_arm: int, - locations: np.ndarray, + locations: Optional[np.ndarray] = None, alpha: float = 0.05, variance_type="moment", n_bootstrap=500, @@ -40,7 +42,11 @@ def predict_dte( Args: target_treatment_arm (int): The index of the treatment arm of the treatment group. control_treatment_arm (int): The index of the treatment arm of the control group. - locations (np.ndarray): Scalar values to be used for computing the cumulative distribution. + locations (np.ndarray, optional): Scalar values to be used for computing the cumulative + distribution. If None, evenly-spaced locations spanning the observed outcome range + are generated automatically. The number of points is determined from data size and + distribution via ``np.histogram_bin_edges(outcomes, bins='auto')``. The actual array + used is stored on ``self.last_locations``. alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are "moment", "simple", and "uniform". Defaults to "moment". @@ -80,6 +86,11 @@ def predict_dte( print(f"DTE shape: {dte.shape}") # Should match locations.shape print(f"Average DTE: {dte.mean():.3f}") """ + if locations is None: + locations = _infer_default_locations( + self.outcomes, for_intervals=False + ) + self.last_locations = locations return self._compute_dtes( target_treatment_arm, control_treatment_arm, @@ -94,7 +105,7 @@ def predict_pte( self, target_treatment_arm: int, control_treatment_arm: int, - locations: np.ndarray, + locations: Optional[np.ndarray] = None, alpha: float = 0.05, variance_type="moment", n_bootstrap=500, @@ -110,8 +121,14 @@ def predict_pte( Args: target_treatment_arm (int): The index of the treatment arm of the treatment group. control_treatment_arm (int): The index of the treatment arm of the control group. - locations (np.ndarray): Scalar values defining interval boundaries for probability computation. - For each interval (locations[i], locations[i+1]], the PTE is computed. + locations (np.ndarray, optional): Scalar values defining interval boundaries for + probability computation. For each interval (locations[i], locations[i+1]], the PTE + is computed. If None, boundaries spanning the observed outcome range are generated + automatically with the left endpoint placed just below ``outcomes.min()`` so that + minimum-valued samples fall inside the first interval. The number of boundaries is + determined from data size and distribution via + ``np.histogram_bin_edges(outcomes, bins='auto')``. The actual array used is stored + on ``self.last_locations``. alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. variance_type (str, optional): Variance type to be used to compute confidence intervals. Available values are "moment", "simple", and "uniform". Defaults to "moment". @@ -154,6 +171,11 @@ def predict_pte( print(f"PTE shape: {pte.shape}") # Should be (4,) for 4 intervals print(f"Interval effects: {pte}") """ + if locations is None: + locations = _infer_default_locations( + self.outcomes, for_intervals=True + ) + self.last_locations = locations return self._compute_ptes( target_treatment_arm, control_treatment_arm, diff --git a/dte_adj/local.py b/dte_adj/local.py index cb80457..5d1fdd2 100644 --- a/dte_adj/local.py +++ b/dte_adj/local.py @@ -1,12 +1,18 @@ from __future__ import annotations import numpy as np -from typing import Tuple +from typing import Optional, Tuple from dte_adj.stratified import ( SimpleStratifiedDistributionEstimator, AdjustedStratifiedDistributionEstimator, ) -from dte_adj.util import ArrayLike, compute_ldte, compute_lpte, _convert_to_ndarray +from dte_adj.util import ( + ArrayLike, + compute_ldte, + compute_lpte, + _convert_to_ndarray, + _infer_default_locations, +) class SimpleLocalDistributionEstimator(SimpleStratifiedDistributionEstimator): @@ -59,7 +65,7 @@ def predict_ldte( self, target_treatment_arm: int, control_treatment_arm: int, - locations: np.ndarray, + locations: Optional[np.ndarray] = None, alpha: float = 0.05, display_progress: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -73,7 +79,11 @@ def predict_ldte( Args: target_treatment_arm (int): The index of the treatment arm of the treatment group. control_treatment_arm (int): The index of the treatment arm of the control group. - locations (np.ndarray): Scalar values to be used for computing the cumulative distribution. + locations (np.ndarray, optional): Scalar values to be used for computing the cumulative + distribution. If None, evenly-spaced locations spanning the observed outcome range + are generated automatically. The number of points is determined from data size and + distribution via ``np.histogram_bin_edges(outcomes, bins='auto')``. The actual + array used is stored on ``self.last_locations``. alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. display_progress (bool, optional): Whether to display a progress bar. Defaults to True. @@ -113,6 +123,11 @@ def predict_ldte( print(f"LDTE shape: {ldte.shape}") # Should match locations.shape print(f"Average LDTE: {ldte.mean():.3f}") """ + if locations is None: + locations = _infer_default_locations( + self.outcomes, for_intervals=False + ) + self.last_locations = locations return compute_ldte( self, target_treatment_arm, @@ -126,7 +141,7 @@ def predict_lpte( self, target_treatment_arm: int, control_treatment_arm: int, - locations: np.ndarray, + locations: Optional[np.ndarray] = None, alpha: float = 0.05, display_progress: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -140,8 +155,13 @@ def predict_lpte( Args: target_treatment_arm (int): The index of the treatment arm of the treatment group. control_treatment_arm (int): The index of the treatment arm of the control group. - locations (np.ndarray): Scalar values defining interval boundaries for probability computation. - For each interval (locations[i], locations[i+1]], the LPTE is computed. + locations (np.ndarray, optional): Scalar values defining interval boundaries for + probability computation. For each interval (locations[i], locations[i+1]], the LPTE + is computed. If None, boundaries spanning the observed outcome range are generated + automatically with the left endpoint placed just below ``outcomes.min()``. The + number of boundaries is determined from data size and distribution via + ``np.histogram_bin_edges(outcomes, bins='auto')``. The actual array used is stored + on ``self.last_locations``. alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. display_progress (bool, optional): Whether to display a progress bar. Defaults to True. @@ -183,6 +203,11 @@ def predict_lpte( print(f"LPTE shape: {lpte.shape}") # Should be (4,) for 4 intervals print(f"Interval effects: {lpte}") """ + if locations is None: + locations = _infer_default_locations( + self.outcomes, for_intervals=True + ) + self.last_locations = locations return compute_lpte( self, target_treatment_arm, @@ -234,7 +259,7 @@ def predict_ldte( self, target_treatment_arm: int, control_treatment_arm: int, - locations: np.ndarray, + locations: Optional[np.ndarray] = None, alpha: float = 0.05, display_progress: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -247,7 +272,11 @@ def predict_ldte( Args: target_treatment_arm (int): The index of the treatment arm of the treatment group. control_treatment_arm (int): The index of the treatment arm of the control group. - locations (np.ndarray): Scalar values to be used for computing the cumulative distribution. + locations (np.ndarray, optional): Scalar values to be used for computing the cumulative + distribution. If None, evenly-spaced locations spanning the observed outcome range + are generated automatically. The number of points is determined from data size and + distribution via ``np.histogram_bin_edges(outcomes, bins='auto')``. The actual + array used is stored on ``self.last_locations``. alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. display_progress (bool, optional): Whether to display a progress bar. Defaults to True. @@ -289,6 +318,11 @@ def predict_ldte( print(f"Adjusted LDTE: {ldte.mean():.3f}") """ + if locations is None: + locations = _infer_default_locations( + self.outcomes, for_intervals=False + ) + self.last_locations = locations return compute_ldte( self, target_treatment_arm, @@ -302,7 +336,7 @@ def predict_lpte( self, target_treatment_arm: int, control_treatment_arm: int, - locations: np.ndarray, + locations: Optional[np.ndarray] = None, alpha: float = 0.05, display_progress: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -315,8 +349,13 @@ def predict_lpte( Args: target_treatment_arm (int): The index of the treatment arm of the treatment group. control_treatment_arm (int): The index of the treatment arm of the control group. - locations (np.ndarray): Scalar values defining interval boundaries for probability computation. - For each interval (locations[i], locations[i+1]], the LPTE is computed. + locations (np.ndarray, optional): Scalar values defining interval boundaries for + probability computation. For each interval (locations[i], locations[i+1]], the LPTE + is computed. If None, boundaries spanning the observed outcome range are generated + automatically with the left endpoint placed just below ``outcomes.min()``. The + number of boundaries is determined from data size and distribution via + ``np.histogram_bin_edges(outcomes, bins='auto')``. The actual array used is stored + on ``self.last_locations``. alpha (float, optional): Significance level of the confidence bound. Defaults to 0.05. display_progress (bool, optional): Whether to display a progress bar. Defaults to True. @@ -361,6 +400,11 @@ def predict_lpte( print(f"Adjusted LPTE: {lpte}") """ + if locations is None: + locations = _infer_default_locations( + self.outcomes, for_intervals=True + ) + self.last_locations = locations return compute_lpte( self, target_treatment_arm, diff --git a/dte_adj/util.py b/dte_adj/util.py index 3c5dd13..0b90803 100644 --- a/dte_adj/util.py +++ b/dte_adj/util.py @@ -32,6 +32,42 @@ def _convert_to_ndarray(data: ArrayLike) -> np.ndarray: return np.asarray(data) +def _infer_default_locations( + outcomes: np.ndarray, + for_intervals: bool = False, +) -> np.ndarray: + """Generate evenly-spaced default locations from observed outcomes. + + The number of points is determined from data size and distribution using + ``np.histogram_bin_edges(outcomes, bins='auto')`` (which combines the + Sturges and Freedman-Diaconis rules). + + Args: + outcomes (np.ndarray): Observed outcomes used to determine the range. + for_intervals (bool, optional): If True, the left endpoint is placed + slightly below ``outcomes.min()`` so that observations equal to the + minimum fall inside the first interval ``(loc[0], loc[1]]``. Set + this for PTE/LPTE estimation. Defaults to False. + + Returns: + np.ndarray: Evenly-spaced locations array. + """ + n_locations = len(np.histogram_bin_edges(outcomes, bins="auto")) + + y_min = float(outcomes.min()) + y_max = float(outcomes.max()) + + if for_intervals: + # Place the left endpoint strictly below y_min so that the smallest + # observation falls inside the first interval (loc[0], loc[1]]. The + # offset scales with the magnitude of the data so that ``y_min - eps`` + # is representable even when the outcome range is zero. + scale = max(y_max - y_min, abs(y_min), abs(y_max), 1.0) + eps = scale * 1e-9 + return np.linspace(y_min - eps, y_max, n_locations) + return np.linspace(y_min, y_max, n_locations) + + def compute_confidence_intervals( vec_y: np.ndarray, vec_d: np.ndarray, diff --git a/tests/test_local_estimators.py b/tests/test_local_estimators.py index 58c7233..5d5d96e 100644 --- a/tests/test_local_estimators.py +++ b/tests/test_local_estimators.py @@ -265,6 +265,61 @@ def test_simple_local_estimator_predict_lpte(self): self.assertTrue(np.all(lower_bound <= beta)) self.assertTrue(np.all(beta <= upper_bound)) + def test_simple_local_estimator_predict_ldte_without_locations(self): + """LDTE auto-infers locations from outcomes when none are passed.""" + estimator = SimpleLocalDistributionEstimator() + estimator.fit( + self.covariates, + self.treatment_arms, + self.treatment_indicator, + self.outcomes, + self.strata, + ) + + beta, lower, upper = estimator.predict_ldte( + target_treatment_arm=1, + control_treatment_arm=0, + alpha=0.05, + ) + + n = estimator.last_locations.shape[0] + self.assertGreater(n, 1) + self.assertEqual(beta.shape, (n,)) + self.assertEqual(lower.shape, (n,)) + self.assertEqual(upper.shape, (n,)) + self.assertAlmostEqual( + estimator.last_locations[0], float(self.outcomes.min()) + ) + self.assertAlmostEqual( + estimator.last_locations[-1], float(self.outcomes.max()) + ) + + def test_simple_local_estimator_predict_lpte_without_locations(self): + """LPTE auto-infers interval boundaries, with left endpoint below min.""" + estimator = SimpleLocalDistributionEstimator() + estimator.fit( + self.covariates, + self.treatment_arms, + self.treatment_indicator, + self.outcomes, + self.strata, + ) + + beta, lower, upper = estimator.predict_lpte( + target_treatment_arm=1, + control_treatment_arm=0, + alpha=0.05, + ) + + n = estimator.last_locations.shape[0] + # LPTE output length is len(locations) - 1 + self.assertEqual(beta.shape, (n - 1,)) + self.assertEqual(lower.shape, (n - 1,)) + self.assertEqual(upper.shape, (n - 1,)) + self.assertLess( + estimator.last_locations[0], float(self.outcomes.min()) + ) + def test_adjusted_local_estimator_predict_lpte(self): """Test that AdjustedLocalDistributionEstimator can predict LPTE.""" base_model = LogisticRegression(random_state=42) diff --git a/tests/test_simple_estimator.py b/tests/test_simple_estimator.py index 928ef3e..abcd862 100644 --- a/tests/test_simple_estimator.py +++ b/tests/test_simple_estimator.py @@ -193,6 +193,61 @@ def test_compute_cumulative_distribution(self): ) +class TestAutoLocations(unittest.TestCase): + """Test auto-inference of locations when not provided to predict_dte/predict_pte.""" + + def setUp(self): + np.random.seed(0) + n = 200 + self.covariates = np.random.randn(n, 3) + self.treatment_arms = np.random.binomial(1, 0.5, n) + self.outcomes = np.random.randn(n) + 0.5 * self.treatment_arms + self.estimator = SimpleDistributionEstimator().fit( + self.covariates, self.treatment_arms, self.outcomes + ) + + def test_predict_dte_without_locations(self): + dte, lower, upper = self.estimator.predict_dte( + 1, 0, variance_type="simple", display_progress=False + ) + self.assertIsNotNone(self.estimator.last_locations) + self.assertEqual(dte.shape, self.estimator.last_locations.shape) + self.assertEqual(lower.shape, dte.shape) + self.assertEqual(upper.shape, dte.shape) + self.assertAlmostEqual( + self.estimator.last_locations[0], float(self.outcomes.min()) + ) + self.assertAlmostEqual( + self.estimator.last_locations[-1], float(self.outcomes.max()) + ) + + def test_predict_pte_without_locations(self): + pte, lower, upper = self.estimator.predict_pte( + 1, 0, variance_type="simple", display_progress=False + ) + # PTE returns len(locations)-1 outputs + self.assertEqual(pte.shape, (self.estimator.last_locations.shape[0] - 1,)) + self.assertEqual(lower.shape, pte.shape) + self.assertEqual(upper.shape, pte.shape) + # Left endpoint should be strictly below outcomes.min() so the smallest + # observation is captured by the first interval. + self.assertLess( + self.estimator.last_locations[0], float(self.outcomes.min()) + ) + + def test_explicit_locations_stored_on_last_locations(self): + locations = np.linspace(-2, 2, 7) + dte, _, _ = self.estimator.predict_dte( + 1, + 0, + locations=locations, + variance_type="simple", + display_progress=False, + ) + self.assertEqual(dte.shape, (7,)) + np.testing.assert_array_equal(self.estimator.last_locations, locations) + + class TestE2E(unittest.TestCase): def test_e2e(self): # Arrange diff --git a/tests/test_utils.py b/tests/test_utils.py index 7dafb3a..aace50f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd import polars as pl -from dte_adj.util import _convert_to_ndarray +from dte_adj.util import _convert_to_ndarray, _infer_default_locations class TestConvertToNdarray(unittest.TestCase): @@ -55,3 +55,40 @@ def test_tuple(self): result = _convert_to_ndarray(data) self.assertIsInstance(result, np.ndarray) np.testing.assert_array_equal(result, np.array([1, 2, 3])) + + +class TestInferDefaultLocations(unittest.TestCase): + """Test _infer_default_locations for generating default DTE/PTE locations.""" + + def test_evenly_spaced_spanning_outcome_range(self): + rng = np.random.default_rng(0) + outcomes = rng.normal(size=500) + result = _infer_default_locations(outcomes) + self.assertAlmostEqual(result[0], float(outcomes.min())) + self.assertAlmostEqual(result[-1], float(outcomes.max())) + diffs = np.diff(result) + np.testing.assert_allclose(diffs, diffs[0]) + + def test_for_intervals_left_endpoint_below_min(self): + outcomes = np.linspace(2.0, 5.0, 50) + result = _infer_default_locations(outcomes, for_intervals=True) + self.assertLess(result[0], outcomes.min()) + self.assertAlmostEqual(result[-1], 5.0) + + def test_auto_n_locations_matches_histogram_bin_edges(self): + rng = np.random.default_rng(0) + outcomes = rng.normal(size=1000) + expected_n = len(np.histogram_bin_edges(outcomes, bins="auto")) + result = _infer_default_locations(outcomes) + self.assertEqual(result.shape, (expected_n,)) + + def test_auto_scales_with_sample_size(self): + rng = np.random.default_rng(1) + small = _infer_default_locations(rng.normal(size=100)) + large = _infer_default_locations(rng.normal(size=10000)) + self.assertGreater(large.shape[0], small.shape[0]) + + def test_for_intervals_constant_outcomes(self): + outcomes = np.full(50, 3.0) + result = _infer_default_locations(outcomes, for_intervals=True) + self.assertLess(result[0], 3.0)