Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions causalml/inference/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ def estimate_ate(
):
pass

def bootstrap(self, X, treatment, y, p=None, size=10000):
def bootstrap(self, X, treatment, y, p=None, size=10000, rng=None):
"""Runs a single bootstrap. Fits on bootstrapped sample, then predicts on whole population."""
idxs = np.random.choice(np.arange(0, X.shape[0]), size=size)
if rng is not None:
idxs = rng.choice(np.arange(0, X.shape[0]), size=size)
else:
idxs = np.random.choice(np.arange(0, X.shape[0]), size=size)
X_b = X[idxs]

if p is not None:
Expand Down
50 changes: 48 additions & 2 deletions causalml/inference/meta/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,24 @@ def fit(self, X, treatment, y, p=None, seed=None):
)
self.models_tau[group][ifold].fit(X_filt, dr)

def bootstrap(self, X, treatment, y, p=None, size=10000, rng=None, seed=None):
"""Runs a single bootstrap with optional deterministic cross-fit seed."""
if rng is not None:
idxs = rng.choice(np.arange(0, X.shape[0]), size=size)
else:
idxs = np.random.choice(np.arange(0, X.shape[0]), size=size)
X_b = X[idxs]

if p is not None:
p_b = {group: _p[idxs] for group, _p in p.items()}
else:
p_b = None

treatment_b = treatment[idxs]
y_b = y[idxs]
self.fit(X=X_b, treatment=treatment_b, y=y_b, p=p_b, seed=seed)
return self.predict(X=X, p=p)

def predict(
self, X, treatment=None, y=None, p=None, return_components=False, verbose=True
):
Expand Down Expand Up @@ -312,10 +330,25 @@ def fit_predict(
te_bootstraps = np.zeros(
shape=(X.shape[0], self.t_groups.shape[0], n_bootstraps)
)
# seed controls both bootstrap resampling and cross-fit randomness.
rng = np.random.default_rng(seed) if seed is not None else None

logger.info("Bootstrap Confidence Intervals")
for i in tqdm(range(n_bootstraps)):
te_b = self.bootstrap(X, treatment, y, p, size=bootstrap_size)
bootstrap_seed = (
int(rng.integers(np.iinfo(np.int32).max))
if rng is not None
else None
)
te_b = self.bootstrap(
X,
treatment,
y,
p,
size=bootstrap_size,
rng=rng,
seed=bootstrap_seed,
)
te_bootstraps[:, :, i] = te_b

te_lower = np.percentile(te_bootstraps, (self.ate_alpha / 2) * 100, axis=2)
Expand Down Expand Up @@ -428,10 +461,23 @@ def estimate_ate(

logger.info("Bootstrap Confidence Intervals for ATE")
ate_bootstraps = np.zeros(shape=(self.t_groups.shape[0], n_bootstraps))
# seed controls both bootstrap resampling and cross-fit randomness.
rng = np.random.default_rng(seed) if seed is not None else None

for n in tqdm(range(n_bootstraps)):
bootstrap_seed = (
int(rng.integers(np.iinfo(np.int32).max))
if rng is not None
else None
)
cate_b = self.bootstrap(
X, treatment, y, p, size=bootstrap_size, seed=seed
X,
treatment,
y,
p,
size=bootstrap_size,
rng=rng,
seed=bootstrap_seed,
)
ate_bootstraps[:, n] = cate_b.mean(axis=0)

Expand Down
101 changes: 101 additions & 0 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,107 @@ def test_BaseDRLearner(generate_regression_data):
assert auuc["cate_p"] > 0.5


def test_BaseDRLearner_estimate_ate_bootstrap(generate_regression_data):
"""Regression test for issue #857: estimate_ate with bootstrap_ci=True
raised TypeError due to stray seed argument passed to bootstrap()."""
y, X, treatment, tau, b, e = generate_regression_data()

learner = BaseDRRegressor(learner=LinearRegression(), control_name=0)

# This call raised TypeError before the fix
ate, lb, ub = learner.estimate_ate(
X=X,
treatment=treatment,
y=y,
p=e,
bootstrap_ci=True,
n_bootstraps=10,
bootstrap_size=200,
seed=RANDOM_SEED,
)

# Verify results are valid
assert np.all(np.isfinite(ate))
assert np.all(np.isfinite(lb))
assert np.all(np.isfinite(ub))

# Verify same seed produces identical bootstrap CI bounds
learner2 = BaseDRRegressor(learner=LinearRegression(), control_name=0)
ate2, lb2, ub2 = learner2.estimate_ate(
X=X,
treatment=treatment,
y=y,
p=e,
bootstrap_ci=True,
n_bootstraps=10,
bootstrap_size=200,
seed=RANDOM_SEED,
)
np.testing.assert_array_equal(lb, lb2)
np.testing.assert_array_equal(ub, ub2)

# fit_predict() should also honor seed for bootstrap reproducibility.
learner_fp1 = BaseDRRegressor(learner=LinearRegression(), control_name=0)
_, te_lb1, te_ub1 = learner_fp1.fit_predict(
X=X,
treatment=treatment,
y=y,
p=e,
return_ci=True,
n_bootstraps=10,
bootstrap_size=200,
seed=RANDOM_SEED,
)
learner_fp2 = BaseDRRegressor(learner=LinearRegression(), control_name=0)
_, te_lb2, te_ub2 = learner_fp2.fit_predict(
X=X,
treatment=treatment,
y=y,
p=e,
return_ci=True,
n_bootstraps=10,
bootstrap_size=200,
seed=RANDOM_SEED,
)
np.testing.assert_array_equal(te_lb1, te_lb2)
np.testing.assert_array_equal(te_ub1, te_ub2)

# Verify seed=None still returns valid results
learner3 = BaseDRRegressor(learner=LinearRegression(), control_name=0)
ate3, lb3, ub3 = learner3.estimate_ate(
X=X,
treatment=treatment,
y=y,
p=e,
bootstrap_ci=True,
n_bootstraps=10,
bootstrap_size=200,
)
assert np.all(np.isfinite(ate3))
assert np.all(np.isfinite(lb3))
assert np.all(np.isfinite(ub3))

# Verify global RNG state is not leaked by seeded bootstrap
np.random.seed(99)
_ = np.random.random()
state_before = np.random.get_state()
learner4 = BaseDRRegressor(learner=LinearRegression(), control_name=0)
learner4.estimate_ate(
X=X,
treatment=treatment,
y=y,
p=e,
bootstrap_ci=True,
n_bootstraps=10,
bootstrap_size=200,
seed=RANDOM_SEED,
)
state_after = np.random.get_state()
assert state_before[0] == state_after[0]
np.testing.assert_array_equal(state_before[1], state_after[1])
assert state_before[2:] == state_after[2:]


def test_BaseDRClassifier(generate_classification_data):
np.random.seed(RANDOM_SEED)

Expand Down