causallib.survival.WeightedStandardizedSurvival#
- class WeightedStandardizedSurvival(weight_model, survival_model, stratify=True, outcome_covariates=None, weight_covariates=None)[source]#
- Combines WeightedSurvival and StandardizedSurvival:
Adjusts for treatment assignment by creating weighted pseudo-population (e.g., inverse propensity weighting).
Computes parametric curve by fitting a time-varying hazards model that includes baseline covariates.
- Parameters:
weight_model (WeightEstimator) – causallib compatible weight model (e.g., IPW)
survival_model (Any) –
- Two alternatives:
- Scikit-Learn estimator (needs to implement predict_proba) - compute parametric curve by fitting a
time-varying hazards model that includes baseline covariates. Note that the model is fitted on a person-time table with all covariates, and might be computationally and memory expansive.
- lifelines RegressionFitter - use lifelines fitter to compute survival curves from baseline covariates,
events and durations
stratify (bool): if True, fit a separate model per treatment group outcome_covariates (array): Covariates to use for outcome model.
If None - all covariates passed will be used. Either list of column names or boolean mask.
- weight_covariates (array): Covariates to use for weight model.
If None - all covariates passed will be used. Either list of column names or boolean mask.
- __init__(weight_model, survival_model, stratify=True, outcome_covariates=None, weight_covariates=None)[source]#
- Combines WeightedSurvival and StandardizedSurvival:
Adjusts for treatment assignment by creating weighted pseudo-population (e.g., inverse propensity weighting).
Computes parametric curve by fitting a time-varying hazards model that includes baseline covariates.
- Parameters:
weight_model (WeightEstimator) – causallib compatible weight model (e.g., IPW)
survival_model (Any) –
- Two alternatives:
- Scikit-Learn estimator (needs to implement predict_proba) - compute parametric curve by fitting a
time-varying hazards model that includes baseline covariates. Note that the model is fitted on a person-time table with all covariates, and might be computationally and memory expansive.
- lifelines RegressionFitter - use lifelines fitter to compute survival curves from baseline covariates,
events and durations
stratify (bool): if True, fit a separate model per treatment group outcome_covariates (array): Covariates to use for outcome model.
If None - all covariates passed will be used. Either list of column names or boolean mask.
- weight_covariates (array): Covariates to use for weight model.
If None - all covariates passed will be used. Either list of column names or boolean mask.
- fit(X, a, t, y, w=None, fit_kwargs=None)[source]#
Fits parametric models and calculates internal survival functions.
- Parameters:
X (
pandas.DataFrame) – Baseline covariate matrix of size (num_subjects, num_features).a (
pandas.Series) – Treatment assignment of size (num_subjects,).t (
pandas.Series) – Followup duration, size (num_subjects,).y (
pandas.Series) – Observed outcome (1) or right censoring event (0), size (num_subjects,).w (
pandas.Series) – NOT USED (for compatibility only) optional subject weights.fit_kwargs (
dict) – Optional kwargs for fit call of survival model
- Returns:
self
- estimate_individual_outcome(X, a, t, y=None, timeline_start=None, timeline_end=None)[source]#
Returns individual survival curves for each subject row in X/a/t
- Parameters:
X (
pandas.DataFrame) – Baseline covariate matrix of size (num_subjects, num_features).a (
pandas.Series) – Treatment assignment of size (num_subjects,).t (
pandas.Series) – Followup durations, size (num_subjects,).y (Any | None) – NOT USED (for API compatibility only).
timeline_start (
int) – Common start time-step. If provided, will generate survival curves starting from ‘timeline_start’ for all patients. If None, will predict from first observed event (t.min()).timeline_end (
int) – Common end time-step. If provided, will generate survival curves up to ‘timeline_end’ for all patients. If None, will predict up to last observed event (t.max()).
- Returns:
with time-step index, subject IDs (X.index) as columns and point survival as entries
- Return type: