Source code for causallib.survival.standardized_survival

from .regression_curve_fitter import RegressionCurveFitter
from sklearn.base import BaseEstimator as SKLearnBaseEstimator
from typing import Any, Optional
import pandas as pd
from copy import deepcopy
from .survival_utils import canonize_dtypes_and_names, safe_join
from .base_survival import SurvivalBase


[docs]class StandardizedSurvival(SurvivalBase): def __init__(self, survival_model: Any, stratify: bool = True, **kwargs): """ Standardization survival estimator. Computes parametric curve by fitting a time-varying hazards model that includes baseline covariates. Args: survival_model: Two alternatives: 1. Scikit-Learn estimator (needs to implement `predict_proba`) - compute parametric curve by fitting a time-varying hazards model that includes baseline covariates. Note that the model is fitted on a person-time table with all covariates, and might be computationally and memory expansive. 2. lifelines RegressionFitter - use lifelines fitter to compute survival curves from baseline covariates, events and durations stratify (bool): if True, fit a separate model per treatment group """ self.stratify = stratify if isinstance(survival_model, SKLearnBaseEstimator): # Construct default curve fitter, parametric with a scikit-learn estimator self.survival_model = RegressionCurveFitter(survival_model) else: # Initialized lifelines RegressionFitter (or any implementation with a compatible API) self.survival_model = survival_model
[docs] def fit(self, X: pd.DataFrame, a: pd.Series, t: pd.Series, y: pd.Series, w: Optional[pd.Series] = None, fit_kwargs: Optional[dict] = None): """ Fits parametric models and calculates internal survival functions. Args: X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). a (pd.Series): Treatment assignment of size (num_subjects,). t (pd.Series): Followup duration, size (num_subjects,). y (pd.Series): Observed outcome (1) or right censoring event (0), size (num_subjects,). w (pd.Series): Optional subject weights. fit_kwargs (dict): Optional kwargs for fit call of survival model Returns: self """ a, t, y, w, X = canonize_dtypes_and_names(a=a, t=t, y=y, w=w, X=X) if w is not None: fit_data, (w_name,) = safe_join(df=X, list_of_series=[w], return_series_names=True) else: fit_data = X w_name = None if fit_kwargs is None: fit_kwargs = {} self.stratified_curve_fitters_ = {} if self.stratify: fit_data, (t_name, y_name) = safe_join(df=fit_data, list_of_series=[t, y], return_series_names=True) unique_treatment_values = a.unique() for treatment_value in unique_treatment_values: stratum_curve_fitter = deepcopy(self.survival_model) stratum_curve_fitter.fit(df=fit_data[a == treatment_value], duration_col=t_name, event_col=y_name, weights_col=w_name, **fit_kwargs) self.stratified_curve_fitters_[treatment_value] = stratum_curve_fitter else: fit_data, (a_name, t_name, y_name) = safe_join(df=fit_data, list_of_series=[a, t, y], return_series_names=True) self.survival_model.fit(df=fit_data, duration_col=t_name, event_col=y_name, weights_col=w_name, **fit_kwargs) return self
[docs] def estimate_individual_outcome(self, X: pd.DataFrame, a: pd.Series, t: pd.Series, y: Optional[Any] = None, timeline_start: Optional[int] = None, timeline_end: Optional[int] = None ) -> pd.DataFrame: """ Returns individual survival curves for each subject row in X/a/t Args: X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). a (pd.Series): Treatment assignment of size (num_subjects,). t (pd.Series): Followup durations, size (num_subjects,). y: NOT USED (for API compatibility only). timeline_start (int): Common start time-step. If provided, will generate survival curves starting from 'timeline_start' for all patients. If None, will predict from first observed event (t.min()). timeline_end (int): Common end time-step. If provided, will generate survival curves up to 'timeline_end' for all patients. If None, will predict up to last observed event (t.max()). Returns: pd.DataFrame: with time-step index, subject IDs (X.index) as columns and point survival as entries """ min_time = timeline_start if timeline_start is not None else int(t.min()) max_time = timeline_end if timeline_end is not None else int(t.max()) contiguous_times = pd.Series(data=range(min_time, max_time + 1), name=t.name) # contiguous time steps for inference a, _, _, _, X = canonize_dtypes_and_names(a=a, w=None, X=X) unique_treatment_values = sorted(a.unique()) res = {} for treatment_value in unique_treatment_values: if self.stratify: predict_data = X model = self.stratified_curve_fitters_[treatment_value] else: assignment = pd.Series(treatment_value, index=a.index, name=a.name) predict_data, a_name = safe_join( df=X, list_of_series=[assignment], return_series_names=True ) model = self.survival_model treatment_individual_survival_curves = model.predict_survival_function( X=predict_data, times=contiguous_times ) res[treatment_value] = treatment_individual_survival_curves res = pd.concat(res, axis="columns", names=[a.name]) return res
[docs] def estimate_population_outcome(self, X: pd.DataFrame, a: pd.Series, t: pd.Series, y: Optional[Any] = None, timeline_start: Optional[int] = None, timeline_end: Optional[int] = None ) -> pd.DataFrame: """ Returns population averaged survival curves. Args: X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). a (pd.Series): Treatment assignment of size (num_subjects,). t (pd.Series): Followup durations, size (num_subjects,). y: NOT USED (for API compatibility only). timeline_start (int): Common start time-step. If provided, will generate survival curves starting from 'timeline_start' for all patients. If None, will predict from first observed event (t.min()). timeline_end (int): Common end time-step. If provided, will generate survival curves up to 'timeline_end' for all patients. If None, will predict up to last observed event (t.max()). Returns: pd.DataFrame: with time-step index, treatment values as columns and survival as entries """ a, t, _, _, X = canonize_dtypes_and_names(a=a, t=t, X=X) individual_survival_curves = self.estimate_individual_outcome( X=X, a=a, t=t, timeline_start=timeline_start, timeline_end=timeline_end, ) res = individual_survival_curves.groupby( level=0, axis="columns", ).mean() # Setting index/column names res.index.name = t.name res.columns.name = a.name return res