causallib.model_selection.TreatmentOutcomeStratifiedKFold#
- class TreatmentOutcomeStratifiedKFold(n_splits=5, *, shuffle=False, random_state=None)[source]#
Creates stratified folds based on both the treatment assignment and the outcome. That is, every fold preserves both the treatment prevalence and outcome prevalence within each treatment.
For non-class outcomes, stratification is done only based on treatment.
Class-wise stratified K-Fold cross-validator.
Provides train/test indices to split data in train/test sets.
This cross-validation object is a variation of KFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class in y in a binary or multiclass classification setting.
Read more in the User Guide.
For visualisation of cross-validation behaviour and comparison between common scikit-learn split methods refer to Visualizing cross-validation behavior in scikit-learn
Note
Stratification on the class label solves an engineering problem rather than a statistical one. See Cross-validation iterators with stratification based on class labels for more details.
- Parameters:
n_splits (
int, default5) –Number of folds. Must be at least 2.
Changed in version 0.22:
n_splitsdefault value changed from 3 to 5.shuffle (
bool, defaultFalse) – Whether to shuffle each class’s samples before splitting into batches. Note that the samples within each split will not be shuffled.random_state (
int,RandomState instanceorNone, defaultNone) – When shuffle is True, random_state affects the ordering of the indices, which controls the randomness of each fold for each class. Otherwise, leave random_state as None. Pass an int for reproducible output across multiple function calls. See Glossary.
Examples
>>> import numpy as np >>> from sklearn.model_selection import StratifiedKFold >>> X = numpy.ndarray([[1, 2], [3, 4], [1, 2], [3, 4]]) >>> y = numpy.ndarray([0, 0, 1, 1]) >>> skf = StratifiedKFold(n_splits=2) >>> skf.get_n_splits() 2 >>> print(skf) StratifiedKFold(n_splits=2, random_state=None, shuffle=False) >>> for i, (train_index, test_index) in enumerate(skf.split(X, y)): ... print(f"Fold {i}:") ... print(f" Train: index={train_index}") ... print(f" Test: index={test_index}") Fold 0: Train: index=[1 3] Test: index=[0 2] Fold 1: Train: index=[0 2] Test: index=[1 3]
Notes
The implementation is designed to:
Generate test sets such that all contain the same distribution of classes, or as close as possible.
Be invariant to class label: relabelling
y = ["Happy", "Sad"]toy = [1, 0]should not change the indices generated.Preserve order dependencies in the dataset ordering, when
shuffle=False: all samples from class k in some test set were contiguous in y, or separated in y by samples from classes other than k.Generate test sets where the smallest and largest differ by at most one sample.
Changed in version 0.22: The previous implementation did not follow the last constraint.
See also
RepeatedStratifiedKFoldRepeats Stratified K-Fold n times.
- split(joinedXa, y, groups=None)[source]#
Generate indices to split data into training and test set.
- Parameters:
X (
array-likeofshape (n_samples,n_features)) –Training data, where n_samples is the number of samples and n_features is the number of features.
Note that providing
yis sufficient to generate the splits and hencenp.zeros(n_samples)may be used as a placeholder forXinstead of actual training data.y (
array-likeofshape (n_samples,)) – The target variable for supervised learning problems. Stratification is done based on the y labels.groups (
array-likeofshape (n_samples,), defaultNone) – Always ignored, exists for API compatibility.
- Yields:
train (
numpy.ndarray) – The training set indices for that split.test (
numpy.ndarray) – The testing set indices for that split.
Notes
Randomized CV splitters may return different results for each call of split. You can make the results identical by setting random_state to an integer.
- set_split_request(*, joinedXa='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
splitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tosplitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tosplit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.