causallib.contrib.adversarial_balancing.classifier_selection module
- causallib.contrib.adversarial_balancing.classifier_selection.select_classifier(model, X, A, n_splits=5, loss_type='01', seed=None)[source]
Utility for selecting best classifier using cross-validation.
- Parameters
model – Either one of: scikit-learn classifier, scikit-learn SearchCV model (GridSearchCV, RandomizedSearchCV), list of classifiers.
X (np.ndarray) – Covariate matrix size (num_samples, num_features)
A (np.ndarray) – binary labels indicating the source and target populations (num_samples,)
n_splits (int) – number of splits in cross-validation. relevant only if list of classifiers is passed.
loss_type (str) – name of loss metric to select classifier by. Either ‘01’ for zero-one loss, otherwise cross-entropy is used (and classifiers must implement predict_proba). relevant only if list of classifiers is passed.
seed (int) – random seed for cross-validation split. relevant only if list of classifiers is passed.
- Returns
best performing classifier on validation set.
- Return type
classifier