mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
community: Fix ValidationError on creating GPT4AllEmbeddings with no gpt4all_kwargs (#25124)
- **Description:** Instantiating `GPT4AllEmbeddings` with no `gpt4all_kwargs` argument raised a `ValidationError`. Root cause: #21238 added the capability to pass `gpt4all_kwargs` through to the `GPT4All` instance via `Embed4All`, but broke code that did not specify a `gpt4all_kwargs` argument. - **Issue:** #25119 - **Dependencies:** None - **Twitter handle:** [`@metadaddy`](https://twitter.com/metadaddy)
This commit is contained in:
parent
04dd8d3b0a
commit
7e7fcf5b1f
@ -38,7 +38,7 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
|
||||
model_name=values.get("model_name"),
|
||||
n_threads=values.get("n_threads"),
|
||||
device=values.get("device"),
|
||||
**values.get("gpt4all_kwargs"),
|
||||
**(values.get("gpt4all_kwargs") or {}),
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
|
62
libs/community/tests/unit_tests/embeddings/test_gpt4all.py
Normal file
62
libs/community/tests/unit_tests/embeddings/test_gpt4all.py
Normal file
@ -0,0 +1,62 @@
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_community.embeddings import GPT4AllEmbeddings
|
||||
|
||||
_GPT4ALL_MODEL_NAME = "all-MiniLM-L6-v2.gguf2.f16.gguf"
|
||||
_GPT4ALL_NTHREADS = 4
|
||||
_GPT4ALL_DEVICE = "gpu"
|
||||
_GPT4ALL_KWARGS = {"allow_download": False}
|
||||
|
||||
|
||||
class MockEmbed4All(MagicMock):
|
||||
"""Mock Embed4All class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
*,
|
||||
n_threads: Optional[int] = None,
|
||||
device: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
): # type: ignore[no-untyped-def]
|
||||
assert model_name == _GPT4ALL_MODEL_NAME
|
||||
|
||||
|
||||
class MockGpt4AllPackage(MagicMock):
|
||||
"""Mock gpt4all package."""
|
||||
|
||||
Embed4All = MockEmbed4All
|
||||
|
||||
|
||||
def test_create_gpt4all_embeddings_no_kwargs() -> None:
|
||||
"""Test fix for #25119"""
|
||||
with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}):
|
||||
embedding = GPT4AllEmbeddings( # type: ignore[call-arg]
|
||||
model_name=_GPT4ALL_MODEL_NAME,
|
||||
n_threads=_GPT4ALL_NTHREADS,
|
||||
device=_GPT4ALL_DEVICE,
|
||||
)
|
||||
|
||||
assert embedding.model_name == _GPT4ALL_MODEL_NAME
|
||||
assert embedding.n_threads == _GPT4ALL_NTHREADS
|
||||
assert embedding.device == _GPT4ALL_DEVICE
|
||||
assert embedding.gpt4all_kwargs == {}
|
||||
assert isinstance(embedding.client, MockEmbed4All)
|
||||
|
||||
|
||||
def test_create_gpt4all_embeddings_with_kwargs() -> None:
|
||||
with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}):
|
||||
embedding = GPT4AllEmbeddings( # type: ignore[call-arg]
|
||||
model_name=_GPT4ALL_MODEL_NAME,
|
||||
n_threads=_GPT4ALL_NTHREADS,
|
||||
device=_GPT4ALL_DEVICE,
|
||||
gpt4all_kwargs=_GPT4ALL_KWARGS,
|
||||
)
|
||||
|
||||
assert embedding.model_name == _GPT4ALL_MODEL_NAME
|
||||
assert embedding.n_threads == _GPT4ALL_NTHREADS
|
||||
assert embedding.device == _GPT4ALL_DEVICE
|
||||
assert embedding.gpt4all_kwargs == _GPT4ALL_KWARGS
|
||||
assert isinstance(embedding.client, MockEmbed4All)
|
Loading…
Reference in New Issue
Block a user