Quantifying aleatoric uncertainty of the treatment effect with an AU-learner based on conditional normalizing flows (CNFs)
The project is built with the following Python libraries:
- Pyro - deep learning and probabilistic models (MDNs, NFs)
- Hydra - simplified command line arguments management
- MlFlow - experiments tracking
First one needs to make the virtual environment and install all the requirements:
pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txtTo start an experiments server, run:
mlflow server --port=5000
To access the MlFLow web UI with all the experiments, connect via ssh:
ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>
Then, one can go to the local browser http://localhost:5000.
Before running semi-synthetic experiments, place datasets in the corresponding folders:
- IHDP100 dataset: ihdp_npci_1-100.test.npz and ihdp_npci_1-100.train.npz to
data/ihdp100/
We use multi-country data from Banholzer et al. (2021). Access data here.
The main training script is universal for different methods and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in config/ folder.
Generic script with logging and fixed random seed is the following:
PYTHONPATH=. python3 runnables/train.py +dataset=<dataset> +model=<model> exp.seed=10One needs to choose a model and then fill in the specific hyperparameters (they are left blank in the configs):
- AU-CNFs (= AU-learner with CNFs, this paper):
+model=dr_cnfswith two variants:- CRPS:
model.target_mode=cdf -
$W_2^2$ :model.target_mode=icdf
- CRPS:
- CA-CNFs (= CA-learner with CNFs, this paper):
+model=ca_cnfswith two variants:- CRPS:
model.target_mode=cdf -
$W_2^2$ :model.target_mode=icdf
- CRPS:
- IPTW-CNF (= IPTW-learner with CNF, this paper):
+model=iptw_plugin_cnfs - Conditional Normalizing Flows (CNF, plug-in learner):
+model=plugin_cnfs -
Distributional Kernel Mean Embeddings (DKME):
+model=plugin_dkme
Models already have the best hyperparameters saved (for each model and dataset), one can access them via: +model/<dataset>_hparams=<model> or +model/<dataset>_hparams/<model>=<dataset_param>. Hyperparameters for three variants of AU-CNFs, CA-CNFs, IPTW-CNF, and CNF are the same: +model/<dataset>_hparams=plugin_cnfs.
To perform a manual hyperparameter tuning, use the flags model.tune_hparams=True, and, then, see model.hparams_grid.
One needs to specify a dataset/dataset generator (and some additional parameters, e.g. train size for the synthetic data dataset.n_samples_train=1000):
- Synthetic data (adapted from https://arxiv.org/abs/1810.02894):
+dataset=sinewith 3 settings:- Normal:
dataset.mode=normal - Multi-modal:
dataset.mode=multimodal - Exponential:
dataset.mode=exp
- Normal:
- IHDP dataset:
+dataset=ihdp - HC-MNIST dataset:
+dataset=hcmnist
Example of running an experiment with our AU-CNFs (CRPS) on Synthetic data in the normal setting with
PYTHONPATH=. python3 runnables/train.py -m +dataset=sine +model=dr_cnfs +model/sine_hparams/plugin_cnfs_normal=\'100\' model.target_mode=cdf model.correction_coeff=0.25 exp.seed=10,101,1010Example of tuning hyperparameters of the CNF based on HC-MNIST dataset:
PYTHONPATH=. python3 runnables/train.py -m +dataset=hcmnist +model=plugin_cnfs +model/hcmnist_hparams=plugin_cnfs exp.seed=10 model.tune_hparams=TrueProject based on the cookiecutter data science project template. #cookiecutterdatascience