Source code for causallib.contrib.faissknn

# (C) Copyright 2021 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import faiss


[docs] class FaissNearestNeighbors:
[docs] def __init__(self, metric="mahalanobis", index_type="flatl2", n_cells=100, n_probes=10): """NearestNeighbors object utilizing the faiss library for speed Implements the same API as sklearn but runs 5-10x faster. Utilizes the `faiss` library https://github.com/facebookresearch/faiss . Tested with version 1.7.0. If `faiss-gpu` is installed from pypi, GPU acceleration will be used if available. Args: metric (str) : Distance metric for finding nearest neighbors (default: "mahalanobis") index_type (str) : Index type within faiss to use (supported: "flatl2" and "ivfflat") n_cells (int) : Number of voronoi cells (only used for "ivfflat", default: 100) n_probes (int) : Number of voronoi cells to search in (only used for "ivfflat", default: 10) Attributes (after running `fit`): index_ : the faiss index fit from the data. For details about faiss indices, see the faiss documentation at https://github.com/facebookresearch/faiss/wiki/Faiss-indexes . """ self.metric = metric self.n_cells = n_cells self.n_probes = n_probes self.index_type = index_type
[docs] def fit(self, X): """Create faiss index and train with data. Args: X (np.array): Array of N samples of shape (NxM) Returns: self: Fitted object """ X = self._transform_covariates(X) if self.index_type == "flatl2": self.index_ = faiss.IndexFlatL2(X.shape[1]) self.index_.add(X) elif self.index_type == "ivfflat": quantizer = faiss.IndexFlatL2(X.shape[1]) n_cells = max(1, min(self.n_cells, X.shape[0]//200)) n_probes = min(self.n_probes, n_cells) self.index_ = faiss.IndexIVFFlat( quantizer, X.shape[1], n_cells) self.index_.train(X) self.index_.nprobe = n_probes self.index_.add(X) else: raise NotImplementedError( "Index type {} not implemented. Please select" "one of [\"flatl2\", \"ivfflat\"]".format(self.index_type)) return self
[docs] def kneighbors(self, X, n_neighbors=1): """Find the k nearest neighbors of each sample in X Args: X (np.array): Array of shape (N,M) of samples to search for neighbors of. M must be the same as the fit data. n_neighbors (int, optional): Number of neighbors to find. Defaults to 1. Returns: (distances, indices): Two np.array objects of shape (N,n_neighbors) containing the distances and indices of the closest neighbors. """ X = self._transform_covariates(X) distances, indices = self.index_.search(X, n_neighbors) # faiss returns euclidean distance squared return np.sqrt(distances), indices
def _transform_covariates(self, X): if self.metric == "mahalanobis": if not hasattr(self, "VI"): raise AttributeError("Set inverse covariance VI first.") X = np.dot(X, self.VI.T) return np.ascontiguousarray(X).astype("float32")
[docs] def set_params(self, **parameters): for parameter, value in parameters.items(): if parameter == "metric_params": self.set_params(**value) else: self._setattr(parameter, value) return self
[docs] def get_params(self, deep=True): # `deep` plays no role because there are no sublearners params_to_return = ["metric", "n_cells", "n_probes", "index_type"] return {i: self.__getattribute__(i) for i in params_to_return}
def _setattr(self, parameter, value): # based on faiss docs https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances if parameter == "VI": value = np.linalg.inv(value) chol = np.linalg.cholesky(value) cholvi = np.linalg.inv(chol) value = cholvi setattr(self, parameter, value)