Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions causalml/inference/meta/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ def predict(
for group in self.t_groups:
model = self.models[group]

# set the treatment column to zero (the control group)
X_new = np.hstack((np.zeros((X.shape[0], 1)), X))
yhat_cs[group] = model.predict(X_new)
# Build separate arrays for control and treatment to avoid in-place
# mutation, which fails when learners like CatBoost set the
# writeable flag to False on arrays passed to predict().
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
yhat_cs[group] = model.predict(X_new_c)

# set the treatment column to one (the treatment group)
X_new[:, 0] = 1
yhat_ts[group] = model.predict(X_new)
X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))
yhat_ts[group] = model.predict(X_new_t)

if (y is not None) and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
Expand Down Expand Up @@ -346,13 +347,14 @@ def predict(
for group in self.t_groups:
model = self.models[group]

# set the treatment column to zero (the control group)
X_new = np.hstack((np.zeros((X.shape[0], 1)), X))
yhat_cs[group] = model.predict_proba(X_new)[:, 1]
# Build separate arrays for control and treatment to avoid in-place
# mutation, which fails when learners like CatBoost set the
# writeable flag to False on arrays passed to predict().
X_new_c = np.hstack((np.zeros((X.shape[0], 1)), X))
yhat_cs[group] = model.predict_proba(X_new_c)[:, 1]

# set the treatment column to one (the treatment group)
X_new[:, 0] = 1
yhat_ts[group] = model.predict_proba(X_new)[:, 1]
X_new_t = np.hstack((np.ones((X.shape[0], 1)), X))
yhat_ts[group] = model.predict_proba(X_new_t)[:, 1]

if y is not None and (treatment is not None) and verbose:
mask = (treatment == group) | (treatment == self.control_name)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@
from .const import RANDOM_SEED, N_SAMPLE, ERROR_THRESHOLD, CONTROL_NAME, CONVERSION


class ReadOnlyLinearRegression:
"""Minimal regressor that marks input arrays read-only like CatBoost."""

def __init__(self):
self.model = LinearRegression()

def fit(self, X, y):
self.model.fit(X, y)
X.flags.writeable = False
return self

def predict(self, X):
result = self.model.predict(X)
X.flags.writeable = False
return result


def test_synthetic_data():
y, X, treatment, tau, b, e = synthetic_data(mode=1, n=N_SAMPLE, p=8, sigma=0.1)

Expand Down Expand Up @@ -97,6 +114,21 @@ def test_BaseSLearner(generate_regression_data):
assert (ate_p_pt == ate_p) and (lb_pt == lb) and (ub_pt == ub)


def test_BaseSLearner_predict_with_readonly_arrays(generate_regression_data):
y, X, treatment, _, _, _ = generate_regression_data()
X_readonly = np.array(X, copy=True)
X_readonly.flags.writeable = False

learner = BaseSLearner(learner=ReadOnlyLinearRegression())

# Exercise both fit() and predict() with read-only array behavior.
learner.fit(X=X_readonly, treatment=treatment, y=y)
cate = learner.predict(X=X_readonly)

assert cate.shape == (X.shape[0], 1)
assert not X_readonly.flags.writeable


def test_BaseSRegressor(generate_regression_data):
y, X, treatment, tau, b, e = generate_regression_data()

Expand Down