causallib.contrib.adversarial_balancing.AdversarialBalancing#

class AdversarialBalancing(learner, iterations=20, lr=0.5, decay=1, loss_type='01', use_stabilized=True, verbose=False, *args, **kwargs)[source]#

Adversarial Balancing finds sample weights such that the weighted population under any treatment A looks similar (distribution-wise) to the true population. Borrowing from GANs, the main idea is that, for each treatment A, the algorithm find weights such that a specified classifier cannot distinguish between the entire population and the weighted population under treatment a.

At each step we update the weights using the gradient of the exponential loss function, and re-train the classifier. For a given classifier family, an optimal solution are weights that maximize the minimal error of classifiers in this family.

For more details about the algorithm see:

Adversarial Balancing for Causal Inference by Ozery-Flato and Thodoroff et al. https://arxiv.org/abs/1810.07406

Parameters:
  • learner – An initialized classifier object implementing fit and predict (scikit-learn compatible) Will be used to discriminate between the population under treatment a and the entire global population. A selection for each treatment value can be performed to choose the best classifier for that treatment group. It can be done by providing a scikit-learn initialized SearchCV model (either GridSearchCV or RandomizedSearchCV), or by providing a list of classifiers. If providing a list of classifiers, a selection will be done for each treatment value using cross-validation that will use the best-performing classifier among the list. see select_classifier module.

  • iterations (int) – The number of iterations to adjust the weights of each sample

  • lr (float) – Learning rate used to update the weights

  • decay (float) – Parameter to decay the learning rate through the iterations

  • loss_type (str) – Use ‘01’ for zero-one loss, otherwise cross-entropy is used (and provided learner should also implement predict_proba methods).

  • use_stabilized (bool) –

    Whether to re-weigh the learned weights with the prevalence of the treatment. Note: Adversarial balancing already has inherent component weighting treatment

    prevalence. Setting to False will “de-stabilize” the weights after they are calculated.

  • verbose (bool) – Whether to print out statistics to console during training.

iterative_models_#

numpy.ndarray of size(n_treatment_values, iterations) holding all the models created during training process.

iterative_normalizing_consts_#

numpy.ndarray of size(n_treatment_values, iterations) holding all the normalizing constants calculated during training process.

discriminator_loss_#

numpy.ndarray of size(n_treatment_values, iterations) holding the loss of the learner throughout the training process.

treatments_frequency_#

if use_stabilized=True, the proportions of the treatment values.

__init__(learner, iterations=20, lr=0.5, decay=1, loss_type='01', use_stabilized=True, verbose=False, *args, **kwargs)[source]#

Adversarial Balancing finds sample weights such that the weighted population under any treatment A looks similar (distribution-wise) to the true population. Borrowing from GANs, the main idea is that, for each treatment A, the algorithm find weights such that a specified classifier cannot distinguish between the entire population and the weighted population under treatment a.

At each step we update the weights using the gradient of the exponential loss function, and re-train the classifier. For a given classifier family, an optimal solution are weights that maximize the minimal error of classifiers in this family.

For more details about the algorithm see:

Adversarial Balancing for Causal Inference by Ozery-Flato and Thodoroff et al. https://arxiv.org/abs/1810.07406

Parameters:
  • learner – An initialized classifier object implementing fit and predict (scikit-learn compatible) Will be used to discriminate between the population under treatment a and the entire global population. A selection for each treatment value can be performed to choose the best classifier for that treatment group. It can be done by providing a scikit-learn initialized SearchCV model (either GridSearchCV or RandomizedSearchCV), or by providing a list of classifiers. If providing a list of classifiers, a selection will be done for each treatment value using cross-validation that will use the best-performing classifier among the list. see select_classifier module.

  • iterations (int) – The number of iterations to adjust the weights of each sample

  • lr (float) – Learning rate used to update the weights

  • decay (float) – Parameter to decay the learning rate through the iterations

  • loss_type (str) – Use ‘01’ for zero-one loss, otherwise cross-entropy is used (and provided learner should also implement predict_proba methods).

  • use_stabilized (bool) –

    Whether to re-weigh the learned weights with the prevalence of the treatment. Note: Adversarial balancing already has inherent component weighting treatment

    prevalence. Setting to False will “de-stabilize” the weights after they are calculated.

  • verbose (bool) – Whether to print out statistics to console during training.

iterative_models_#

numpy.ndarray of size(n_treatment_values, iterations) holding all the models created during training process.

iterative_normalizing_consts_#

numpy.ndarray of size(n_treatment_values, iterations) holding all the normalizing constants calculated during training process.

discriminator_loss_#

numpy.ndarray of size(n_treatment_values, iterations) holding the loss of the learner throughout the training process.

treatments_frequency_#

if use_stabilized=True, the proportions of the treatment values.

fit(X, a, y=None, w_init=None, **select_kwargs)[source]#

Trains an Adversarial Balancing model.

Parameters:
  • X (pandas.DataFrame) – Covariate matrix of size (num_subjects, num_features).

  • a (pandas.Series) – Treatment assignment of size (num_subjects,).

  • y – IGNORED.

  • w_init (pandas.Series) – Initial sample weights. If not provided, assumes uniform.

  • select_kwargs – keywords argument to past into select_classifier. relevant only if model was initialized with list of classifiers in learner.

Returns:

AdversarialBalancing

compute_weights(X, a, treatment_values=None, use_stabilized=None, **kwargs)[source]#

Computes individual weight given the individual’s treatment assignment. f(Pr[A=a_i | X_i]) for each individual i.

Parameters:
  • X (pandas.DataFrame) – Covariate matrix of size (num_subjects, num_features).

  • a (pandas.Series) – Treatment assignment of size (num_subjects,).

  • treatment_values (Any | None) – A desired value/s to extract weights to (i.e. weights to what treatment value should be calculated). If not specified, then the weights are chosen by the individual’s actual treatment assignment.

  • use_stabilized (bool) – Whether to re-weigh the learned weights with the prevalence of the treatment. This overrides the use_stabilized parameter provided at initialization. See Also: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4351790/#S6title

  • **kwargs

Returns:

A vector of size (num_subjects,) with a weight for each individual

Return type:

pandas.Series

estimate_population_outcome(X, a, y, w=None, treatment_values=None)[source]#
compute_weight_matrix(X, a, use_stabilized=None, **kwargs)[source]#

Computes individual weight across all possible treatment values. f(Pr[A=a_j | X_i]) for all individual i and treatment j.

Parameters:
Returns:

A matrix of size (num_subjects, num_treatments) with weight for every individual and every

treatment.

Return type:

pandas.DataFrame

set_fit_request(*, a='$UNCHANGED$', w_init='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • a (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for a parameter in fit.

  • w_init (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for w_init parameter in fit.

Returns:

self – The updated object.

Return type:

object