causallib.contrib.hemm.hemm.HEMM#

class HEMM(*args, **kwargs)[source]#

This is the model defintion. The model has two parts:

  1. the subgroup discovery component

  2. the outcome prediction from the subgroup assignment and the interaction with confounders through an MLP.

Parameters:
  • D_in – the size of the features of the data

  • K – number of components to discover.

  • homo (bool) – Flag to specify if the final outcome model is same for each discovered subgroup. Default is True ie. same outcome model is used for each subgroup.

  • mu – initialize the components with means of the training data.

  • std – initialize the components with std dev of the training data.

  • bc – the first bc components are considered bernoulli variables

  • lamb – strength of the beta(0.5, 0.5) prior on the bernoulli variables

  • spread – how far should the components be initialized from there means.

  • outcomeModel – ‘linear’ to specify a linear outcome function. Or pass another Torch.model as the outcome model.

  • sep_heads – Setting false will force the adjustment of Confounding to be same independent of treatment assignment.

__init__(D_in, K, homo=True, mu=None, std=None, bc=0, lamb=0.0001, spread=0.1, outcomeModel='linear', sep_heads=True)[source]#
static gaussian_pdf(x, mu, std)[source]#
static bernoulli_pdf(x, mu)[source]#
regularization(alpha, beta)[source]#
forward(x, t, soft=True, infer=True, response='bin')[source]#
fit(train, epochs=100, batch_size=10, lr=0.001, wd=0, ltype='log', dev=None, metric='AP', response='bin', use_p_correction=True, imb_fun='mmd2_lin', p_alpha=0.0001)[source]#

This method uses ELBO to perform parameter updates using a first order optimizer routing.

Parameters:
  • train – (x_, t_, Y_):= Tuple o torch tensors of the input features, treatment and outcome.

  • epochs – Max number of epochs.

  • batch_size – Batch size for optimizer.

  • lr – Learning rate for the optimizer.

  • wd (float) – Weight decay

  • ltype – ‘log’ for the ELBO.

  • dev – Tuple of validation dataset i.e.: torch tensors (x_, t_, Y_). If provided, early-stopping criteria using metric will be applied.

  • metric (str) – ‘AP’ mean average precision, ‘AuROC’: area under roc curve, ‘MSE’:mean squared error. If not specified it uses the optimizer cost as the metric. The specified this metric is computed for the training set (or dev set if it is specified) and used to perform early stopping to prevent overfitting if dev set is provided.

  • response – Specify if y is continuous (‘cont’) or binary (‘bin’).

  • use_p_correction (bool) – Whether to use population size p(treated) in imbalance penalty (IPM).

  • imb_fun (str) – Which imbalance penalty to use (‘mmd_lin’, ‘wass’).

  • p_alpha (float) – Imbalance regularization parameter.

lcost(x_, t_, y_, ltype='log', response='bin', use_p_correction=True, imb_fun=None, p_alpha=0.0001, rbf_sigma=0.1, wass_its=20, wass_lambda=10.0)[source]#

Implements ELBO as the objective function (eq 12 in paper).

Parameters:
  • x (torch.Tensor) – Covariate matrix of size (num_subjects, num_features).

  • t (torch.Tensor) – Treatment assignment of size (num_subjects,).

  • y (torch.Tensor) – Outcome of size (num_subjects,).

  • ltype – ‘log’ specifies ELBO

  • response – specifies whether outcome y is binary (‘bin’) or continuous (‘cont’).

  • use_p_correction (bool) – whether to use population size p(treated) in imbalance penalty (IPM)

  • imb_fun (str) – which imbalance penalty to use (‘mmd_lin’, ‘mmd_rbf’, ‘wass’, ‘wass2’)

  • p_alpha (float) – imbalance regularization parameter

  • rbf_sigma (float) – RBF MMD sigma

  • wass_its (int) – Number of iterations in Wasserstein computation.

  • wass_lambda (float) – Wasserstein lambda.

group_sizes(X)[source]#

Returns the number of data points assigned to each subgroup.

Parameters:

X (torch.Tensor) – Covariate matrix of size (num_subjects, num_features).

Returns:

giving size for each group

Return type:

collections.Counter

estimate_individual_outcomes(X, T, response='bin', soft=True)[source]#

Return individual treatment outcomes.

Parameters:
  • X (torch.Tensor) – Covariate matrix of size (num_subjects, num_features).

  • T (torch.Tensor) – Treatment assignment of size (num_subjects,).

  • response (str)

  • soft (bool)

Returns:

list of torch.Tensor, one for each unique value in tensor T

get_groups(X)[source]#

Return hard assignment of groups for each sample

Parameters:

X (torch.Tensor) – Covariate matrix of size (num_subjects, num_features).

Returns:

Most probable group assignment of each sample, size = (num_samples,)

Return type:

groups (numpy.ndarray)

get_groups_proba(X, log=False)[source]#

Return soft assignment of probability of each sample to be part of each group.

Parameters:
  • X (torch.Tensor) – Covariate matrix of size (num_subjects, num_features).

  • log (bool) – If True returns log probabilities

Returns:

probability of group membership given X P(Z|X),

size = (num_covariates, num_components+1).

Return type:

z_pred (numpy.ndarray)

get_groups_effect()[source]#