Source code for causallib.metrics.scorers

"""
This submodule implements a scorer interface for the
various causal metrics implemented.
These scorers can then be incorporated into model selection objects
to select the models (or hyperparameters) that optimizes these scores.
For example, as a `scoring` parameter to a `causallib.model_selection.GridSearchCV
object.

The signature of the call is ``(estimator, X, a, y)`` where ``estimator``
is the causal model to be evaluated, ``X`` are the covariates,
`a` is the treatment assignment, and ``y`` is the ground truth target
"""
# from sklearn.metrics._scorer import _BaseScorer
import abc

from . import (
    propensity_metrics,
    weight_metrics,
    outcome_metrics,
)


class _BaseCausalScorer:
    def __init__(self, score_func, sign, **kwargs):
        self._score_func = score_func
        self._sign = sign
        self._kwargs = kwargs

    def __call__(self, estimator, X, a_true, y_true, sample_weight=None, **kwargs):
        return self._sign * self._score(
            estimator,
            X,
            a_true,
            y_true,
            sample_weight=sample_weight,
            **{**self._kwargs, **kwargs},
        )

    @abc.abstractmethod
    def _score(self, estimator, X, a, y, sample_weight=None, **kwargs):
        raise NotImplementedError


class PropensityScorerBase(_BaseCausalScorer):
    def _score(self, estimator, X, a, y, sample_weight=None, **kwargs):
        propensities = estimator.compute_propensity(X, a)
        weights = estimator.compute_weights(X, a)
        score = self._score_func(
            a, propensities, sample_weight=weights,
            **kwargs
        )
        return score


weighted_roc_auc_error_scorer = PropensityScorerBase(
    propensity_metrics.weighted_roc_auc_error, -1,
)
weighted_roc_curve_error_scorer = PropensityScorerBase(
    propensity_metrics.weighted_roc_curve_error, -1,
)
expected_roc_auc_error_scorer = PropensityScorerBase(
    propensity_metrics.expected_roc_auc_error, -1,
)
expected_roc_curve_error_scorer = PropensityScorerBase(
    propensity_metrics.expected_roc_curve_error, -1
)
ici_error_scorer = PropensityScorerBase(
    propensity_metrics.ici_error, -1,
)

_PROPENSITY_SCORERS = dict(
    weighted_roc_auc_error=weighted_roc_auc_error_scorer,
    weighted_roc_curve_error=weighted_roc_curve_error_scorer,
    expected_roc_auc_error=expected_roc_auc_error_scorer,
    expected_roc_curve_error=expected_roc_curve_error_scorer,
    ici_error=ici_error_scorer,

)


class WeightScorerBase(_BaseCausalScorer):
    def _score(self, estimator, X, a, y, sample_weight=None, **kwargs):
        weights = estimator.compute_weights(X, a)
        score = self._score_func(
            X, a, sample_weight=weights,
            **kwargs
        )
        return score


covariate_balancing_error_scorer = WeightScorerBase(
    weight_metrics.covariate_balancing_error, -1,
)

covariate_imbalance_count_error_scorer = WeightScorerBase(
    weight_metrics.covariate_imbalance_count_error, -1,
)

_WEIGHT_SCORERS = dict(
    covariate_balancing_error=covariate_balancing_error_scorer,
    covariate_imbalance_count_error=covariate_imbalance_count_error_scorer,
)


class OutcomeScorerBase(_BaseCausalScorer):
    def _score(self, estimator, X, a, y, sample_weight=None, **kwargs):
        potential_outcomes_pred = estimator.estimate_individual_outcome(X, a)
        score = self._score_func(
            y, potential_outcomes_pred, a,
            **kwargs
            # Is this a good generic API call to the outcome metrics?
        )
        return score


balanced_residuals_error_scorer = OutcomeScorerBase(
    outcome_metrics.balanced_residuals_error, -1,
)

_OUTCOME_SCORERS = dict(
    balanced_residuals_error=balanced_residuals_error_scorer,
)

_SCORERS = {
    **_PROPENSITY_SCORERS,
    **_WEIGHT_SCORERS,
    **_OUTCOME_SCORERS
}


[docs] def get_scorer(scoring): """Gets a scorer callable from string. see `causallib.metrics.get_scorer_names` to retrieve available score names. """ if callable(scoring): return scoring try: return _SCORERS.get(scoring) except KeyError: raise ValueError( f"Scoring name {scoring} is not a valid scoring name." f"use the `causallib.metrics.get_scorer_names` to get all valid names." )
[docs] def get_scorer_names(score_type="all"): """Get the name of all available scorers. These names can be passed to `causallib.metrics.get_scorer` to retrieve a scorer object. Args: score_type (str): any of {"all", "propensity", "weight", "outcome"}. Returns only scorers relevant to the `score_type` type of model. Returns: """ scores_types_map = { "all": _SCORERS, "propensity": _PROPENSITY_SCORERS, "weight": _WEIGHT_SCORERS, "outcome": _OUTCOME_SCORERS, } try: return sorted(scores_types_map[score_type]) except KeyError: raise ValueError( f"`score_type` {score_type} is not valid." f"Please use one of {scores_types_map.keys()}." )