diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 2919de1f9f4..1cbad1b3633 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -64,6 +64,9 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: # Ignore divide by zero errors run time warnings as those are handled below. with np.errstate(divide="ignore", invalid="ignore"): similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm) + if np.isnan(similarity).all(): + msg = "NaN values found, please remove the NaN values and try again" + raise ValueError(msg) from None similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity diff --git a/libs/core/tests/unit_tests/vectorstores/test_utils.py b/libs/core/tests/unit_tests/vectorstores/test_utils.py index 1bf4f9d7d96..2ff9817f5bf 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_utils.py +++ b/libs/core/tests/unit_tests/vectorstores/test_utils.py @@ -42,9 +42,8 @@ class TestCosineSimilarity: """Test cosine similarity with zero vector.""" x: list[list[float]] = [[0, 0, 0]] y: list[list[float]] = [[1, 2, 3]] - result = _cosine_similarity(x, y) - expected = np.array([[0.0]]) - np.testing.assert_array_almost_equal(result, expected) + with pytest.raises(ValueError, match="NaN values found"): + _cosine_similarity(x, y) def test_multiple_vectors(self) -> None: """Test cosine similarity with multiple vectors.""" @@ -115,13 +114,8 @@ class TestCosineSimilarity: # Create vectors that would result in NaN/inf in similarity calculation x: list[list[float]] = [[0, 0]] # Zero vector y: list[list[float]] = [[0, 0]] # Zero vector - result = _cosine_similarity(x, y) - - # Should return 0.0 instead of NaN - expected = np.array([[0.0]]) - np.testing.assert_array_equal(result, expected) - assert not np.isnan(result).any() - assert not np.isinf(result).any() + with pytest.raises(ValueError, match="NaN values found"): + _cosine_similarity(x, y) def test_large_values(self) -> None: """Test with large values to check numerical stability."""