From 7e7fcf5b1fd1a7e7f3ec49651b20b12c8f639844 Mon Sep 17 00:00:00 2001 From: Pat Patterson Date: Wed, 7 Aug 2024 14:34:01 +0100 Subject: [PATCH] community: Fix ValidationError on creating GPT4AllEmbeddings with no gpt4all_kwargs (#25124) - **Description:** Instantiating `GPT4AllEmbeddings` with no `gpt4all_kwargs` argument raised a `ValidationError`. Root cause: #21238 added the capability to pass `gpt4all_kwargs` through to the `GPT4All` instance via `Embed4All`, but broke code that did not specify a `gpt4all_kwargs` argument. - **Issue:** #25119 - **Dependencies:** None - **Twitter handle:** [`@metadaddy`](https://twitter.com/metadaddy) --- .../langchain_community/embeddings/gpt4all.py | 2 +- .../unit_tests/embeddings/test_gpt4all.py | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 libs/community/tests/unit_tests/embeddings/test_gpt4all.py diff --git a/libs/community/langchain_community/embeddings/gpt4all.py b/libs/community/langchain_community/embeddings/gpt4all.py index 1c6b6501898..338acc50389 100644 --- a/libs/community/langchain_community/embeddings/gpt4all.py +++ b/libs/community/langchain_community/embeddings/gpt4all.py @@ -38,7 +38,7 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): model_name=values.get("model_name"), n_threads=values.get("n_threads"), device=values.get("device"), - **values.get("gpt4all_kwargs"), + **(values.get("gpt4all_kwargs") or {}), ) except ImportError: raise ImportError( diff --git a/libs/community/tests/unit_tests/embeddings/test_gpt4all.py b/libs/community/tests/unit_tests/embeddings/test_gpt4all.py new file mode 100644 index 00000000000..19cfc134b5a --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_gpt4all.py @@ -0,0 +1,62 @@ +import sys +from typing import Any, Optional +from unittest.mock import MagicMock, patch + +from langchain_community.embeddings import GPT4AllEmbeddings + +_GPT4ALL_MODEL_NAME = "all-MiniLM-L6-v2.gguf2.f16.gguf" +_GPT4ALL_NTHREADS = 4 +_GPT4ALL_DEVICE = "gpu" +_GPT4ALL_KWARGS = {"allow_download": False} + + +class MockEmbed4All(MagicMock): + """Mock Embed4All class.""" + + def __init__( + self, + model_name: Optional[str] = None, + *, + n_threads: Optional[int] = None, + device: Optional[str] = None, + **kwargs: Any, + ): # type: ignore[no-untyped-def] + assert model_name == _GPT4ALL_MODEL_NAME + + +class MockGpt4AllPackage(MagicMock): + """Mock gpt4all package.""" + + Embed4All = MockEmbed4All + + +def test_create_gpt4all_embeddings_no_kwargs() -> None: + """Test fix for #25119""" + with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}): + embedding = GPT4AllEmbeddings( # type: ignore[call-arg] + model_name=_GPT4ALL_MODEL_NAME, + n_threads=_GPT4ALL_NTHREADS, + device=_GPT4ALL_DEVICE, + ) + + assert embedding.model_name == _GPT4ALL_MODEL_NAME + assert embedding.n_threads == _GPT4ALL_NTHREADS + assert embedding.device == _GPT4ALL_DEVICE + assert embedding.gpt4all_kwargs == {} + assert isinstance(embedding.client, MockEmbed4All) + + +def test_create_gpt4all_embeddings_with_kwargs() -> None: + with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}): + embedding = GPT4AllEmbeddings( # type: ignore[call-arg] + model_name=_GPT4ALL_MODEL_NAME, + n_threads=_GPT4ALL_NTHREADS, + device=_GPT4ALL_DEVICE, + gpt4all_kwargs=_GPT4ALL_KWARGS, + ) + + assert embedding.model_name == _GPT4ALL_MODEL_NAME + assert embedding.n_threads == _GPT4ALL_NTHREADS + assert embedding.device == _GPT4ALL_DEVICE + assert embedding.gpt4all_kwargs == _GPT4ALL_KWARGS + assert isinstance(embedding.client, MockEmbed4All)