diff --git a/.gitignore b/.gitignore index 2d446e8a..584da151 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ uv.lock # Claude Code artifacts CLAUDE.md -.claude/ \ No newline at end of file +.claude/.worktrees/ +.worktrees/ +.worktrees/ diff --git a/causalml/metrics/sensitivity.py b/causalml/metrics/sensitivity.py index 3bd186d2..c8f90092 100644 --- a/causalml/metrics/sensitivity.py +++ b/causalml/metrics/sensitivity.py @@ -278,11 +278,10 @@ def sensitivity_estimate(self): Returns: (pd.DataFrame): a summary dataframe """ - num_rows = self.df.shape[0] - X = self.df[self.inference_features].values p = self.df[self.p_col].values - treatment_new = np.random.randint(2, size=num_rows) + treatment = self.df[self.treatment_col].values + treatment_new = np.random.permutation(treatment) y = self.df[self.outcome_col].values ate_new, ate_new_lower, ate_new_upper = self.get_ate_ci(X, p, treatment_new, y) diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py index 7c0b0248..4062758e 100644 --- a/tests/test_sensitivity.py +++ b/tests/test_sensitivity.py @@ -104,6 +104,34 @@ def test_SensitivityPlaceboTreatment(): print(sens_summary) +def test_SensitivityPlaceboTreatment_string_labels(): + y, X, treatment, tau, b, e = synthetic_data( + mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 + ) + + # Convert binary treatment to string labels + treatment_str = np.where(treatment == 1, "treatment1", "control") + + INFERENCE_FEATURES = ["feature_" + str(i) for i in range(NUM_FEATURES)] + df = pd.DataFrame(X, columns=INFERENCE_FEATURES) + df[TREATMENT_COL] = treatment_str + df[OUTCOME_COL] = y + df[SCORE_COL] = e + + learner = BaseXLearner(LinearRegression(), control_name="control") + sens = SensitivityPlaceboTreatment( + df=df, + inference_features=INFERENCE_FEATURES, + p_col=SCORE_COL, + treatment_col=TREATMENT_COL, + outcome_col=OUTCOME_COL, + learner=learner, + ) + + sens_summary = sens.summary(method="Placebo Treatment") + print(sens_summary) + + def test_SensitivityRandomCause(): y, X, treatment, tau, b, e = synthetic_data( mode=1, n=100000, p=NUM_FEATURES, sigma=1.0