mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 15:03:21 +00:00
fix: issue a warning if np.nan
or np.inf
are in _cosine_similarity
argument Matrices (#31532)
- **Description**: issues a warning if inf and nan are passed as inputs to langchain_core.vectorstores.utils._cosine_similarity - **Issue**: Fixes #31496 - **Dependencies**: no external dependencies added, only warnings module imported --------- Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
@@ -7,6 +7,7 @@ as they can change without notice.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -46,6 +47,23 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray:
|
|||||||
|
|
||||||
x = np.array(x)
|
x = np.array(x)
|
||||||
y = np.array(y)
|
y = np.array(y)
|
||||||
|
|
||||||
|
# Check for NaN
|
||||||
|
if np.any(np.isnan(x)) or np.any(np.isnan(y)):
|
||||||
|
warnings.warn(
|
||||||
|
"NaN found in input arrays, unexpected return might follow",
|
||||||
|
category=RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for Inf
|
||||||
|
if np.any(np.isinf(x)) or np.any(np.isinf(y)):
|
||||||
|
warnings.warn(
|
||||||
|
"Inf found in input arrays, unexpected return might follow",
|
||||||
|
category=RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
if x.shape[1] != y.shape[1]:
|
if x.shape[1] != y.shape[1]:
|
||||||
msg = (
|
msg = (
|
||||||
f"Number of columns in X and Y must be the same. X has shape {x.shape} "
|
f"Number of columns in X and Y must be the same. X has shape {x.shape} "
|
||||||
|
Reference in New Issue
Block a user