"""Evaluation results objects for plotting and further analysis.
These objects are generated by the `evaluate` method.
"""
import abc
import dataclasses
import inspect
from typing import Dict, List, Tuple, Union
import pandas as pd
from ..estimation.base_estimator import IndividualOutcomeEstimator
from ..estimation.base_weight import PropensityEstimator, WeightEstimator
from .predictions import PropensityEvaluatorScores, SingleFoldPrediction
from .plots import mixins, data_extractors
[docs]@dataclasses.dataclass
class EvaluationResults(abc.ABC):
"""Data structure to hold evaluation results including cross-validation.
Attrs:
evaluated_metrics (Union[pd.DataFrame, PropensityEvaluatorScores, None]):
models (dict[str, Union[list[WeightEstimator], list[IndividualOutcomeEstimator]):
Models trained during evaluation. May be dict or list or a model
directly.
predictions (dict[str, List[SingleFoldPredictions]]): dict with keys
"train" and "valid" (if produced through cross-validation) and
values of the predictions for the respective fold
cv (list[tuple[list[int], list[int]]]): the cross validation indices,
used to generate the results, used for constructing plots correctly
X (pd.DataFrame): features data
a (pd.Series): treatment assignment data
y (pd.Series): outcome data
"""
evaluated_metrics: Union[pd.DataFrame, PropensityEvaluatorScores]
models: Union[
List[WeightEstimator],
List[IndividualOutcomeEstimator],
List[PropensityEstimator],
]
predictions: Dict[str, List[SingleFoldPrediction]]
cv: List[Tuple[List[int], List[int]]]
X: pd.DataFrame
a: pd.Series
y: pd.Series
@property
def _extractor(self):
"""Plot-data extractor for these results.
Implemented by child classes.
"""
raise NotImplementedError
[docs] @staticmethod
def make(
evaluated_metrics: Union[pd.DataFrame, PropensityEvaluatorScores],
models: Union[
List[WeightEstimator],
List[IndividualOutcomeEstimator],
List[PropensityEstimator],
],
predictions: Dict[str, List[SingleFoldPrediction]],
cv: List[Tuple[List[int], List[int]]],
X: pd.DataFrame,
a: pd.Series,
y: pd.Series,
):
"""Make EvaluationResults object of correct type.
This is a factory method to dispatch the initializing data to the correct subclass of
EvaluationResults. This is the only supported way to instantiate EvaluationResults objects.
Args:
evaluated_metrics (Union[pd.DataFrame, WeightEvaluatorScores]): evaluated metrics
models (Union[
List[WeightEstimator],
List[IndividualOutcomeEstimator],
List[PropensityEstimator], ]): fitted models
predictions (Dict[str, List[SingleFoldPrediction]]): predictions by phase and fold
cv (List[Tuple[List[int], List[int]]]): cross validation indices
X (pd.DataFrame): features data
a (pd.Series): treatment assignment data
y (pd.Series): outcome data
Raises:
ValueError: raised if invalid estimator is passed
Returns:
EvaluationResults: object with results of correct type
"""
if isinstance(models, dict):
fitted_model = models["train"][0]
elif isinstance(models, list):
fitted_model = models[0]
else:
fitted_model = models
if isinstance(fitted_model, PropensityEstimator):
return PropensityEvaluationResults(
evaluated_metrics, models, predictions, cv, X, a, y
)
if isinstance(fitted_model, WeightEstimator):
return WeightEvaluationResults(
evaluated_metrics, models, predictions, cv, X, a, y
)
if isinstance(fitted_model, IndividualOutcomeEstimator):
if any(
x and any(y.is_binary_outcome for y in x) for x in predictions.values()
):
return BinaryOutcomeEvaluationResults(
evaluated_metrics, models, predictions, cv, X, a, y
)
return ContinuousOutcomeEvaluationResults(
evaluated_metrics, models, predictions, cv, X, a, y
)
raise ValueError(
f"Unable to find suitable results object for esimator of type {type(fitted_model)}"
)
@property
def all_plot_names(self):
"""Available plot names.
Returns:
set[str]: string names of supported plot names for these results
"""
return self._extractor.plot_names
[docs] def get_data_for_plot(self, plot_name, phase="train"):
"""Get data for a given plot
Args:
plot_name (str): plot name from `self.all_plot_names`
phase (str, optional): phase of interest. Defaults to "train".
Returns:
Any: the data required for the plot in question
"""
return self._extractor.get_data_for_plot(plot_name, phase)
[docs] def remove_spurious_cv(self):
"""Remove redundant information accumulated due to the use of cross-validation process."""
self.models = self.models[0]
if isinstance(self.evaluated_metrics, pd.DataFrame):
self.evaluated_metrics.reset_index(level=["phase", "fold"], drop=True, inplace=True)
elif isinstance(self.evaluated_metrics, PropensityEvaluatorScores):
for metric in self.evaluated_metrics:
metric.reset_index(level=["phase", "fold"], drop=True, inplace=True)
[docs]class WeightEvaluationResults(
EvaluationResults,
mixins.WeightPlotterMixin,
mixins.PlotAllMixin,
):
__doc__ = inspect.getdoc(EvaluationResults)
@property
def _extractor(self):
return data_extractors.WeightPlotDataExtractor(self)
[docs]class BinaryOutcomeEvaluationResults(
EvaluationResults,
mixins.ClassificationPlotterMixin,
mixins.PlotAllMixin,
):
__doc__ = inspect.getdoc(EvaluationResults)
@property
def _extractor(self):
return data_extractors.BinaryOutcomePlotDataExtractor(self)
[docs]class ContinuousOutcomeEvaluationResults(
EvaluationResults,
mixins.ContinuousOutcomePlotterMixin,
mixins.PlotAllMixin,
):
__doc__ = inspect.getdoc(EvaluationResults)
@property
def _extractor(self):
return data_extractors.ContinuousOutcomePlotDataExtractor(self)
[docs]class PropensityEvaluationResults(
EvaluationResults,
mixins.ClassificationPlotterMixin,
mixins.WeightPlotterMixin,
mixins.PlotAllMixin,
):
__doc__ = inspect.getdoc(EvaluationResults)
@property
def _extractor(self):
return data_extractors.PropensityPlotDataExtractor(self)