community: Fix ValidationError on creating GPT4AllEmbeddings with no gpt4all_kwargs (#25124)

- **Description:** Instantiating `GPT4AllEmbeddings` with no
`gpt4all_kwargs` argument raised a `ValidationError`. Root cause: #21238
added the capability to pass `gpt4all_kwargs` through to the `GPT4All`
instance via `Embed4All`, but broke code that did not specify a
`gpt4all_kwargs` argument.
- **Issue:** #25119 
- **Dependencies:** None
- **Twitter handle:** [`@metadaddy`](https://twitter.com/metadaddy)
This commit is contained in:
Pat Patterson 2024-08-07 14:34:01 +01:00 committed by GitHub
parent 04dd8d3b0a
commit 7e7fcf5b1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 1 deletions

View File

@ -38,7 +38,7 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
model_name=values.get("model_name"),
n_threads=values.get("n_threads"),
device=values.get("device"),
**values.get("gpt4all_kwargs"),
**(values.get("gpt4all_kwargs") or {}),
)
except ImportError:
raise ImportError(

View File

@ -0,0 +1,62 @@
import sys
from typing import Any, Optional
from unittest.mock import MagicMock, patch
from langchain_community.embeddings import GPT4AllEmbeddings
_GPT4ALL_MODEL_NAME = "all-MiniLM-L6-v2.gguf2.f16.gguf"
_GPT4ALL_NTHREADS = 4
_GPT4ALL_DEVICE = "gpu"
_GPT4ALL_KWARGS = {"allow_download": False}
class MockEmbed4All(MagicMock):
"""Mock Embed4All class."""
def __init__(
self,
model_name: Optional[str] = None,
*,
n_threads: Optional[int] = None,
device: Optional[str] = None,
**kwargs: Any,
): # type: ignore[no-untyped-def]
assert model_name == _GPT4ALL_MODEL_NAME
class MockGpt4AllPackage(MagicMock):
"""Mock gpt4all package."""
Embed4All = MockEmbed4All
def test_create_gpt4all_embeddings_no_kwargs() -> None:
"""Test fix for #25119"""
with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}):
embedding = GPT4AllEmbeddings( # type: ignore[call-arg]
model_name=_GPT4ALL_MODEL_NAME,
n_threads=_GPT4ALL_NTHREADS,
device=_GPT4ALL_DEVICE,
)
assert embedding.model_name == _GPT4ALL_MODEL_NAME
assert embedding.n_threads == _GPT4ALL_NTHREADS
assert embedding.device == _GPT4ALL_DEVICE
assert embedding.gpt4all_kwargs == {}
assert isinstance(embedding.client, MockEmbed4All)
def test_create_gpt4all_embeddings_with_kwargs() -> None:
with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}):
embedding = GPT4AllEmbeddings( # type: ignore[call-arg]
model_name=_GPT4ALL_MODEL_NAME,
n_threads=_GPT4ALL_NTHREADS,
device=_GPT4ALL_DEVICE,
gpt4all_kwargs=_GPT4ALL_KWARGS,
)
assert embedding.model_name == _GPT4ALL_MODEL_NAME
assert embedding.n_threads == _GPT4ALL_NTHREADS
assert embedding.device == _GPT4ALL_DEVICE
assert embedding.gpt4all_kwargs == _GPT4ALL_KWARGS
assert isinstance(embedding.client, MockEmbed4All)