causallib.contrib.adversarial_balancing.classifier_selection module

causallib.contrib.adversarial_balancing.classifier_selection.select_classifier(model, X, A, n_splits=5, loss_type='01', seed=None)[source]

Utility for selecting best classifier using cross-validation.

Parameters
  • model – Either one of: scikit-learn classifier, scikit-learn SearchCV model (GridSearchCV, RandomizedSearchCV), list of classifiers.

  • X (np.ndarray) – Covariate matrix size (num_samples, num_features)

  • A (np.ndarray) – binary labels indicating the source and target populations (num_samples,)

  • n_splits (int) – number of splits in cross-validation. relevant only if list of classifiers is passed.

  • loss_type (str) – name of loss metric to select classifier by. Either ‘01’ for zero-one loss, otherwise cross-entropy is used (and classifiers must implement predict_proba). relevant only if list of classifiers is passed.

  • seed (int) – random seed for cross-validation split. relevant only if list of classifiers is passed.

Returns

best performing classifier on validation set.

Return type

classifier