Source code for causallib.evaluation.plots.mixins

"""Mixins for plotting.

To work the mixin requires the class to implement `get_data_for_plot` with the
supported plot names. See .data_extractors for examples. """

from . import plots


[docs]class WeightPlotterMixin: """Mixin to add members to for weight estimation plotting. Class must implement: * `get_data_for_plot(plots.COVARIATE_BALANCE_GENERIC_PLOT)` * `get_data_for_plot(plots.WEIGHT_DISTRIBUTION_PLOT)` """
[docs] def plot_covariate_balance( self, kind="love", phase="train", ax=None, aggregate_folds=True, thresh=None, plot_semi_grid=True, label_imbalanced=True, **kwargs, ): """Plot covariate balance before and after weighting. Args: kind (str, optional): Plot kind, "love" ,"slope" or "scatter". Defaults to "love". phase (str, optional): Phase to plot: "train" or "valid". Defaults to "train". ax (matplotlib.axes.Axes, optional): axis to plot on, if None creates new axis. Defaults to None. aggregate_folds (bool, optional): Whether to aggregate folds. Defaults to True. Ignored when kind="slope". thresh (float, optional): Draw threshold line at value. Defaults to None. plot_semi_grid (bool, optional): Defaults to True. only for kind="love". label_imbalanced (bool): Label covariates that weren't properly balanced. Ignored when kind="love". Returns: matplotlib.axes.Axes: axis with plot """ (table1_folds,) = self.get_data_for_plot( plots.COVARIATE_BALANCE_GENERIC_PLOT, phase=phase ) if kind == "love": return plots.plot_mean_features_imbalance_love_folds( table1_folds=table1_folds, ax=ax, aggregate_folds=aggregate_folds, thresh=thresh, plot_semi_grid=plot_semi_grid, **kwargs, ) if kind == "slope": return plots.plot_mean_features_imbalance_slope_folds( table1_folds=table1_folds, ax=ax, thresh=thresh, label_imbalanced=label_imbalanced, **kwargs, ) if kind == "scatter": return plots.plot_mean_features_imbalance_scatter_plot( table1_folds=table1_folds, ax=ax, thresh=thresh, label_imbalanced=label_imbalanced, **kwargs, ) raise ValueError(f"Unsupported covariate balance plot kind {kind}")
[docs] def plot_weight_distribution( self, phase="train", reflect=True, kde=False, cumulative=False, norm_hist=True, ax=None, ): """ Plot the distribution of propensity score. Args: phase (str, optional): Phase to plot: "train" or "valid". Defaults to "train". reflect (bool): Whether to plot treatment groups on opposite sides of the x-axis. This can only work if there are exactly two groups. kde (bool): Whether to plot kernel density estimation cumulative (bool): Whether to plot cumulative distribution. norm_hist (bool): If False - use raw counts on the y-axis. If kde=True, then norm_hist should be True as well. ax (matplotlib.axes.Axes, optional): axis to plot on, if None creates new axis. Defaults to None. Returns: matplotlib.axes.Axes """ weights, treatments, cv = self.get_data_for_plot( plots.WEIGHT_DISTRIBUTION_PLOT, phase=phase ) return plots.plot_propensity_score_distribution_folds( predictions=weights, hue_by=treatments, cv=cv, reflect=reflect, kde=kde, cumulative=cumulative, norm_hist=norm_hist, ax=ax, )
[docs]class ClassificationPlotterMixin: """Mixin to add members to for classification/binary prediction estimation. This occurs for propensity models (treatment assignment is inherently binary) and for outcome models where the outcome is binary. Class must implement: * `get_data_for_plot(plots.ROC_CURVE_PLOT)` * `get_data_for_plot(plots.PR_CURVE_PLOT)` * `get_data_for_plot(plots.CALIBRATION_PLOT)` """
[docs] def plot_roc_curve( self, phase="train", plot_folds=False, label_folds=False, label_std=False, ax=None, ): """Plot ROC curve. Args: phase (str, optional): Phase to plot: "train" or "valid". Defaults to "train". plot_folds (bool, optional): Whether to plot individual folds. Defaults to False. label_folds (bool, optional): Whether to label folds. Defaults to False. label_std (bool, optional): Whether to label std. Defaults to False. ax (matplotlib.axes.Axes, optional): axis to plot on, if None creates new axis. Defaults to None. Returns: matplotlib.axes.Axes """ (roc_curve_data,) = self.get_data_for_plot(plots.ROC_CURVE_PLOT, phase=phase) return plots.plot_roc_curve_folds( roc_curve_data, ax=ax, plot_folds=plot_folds, label_folds=label_folds, label_std=label_std, )
[docs] def plot_pr_curve( self, phase="train", plot_folds=False, label_folds=False, label_std=False, ax=None, ): """Plot precision-recall (PR) curve. Args: phase (str, optional): Phase to plot: "train" or "valid". Defaults to "train". plot_folds (bool, optional): Whether to plot individual folds. Defaults to False. label_folds (bool, optional): Whether to label folds. Defaults to False. label_std (bool, optional): Whether to label std. Defaults to False. ax (matplotlib.axes.Axes, optional): axis to plot on, if None creates new axis. Defaults to None. Returns: matplotlib.axes.Axes """ (pr_curve_data,) = self.get_data_for_plot(plots.PR_CURVE_PLOT, phase=phase) return plots.plot_precision_recall_curve_folds( pr_curve_data, ax=ax, plot_folds=plot_folds, label_folds=label_folds, label_std=label_std, )
[docs] def plot_calibration_curve( self, phase="train", n_bins=10, plot_se=True, plot_rug=False, plot_histogram=False, quantile=False, ax=None, ): """Plot calibration curves for multiple models (presumably in folds) Args: phase (str, optional): Phase to plot: "train" or "valid". Defaults to "train". n_bins (int): number of bins to evaluate in the plot plot_se (bool): Whether to plot standard errors around the mean bin-probability estimation. plot_rug (bool): plot_histogram (bool): quantile (bool): If true, the binning of the calibration curve is by quantiles. Defaults to False. ax (matplotlib.axes.Axes, optional): axis to plot on, if None creates new axis. Defaults to None. Note: One of plot_propensity or plot_model must be True. Returns: matplotlib.axes.Axes """ predictions, targets, cv = self.get_data_for_plot( plots.CALIBRATION_PLOT, phase=phase ) return plots.plot_calibration_folds( predictions=predictions, targets=targets, cv=cv, n_bins=n_bins, plot_se=plot_se, plot_rug=plot_rug, plot_histogram=plot_histogram, quantile=quantile, ax=ax, )
[docs]class ContinuousOutcomePlotterMixin: """Mixin to add members to for continous outcome estimation. Class must implement: * `get_data_for_plot(plots.CONTINUOUS_ACCURACY_PLOT)` * `get_data_for_plot(plots.RESIDUALS_PLOT)` * `get_data_for_plot(plots.CONTINUOUS_ACCURACY_PLOT)` """
[docs] def plot_continuous_accuracy( self, phase="train", alpha_by_density=True, plot_residuals=False, ax=None ): """Plot continuous accuracy, Args: phase (str, optional): Phase to plot: "train" or "valid". Defaults to "train". alpha_by_density (bool, optional): Whether to calculate points alpha value (transparent-opaque) with density estimation. This can take some time to compute for a large number of points. If False, alpha calculation will be a simple fast heuristic. plot_residuals (bool, optional): Whether to plot residuals. Defaults to False. ax (matplotlib.axes.Axes, optional): axis to plot on, if None creates new axis. Defaults to None. Returns: matplotlib.axes.Axes """ predictions, y, a, cv = self.get_data_for_plot( plots.CONTINUOUS_ACCURACY_PLOT, phase=phase ) return plots.plot_continuous_prediction_accuracy_folds( predictions=predictions, y=y, a=a, cv=cv, alpha_by_density=alpha_by_density, plot_residuals=plot_residuals, ax=ax, )
[docs] def plot_residuals(self, phase="train", alpha_by_density=True, ax=None): """Plot residuals of predicted outcome vs ground truth. Args: phase (str, optional): Phase to plot: "train" or "valid". Defaults to "train". alpha_by_density (bool, optional): Whether to calculate points alpha value (transparent-opaque) with density estimation. This can take some time to compute for a large number of points. If False, alpha calculation will be a simple fast heuristic. ax (matplotlib.axes.Axes, optional): axis to plot on, if None creates new axis. Defaults to None. Returns: matplotlib.axes.Axes """ predictions, y, a, cv = self.get_data_for_plot( plots.RESIDUALS_PLOT, phase=phase ) return plots.plot_residual_folds( predictions=predictions, y=y, a=a, cv=cv, alpha_by_density=alpha_by_density, ax=ax, )
[docs] def plot_common_support(self, phase="train", alpha_by_density=True, ax=None): """Plot the scatter plot of y0 vs. y1 for multiple scoring results, colored by the treatment Args: alpha_by_density (bool): Whether to calculate points alpha value (transparent-opaque) with density estimation. This can take some time to compute for a large number of points. If False, alpha calculation will be a simple fast heuristic. ax (plt.Axes): The axes on which the plot will be displayed. Optional. """ predictions, treatments, cv = self.get_data_for_plot( plots.COMMON_SUPPORT_PLOT, phase=phase ) return plots.plot_counterfactual_common_support_folds( predictions=predictions, hue_by=treatments, cv=cv, alpha_by_density=alpha_by_density, ax=ax, )
[docs]class PlotAllMixin: """Mixin to make all the train and validation plots. Class must implement: * `all_plot_names` * `get_data_for_plot(name)` for every name in `all_plot_names` """
[docs] def plot_all(self, phase=None): """Create plot of all available EvaluationResults. Will create a figure with a subplot for each plot name in `all_plot_names`. If `results` have train and validation data, will create separate "train" and "valid" figures. If a single plot is requested, only that plot is created. Args: phase (Union[str, None], optional): phase to plot "train" or "valid". If not supplied, defaults to both if available. Returns: Dict[str, matplotlib.axis.Axis]]: the Axis objects of the plots in a nested dictionary: * First key is the phase ("train" or "valid") * Second key is the plot name. """ phases_to_plot = self.predictions.keys() if phase is None else [phase] multipanel_plot = { plotted_phase: self._make_multipanel_evaluation_plot( plot_names=self.all_plot_names, phase=plotted_phase ) for plotted_phase in phases_to_plot } return multipanel_plot
def _make_multipanel_evaluation_plot(self, plot_names, phase): phase_fig, phase_axes = plots.get_subplots(len(plot_names)) named_axes = { name: self._make_single_panel_evaluation_plot(name, phase, ax) for name, ax in zip(plot_names, phase_axes.ravel()) } phase_fig.suptitle(f"Evaluation on {phase} phase") return named_axes def _make_single_panel_evaluation_plot(self, plot_name, phase, ax=None, **kwargs): """Create a single evaluation plot. For a single phase and a single plot name. Args: results (EvaluationResults): evaluation results to plot plot_name (str): plot name (from results.all_plot_names) phase (str): "train" or "valid" ax (matplotlib.axis.Axis, optional): axis to plot on. Defaults to None. **kwargs: passed to underlying plotting function Returns: Union[matplotlib.axis.Axis, None]: axis with plot if successful, else None """ plot_func = plots.lookup_name(plot_name) plot_data = self.get_data_for_plot(plot_name, phase=phase) return plot_func(*plot_data, ax=ax, **kwargs)