Source code for causallib.survival.univariate_curve_fitter

import pandas as pd
import numpy as np
from typing import Optional
from sklearn.base import BaseEstimator as SKLearnBaseEstimator
from .survival_utils import safe_join
from .regression_curve_fitter import RegressionCurveFitter


[docs]class UnivariateCurveFitter: def __init__(self, learner: Optional[SKLearnBaseEstimator] = None): """ Default implementation of a univariate survival curve fitter. Construct a curve fitter, either non-parametric (Kaplan-Meier) or parametric. API follows 'lifelines' convention for univariate models, see here for example: https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter.fit Args: learner: optional scikit-learn estimator (needs to implement `predict_proba`). If provided, will compute parametric curve by fitting a time-varying hazards model. if None, will compute non-parametric Kaplan-Meier estimator. """ self.learner = learner
[docs] def fit(self, durations, event_observed=None, weights=None): """ Fits a univariate survival curve (Kaplan-Meier or parametric, if a learner was provided in constructor) Args: durations (Iterable): Duration subject was observed event_observed (Optional[Iterable]): Boolean or 0/1 iterable, where True means 'outcome event' and False means 'right censoring'. If unspecified, assumes that all events are 'outcome' (no censoring). weights (Optional[Iterable]): Optional subject weights Returns: Self """ # If 'event_observed' is unspecified, assumes that all events are 'outcome' (no censoring). if event_observed is None: event_observed = pd.Series(data=1, index=durations.index) if weights is None: weights = pd.Series(data=1, index=durations.index, name='weights') else: weights = pd.Series(data=weights, index=durations.index, name='weights') self.timeline_ = np.sort(np.unique(durations)) # If sklearn classifier is provided, fit parametric curve if self.learner is not None: self.curve_fitter_ = RegressionCurveFitter(learner=self.learner) fit_data, (duration_col_name, event_col_name, weights_col_name) = safe_join( df=None, list_of_series=[durations, event_observed, weights], return_series_names=True ) self.curve_fitter_.fit(df=fit_data, duration_col=duration_col_name, event_col=event_col_name, weights_col=weights_col_name) # Else, compute Kaplan Meier estimator non parametrically else: # Code inspired by lifelines KaplanMeierFitter df = pd.DataFrame({ 't': durations, 'removed': weights.to_numpy(), 'observed': weights.to_numpy() * (event_observed.to_numpy(dtype=bool)) }) death_table = df.groupby("t").sum() death_table['censored'] = (death_table['removed'] - death_table['observed']).astype(int) births = pd.DataFrame(np.zeros(durations.shape[0]), columns=["t"]) births['entrance'] = np.asarray(weights) births_table = births.groupby("t").sum() event_table = death_table.join(births_table, how="outer", sort=True).fillna( 0) # http://wesmckinney.com/blog/?p=414 event_table['at_risk'] = event_table['entrance'].cumsum() - event_table['removed'].cumsum().shift(1).fillna( 0) self.event_table_ = event_table return self
[docs] def predict(self, times=None, interpolate=False): """ Compute survival curve for time points given in 'times' param. Args: times: sequence of time points for prediction interpolate: if True, linearly interpolate non-observed times. Otherwise, repeat last observed time point. Returns: pd.Series: with times index and survival values """ if times is None: times = self.timeline_ else: times = sorted(times) if self.learner is not None: # Predict parametric survival curve survival = self.curve_fitter_.predict_survival_function(X=None, times=pd.Series(times)) else: # Compute hazard at each time step hazard = self.event_table_['observed'] / self.event_table_['at_risk'] timeline = hazard.index # if computed non-parametrically, timeline is all observed data points # Compute survival from hazards survival = pd.Series(data=np.cumprod(1 - hazard), index=timeline, name='survival') if interpolate: survival = pd.Series(data=np.interp(times, survival.index.values, survival.values), index=pd.Index(data=times, name='t'), name='survival') else: survival = survival.asof(times).squeeze() # Round near-zero values (may occur when using weights and all observed subjects "died" at some point) survival[np.abs(survival) < np.finfo(float).resolution] = 0 return survival