Source code for causallib.utils.general_tools

"""
(C) Copyright 2019 IBM Corp.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Created on Jun 27, 2018

General (i.e. non-scientific) utils used throughout the package.
"""
from numpy import isscalar as np_is_scalar
from pandas import Series


[docs]def get_iterable_treatment_values(treatment_values, treatment_assignment, sort=True): """ Convert an optionally provided specification of unique treatment values to an iterable of the unique treatment options. Since user can provide treatment values as either an iterable or a single value, this conversion to an iterable allows a generic approach of going over all provided treatment values. Args: treatment_values (None|Any|list[Any]): Unique values of possible treatment values. Can be either one value (scalar) or list of values (any iterable). Can be None, if None - treatment values are inferred from treatment assignment. treatment_assignment (Series): The observed treatment assignment, used to infer a list of unique treatment values in case no treatment values are provided (None is passed to treatment_values). sort (bool): Whether to sort the treatment values Returns: list[Any]: list of unique treatment values. """ treatment_values = treatment_assignment.unique() if treatment_values is None else treatment_values treatment_values = [treatment_values] if np_is_scalar(treatment_values) else treatment_values if sort: treatment_values = sorted(treatment_values) return treatment_values
[docs]def create_repr_string(o): """ Args: o (object): any core object Returns: str: repr string based on internal attributes """ # Filter peripheral unimportant attribute names: params = [ attr for attr in dir(o) if not attr.startswith('__') # Data-model dunder methods and not callable(getattr(o, attr, None)) # Remove other methods, keep only fields and not attr.startswith("_abc") # Remove abstract-related attributes and not attr.endswith("_") # Remove attributes stated after initialization and not attr == "CALCULATE_EFFECT" # Remove the EffectEstimator attribute ] # Special treatment for scikit-learn's learner object (the "learner" attribute) - place last in a new line: learner_included = False if "learner" in params: params.remove("learner") # params.append("learner") # move learner to be last parameter learner_included = True # Couple attribute name with attribute value params = [(attr, getattr(o, attr)) for attr in params] params_string = ", ".join("{}={}".format(*param) for param in params) if learner_included: # Place "learner" attribute last in a different line: params_string += ",\n{spaces}learner={learner}".format(spaces=" " * (len(o.__class__.__name__) + 1), learner=getattr(o, "learner")) repr_string = "{cls_name}({params})".format(cls_name=o.__class__.__name__, params=params_string) return repr_string
[docs]def check_learner_is_fitted(learner): """Return True if fitted and False otherwise""" # Following scikit-learn's convention, # fitted models have additional attributes ending with underscores. # See: https://scikit-learn.org/dev/glossary.html#term-fitted # Hence checking whether these exist is sufficient: after_init_attr = [attr for attr in learner.__dict__.keys() if attr.endswith("_")] is_fitted = len(after_init_attr) > 0 return is_fitted