causallib.contrib.adversarial_balancing.adversarial_balancing module
- class causallib.contrib.adversarial_balancing.adversarial_balancing.AdversarialBalancing(learner, iterations=20, lr=0.5, decay=1, loss_type='01', use_stabilized=True, verbose=False, *args, **kwargs)[source]
Bases:
causallib.estimation.base_weight.WeightEstimator
,causallib.estimation.base_estimator.PopulationOutcomeEstimator
Adversarial Balancing finds sample weights such that the weighted population under any treatment A looks similar (distribution-wise) to the true population. Borrowing from GANs, the main idea is that, for each treatment A, the algorithm find weights such that a specified classifier cannot distinguish between the entire population and the weighted population under treatment a.
At each step we update the weights using the gradient of the exponential loss function, and re-train the classifier. For a given classifier family, an optimal solution are weights that maximize the minimal error of classifiers in this family.
- For more details about the algorithm see:
Adversarial Balancing for Causal Inference by Ozery-Flato and Thodoroff et al. https://arxiv.org/abs/1810.07406
- Parameters
learner – An initialized classifier object implementing fit and predict (scikit-learn compatible) Will be used to discriminate between the population under treatment a and the entire global population. A selection for each treatment value can be performed to choose the best classifier for that treatment group. It can be done by providing a scikit-learn initialized SearchCV model (either GridSearchCV or RandomizedSearchCV), or by providing a list of classifiers. If providing a list of classifiers, a selection will be done for each treatment value using cross-validation that will use the best-performing classifier among the list. see select_classifier module.
iterations (int) – The number of iterations to adjust the weights of each sample
lr (float) – Learning rate used to update the weights
decay (float) – Parameter to decay the learning rate through the iterations
loss_type (str) – Use ‘01’ for zero-one loss, otherwise cross-entropy is used (and provided learner should also implement predict_proba methods).
use_stabilized (bool) –
Whether to re-weigh the learned weights with the prevalence of the treatment. Note: Adversarial balancing already has inherent component weighting treatment
prevalence. Setting to False will “de-stabilize” the weights after they are calculated.
verbose (bool) – Whether to print out statistics to console during training.
- iterative_models_
np.ndarray of size(n_treatment_values, iterations) holding all the models created during training process.
- iterative_normalizing_consts_
np.ndarray of size(n_treatment_values, iterations) holding all the normalizing constants calculated during training process.
- discriminator_loss_
np.ndarray of size(n_treatment_values, iterations) holding the loss of the learner throughout the training process.
- treatments_frequency_
if use_stabilized=True, the proportions of the treatment values.
- compute_weight_matrix(X, a, use_stabilized=None, **kwargs)[source]
Computes individual weight across all possible treatment values. f(Pr[A=a_j | X_i]) for all individual i and treatment j.
- Parameters
X (pd.DataFrame) – Covariate matrix of size (num_subjects, num_features).
a (pd.Series) – Treatment assignment of size (num_subjects,).
use_stabilized (bool) – Whether to re-weigh the learned weights with the prevalence of the treatment. This overrides the use_stabilized parameter provided at initialization. See Also: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4351790/#S6title
**kwargs –
- Returns
- A matrix of size (num_subjects, num_treatments) with weight for every individual and every
treatment.
- Return type
pd.DataFrame
- compute_weights(X, a, treatment_values=None, use_stabilized=None, **kwargs)[source]
Computes individual weight given the individual’s treatment assignment. f(Pr[A=a_i | X_i]) for each individual i.
- Parameters
X (pd.DataFrame) – Covariate matrix of size (num_subjects, num_features).
a (pd.Series) – Treatment assignment of size (num_subjects,).
treatment_values (Any | None) – A desired value/s to extract weights to (i.e. weights to what treatment value should be calculated). If not specified, then the weights are chosen by the individual’s actual treatment assignment.
use_stabilized (bool) – Whether to re-weigh the learned weights with the prevalence of the treatment. This overrides the use_stabilized parameter provided at initialization. See Also: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4351790/#S6title
**kwargs –
- Returns
A vector of size (num_subjects,) with a weight for each individual
- Return type
pd.Series
- fit(X, a, y=None, w_init=None, **select_kwargs)[source]
Trains an Adversarial Balancing model.
- Parameters
X (pd.DataFrame) – Covariate matrix of size (num_subjects, num_features).
a (pd.Series) – Treatment assignment of size (num_subjects,).
y – IGNORED.
w_init (pd.Series) – Initial sample weights. If not provided, assumes uniform.
select_kwargs – keywords argument to past into select_classifier. relevant only if model was initialized with list of classifiers in learner.
- Returns
AdversarialBalancing