diff --git a/libs/langchain/langchain/utils/math.py b/libs/langchain/langchain/utils/math.py index 77784ba2a49..99d47368197 100644 --- a/libs/langchain/langchain/utils/math.py +++ b/libs/langchain/langchain/utils/math.py @@ -1,8 +1,11 @@ """Math utils.""" +import logging from typing import List, Optional, Tuple, Union import numpy as np +logger = logging.getLogger(__name__) + Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] @@ -10,6 +13,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: """Row-wise cosine similarity between two equal-width matrices.""" if len(X) == 0 or len(Y) == 0: return np.array([]) + X = np.array(X) Y = np.array(Y) if X.shape[1] != Y.shape[1]: @@ -17,14 +21,27 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"and Y has shape {Y.shape}." ) + try: + import simsimd as simd - X_norm = np.linalg.norm(X, axis=1) - Y_norm = np.linalg.norm(Y, axis=1) - # Ignore divide by zero errors run time warnings as those are handled below. - with np.errstate(divide="ignore", invalid="ignore"): - similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) - similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 - return similarity + X = np.array(X, dtype=np.float32) + Y = np.array(Y, dtype=np.float32) + Z = 1 - simd.cdist(X, Y, metric="cosine") + if isinstance(Z, float): + return np.array([Z]) + return Z + except ImportError: + logger.info( + "Unable to import simsimd, defaulting to NumPy implementation. If you want " + "to use simsimd please install with `pip install simsimd`." + ) + X_norm = np.linalg.norm(X, axis=1) + Y_norm = np.linalg.norm(Y, axis=1) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity def cosine_similarity_top_k( diff --git a/libs/langchain/tests/unit_tests/test_math_utils.py b/libs/langchain/tests/unit_tests/utils/test_math.py similarity index 100% rename from libs/langchain/tests/unit_tests/test_math_utils.py rename to libs/langchain/tests/unit_tests/utils/test_math.py