causallib.model_selection.TreatmentStratifiedKFold#

class TreatmentStratifiedKFold(n_splits=5, *, shuffle=False, random_state=None)[source]#

Creates stratified folds based on the treatment assignment. That is, every fold preserves the treatment prevalence. Class-wise stratified K-Fold cross-validator.

Provides train/test indices to split data in train/test sets.

This cross-validation object is a variation of KFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class in y in a binary or multiclass classification setting.

Read more in the User Guide.

For visualisation of cross-validation behaviour and comparison between common scikit-learn split methods refer to Visualizing cross-validation behavior in scikit-learn

Note

Stratification on the class label solves an engineering problem rather than a statistical one. See Cross-validation iterators with stratification based on class labels for more details.

Parameters:
  • n_splits (int, default 5) –

    Number of folds. Must be at least 2.

    Changed in version 0.22: n_splits default value changed from 3 to 5.

  • shuffle (bool, default False) – Whether to shuffle each class’s samples before splitting into batches. Note that the samples within each split will not be shuffled.

  • random_state (int, RandomState instance or None, default None) – When shuffle is True, random_state affects the ordering of the indices, which controls the randomness of each fold for each class. Otherwise, leave random_state as None. Pass an int for reproducible output across multiple function calls. See Glossary.

Examples

>>> import numpy as np
>>> from sklearn.model_selection import StratifiedKFold
>>> X = numpy.ndarray([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = numpy.ndarray([0, 0, 1, 1])
>>> skf = StratifiedKFold(n_splits=2)
>>> skf.get_n_splits()
2
>>> print(skf)
StratifiedKFold(n_splits=2, random_state=None, shuffle=False)
>>> for i, (train_index, test_index) in enumerate(skf.split(X, y)):
...     print(f"Fold {i}:")
...     print(f"  Train: index={train_index}")
...     print(f"  Test:  index={test_index}")
Fold 0:
  Train: index=[1 3]
  Test:  index=[0 2]
Fold 1:
  Train: index=[0 2]
  Test:  index=[1 3]

Notes

The implementation is designed to:

  • Generate test sets such that all contain the same distribution of classes, or as close as possible.

  • Be invariant to class label: relabelling y = ["Happy", "Sad"] to y = [1, 0] should not change the indices generated.

  • Preserve order dependencies in the dataset ordering, when shuffle=False: all samples from class k in some test set were contiguous in y, or separated in y by samples from classes other than k.

  • Generate test sets where the smallest and largest differ by at most one sample.

Changed in version 0.22: The previous implementation did not follow the last constraint.

See also

RepeatedStratifiedKFold

Repeats Stratified K-Fold n times.

split(joinedXa, y=None, groups=None)[source]#

Generate indices to split data into training and test set.

Parameters:
  • X (array-like of shape (n_samples, n_features)) –

    Training data, where n_samples is the number of samples and n_features is the number of features.

    Note that providing y is sufficient to generate the splits and hence np.zeros(n_samples) may be used as a placeholder for X instead of actual training data.

  • y (array-like of shape (n_samples,)) – The target variable for supervised learning problems. Stratification is done based on the y labels.

  • groups (array-like of shape (n_samples,), default None) – Always ignored, exists for API compatibility.

Yields:
  • train (numpy.ndarray) – The training set indices for that split.

  • test (numpy.ndarray) – The testing set indices for that split.

Notes

Randomized CV splitters may return different results for each call of split. You can make the results identical by setting random_state to an integer.

set_split_request(*, joinedXa='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the split method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to split if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to split.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:

joinedXa (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for joinedXa parameter in split.

Returns:

self – The updated object.

Return type:

object