diff --git a/libs/community/langchain_community/embeddings/gpt4all.py b/libs/community/langchain_community/embeddings/gpt4all.py index 1c6b6501898..338acc50389 100644 --- a/libs/community/langchain_community/embeddings/gpt4all.py +++ b/libs/community/langchain_community/embeddings/gpt4all.py @@ -38,7 +38,7 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): model_name=values.get("model_name"), n_threads=values.get("n_threads"), device=values.get("device"), - **values.get("gpt4all_kwargs"), + **(values.get("gpt4all_kwargs") or {}), ) except ImportError: raise ImportError( diff --git a/libs/community/tests/unit_tests/embeddings/test_gpt4all.py b/libs/community/tests/unit_tests/embeddings/test_gpt4all.py new file mode 100644 index 00000000000..19cfc134b5a --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_gpt4all.py @@ -0,0 +1,62 @@ +import sys +from typing import Any, Optional +from unittest.mock import MagicMock, patch + +from langchain_community.embeddings import GPT4AllEmbeddings + +_GPT4ALL_MODEL_NAME = "all-MiniLM-L6-v2.gguf2.f16.gguf" +_GPT4ALL_NTHREADS = 4 +_GPT4ALL_DEVICE = "gpu" +_GPT4ALL_KWARGS = {"allow_download": False} + + +class MockEmbed4All(MagicMock): + """Mock Embed4All class.""" + + def __init__( + self, + model_name: Optional[str] = None, + *, + n_threads: Optional[int] = None, + device: Optional[str] = None, + **kwargs: Any, + ): # type: ignore[no-untyped-def] + assert model_name == _GPT4ALL_MODEL_NAME + + +class MockGpt4AllPackage(MagicMock): + """Mock gpt4all package.""" + + Embed4All = MockEmbed4All + + +def test_create_gpt4all_embeddings_no_kwargs() -> None: + """Test fix for #25119""" + with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}): + embedding = GPT4AllEmbeddings( # type: ignore[call-arg] + model_name=_GPT4ALL_MODEL_NAME, + n_threads=_GPT4ALL_NTHREADS, + device=_GPT4ALL_DEVICE, + ) + + assert embedding.model_name == _GPT4ALL_MODEL_NAME + assert embedding.n_threads == _GPT4ALL_NTHREADS + assert embedding.device == _GPT4ALL_DEVICE + assert embedding.gpt4all_kwargs == {} + assert isinstance(embedding.client, MockEmbed4All) + + +def test_create_gpt4all_embeddings_with_kwargs() -> None: + with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}): + embedding = GPT4AllEmbeddings( # type: ignore[call-arg] + model_name=_GPT4ALL_MODEL_NAME, + n_threads=_GPT4ALL_NTHREADS, + device=_GPT4ALL_DEVICE, + gpt4all_kwargs=_GPT4ALL_KWARGS, + ) + + assert embedding.model_name == _GPT4ALL_MODEL_NAME + assert embedding.n_threads == _GPT4ALL_NTHREADS + assert embedding.device == _GPT4ALL_DEVICE + assert embedding.gpt4all_kwargs == _GPT4ALL_KWARGS + assert isinstance(embedding.client, MockEmbed4All)