causallib.contrib.hemm.hemm.HEMM#
- class HEMM(*args, **kwargs)[source]#
This is the model defintion. The model has two parts:
the subgroup discovery component
the outcome prediction from the subgroup assignment and the interaction with confounders through an MLP.
- Parameters:
D_in – the size of the features of the data
K – number of components to discover.
homo (
bool) – Flag to specify if the final outcome model is same for each discovered subgroup. Default is True ie. same outcome model is used for each subgroup.mu – initialize the components with means of the training data.
std – initialize the components with std dev of the training data.
bc – the first bc components are considered bernoulli variables
lamb – strength of the beta(0.5, 0.5) prior on the bernoulli variables
spread – how far should the components be initialized from there means.
outcomeModel – ‘linear’ to specify a linear outcome function. Or pass another Torch.model as the outcome model.
sep_heads – Setting false will force the adjustment of Confounding to be same independent of treatment assignment.
- __init__(D_in, K, homo=True, mu=None, std=None, bc=0, lamb=0.0001, spread=0.1, outcomeModel='linear', sep_heads=True)[source]#
- fit(train, epochs=100, batch_size=10, lr=0.001, wd=0, ltype='log', dev=None, metric='AP', response='bin', use_p_correction=True, imb_fun='mmd2_lin', p_alpha=0.0001)[source]#
This method uses ELBO to perform parameter updates using a first order optimizer routing.
- Parameters:
train – (x_, t_, Y_):= Tuple o torch tensors of the input features, treatment and outcome.
epochs – Max number of epochs.
batch_size – Batch size for optimizer.
lr – Learning rate for the optimizer.
wd (
float) – Weight decayltype – ‘log’ for the ELBO.
dev – Tuple of validation dataset i.e.: torch tensors (x_, t_, Y_). If provided, early-stopping criteria using metric will be applied.
metric (
str) – ‘AP’ mean average precision, ‘AuROC’: area under roc curve, ‘MSE’:mean squared error. If not specified it uses the optimizer cost as the metric. The specified this metric is computed for the training set (or dev set if it is specified) and used to perform early stopping to prevent overfitting if dev set is provided.response – Specify if y is continuous (‘cont’) or binary (‘bin’).
use_p_correction (
bool) – Whether to use population size p(treated) in imbalance penalty (IPM).imb_fun (
str) – Which imbalance penalty to use (‘mmd_lin’, ‘wass’).p_alpha (
float) – Imbalance regularization parameter.
- lcost(x_, t_, y_, ltype='log', response='bin', use_p_correction=True, imb_fun=None, p_alpha=0.0001, rbf_sigma=0.1, wass_its=20, wass_lambda=10.0)[source]#
Implements ELBO as the objective function (eq 12 in paper).
- Parameters:
x (
torch.Tensor) – Covariate matrix of size (num_subjects, num_features).t (
torch.Tensor) – Treatment assignment of size (num_subjects,).y (
torch.Tensor) – Outcome of size (num_subjects,).ltype – ‘log’ specifies ELBO
response – specifies whether outcome y is binary (‘bin’) or continuous (‘cont’).
use_p_correction (
bool) – whether to use population size p(treated) in imbalance penalty (IPM)imb_fun (
str) – which imbalance penalty to use (‘mmd_lin’, ‘mmd_rbf’, ‘wass’, ‘wass2’)p_alpha (
float) – imbalance regularization parameterrbf_sigma (
float) – RBF MMD sigmawass_its (
int) – Number of iterations in Wasserstein computation.wass_lambda (
float) – Wasserstein lambda.
- group_sizes(X)[source]#
Returns the number of data points assigned to each subgroup.
- Parameters:
X (
torch.Tensor) – Covariate matrix of size (num_subjects, num_features).- Returns:
giving size for each group
- Return type:
- estimate_individual_outcomes(X, T, response='bin', soft=True)[source]#
Return individual treatment outcomes.
- get_groups(X)[source]#
Return hard assignment of groups for each sample
- Parameters:
X (
torch.Tensor) – Covariate matrix of size (num_subjects, num_features).- Returns:
Most probable group assignment of each sample, size = (num_samples,)
- Return type:
groups (numpy.ndarray)
- get_groups_proba(X, log=False)[source]#
Return soft assignment of probability of each sample to be part of each group.
- Parameters:
X (
torch.Tensor) – Covariate matrix of size (num_subjects, num_features).log (
bool) – If True returns log probabilities
- Returns:
- probability of group membership given X P(Z|X),
size = (num_covariates, num_components+1).
- Return type:
z_pred (numpy.ndarray)