mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 10:10:09 +00:00 
			
		
		
		
	- **Description:** Instantiating `GPT4AllEmbeddings` with no `gpt4all_kwargs` argument raised a `ValidationError`. Root cause: #21238 added the capability to pass `gpt4all_kwargs` through to the `GPT4All` instance via `Embed4All`, but broke code that did not specify a `gpt4all_kwargs` argument. - **Issue:** #25119 - **Dependencies:** None - **Twitter handle:** [`@metadaddy`](https://twitter.com/metadaddy)
		
			
				
	
	
		
			63 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			63 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import sys
 | 
						|
from typing import Any, Optional
 | 
						|
from unittest.mock import MagicMock, patch
 | 
						|
 | 
						|
from langchain_community.embeddings import GPT4AllEmbeddings
 | 
						|
 | 
						|
_GPT4ALL_MODEL_NAME = "all-MiniLM-L6-v2.gguf2.f16.gguf"
 | 
						|
_GPT4ALL_NTHREADS = 4
 | 
						|
_GPT4ALL_DEVICE = "gpu"
 | 
						|
_GPT4ALL_KWARGS = {"allow_download": False}
 | 
						|
 | 
						|
 | 
						|
class MockEmbed4All(MagicMock):
 | 
						|
    """Mock Embed4All class."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        model_name: Optional[str] = None,
 | 
						|
        *,
 | 
						|
        n_threads: Optional[int] = None,
 | 
						|
        device: Optional[str] = None,
 | 
						|
        **kwargs: Any,
 | 
						|
    ):  # type: ignore[no-untyped-def]
 | 
						|
        assert model_name == _GPT4ALL_MODEL_NAME
 | 
						|
 | 
						|
 | 
						|
class MockGpt4AllPackage(MagicMock):
 | 
						|
    """Mock gpt4all package."""
 | 
						|
 | 
						|
    Embed4All = MockEmbed4All
 | 
						|
 | 
						|
 | 
						|
def test_create_gpt4all_embeddings_no_kwargs() -> None:
 | 
						|
    """Test fix for #25119"""
 | 
						|
    with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}):
 | 
						|
        embedding = GPT4AllEmbeddings(  # type: ignore[call-arg]
 | 
						|
            model_name=_GPT4ALL_MODEL_NAME,
 | 
						|
            n_threads=_GPT4ALL_NTHREADS,
 | 
						|
            device=_GPT4ALL_DEVICE,
 | 
						|
        )
 | 
						|
 | 
						|
        assert embedding.model_name == _GPT4ALL_MODEL_NAME
 | 
						|
        assert embedding.n_threads == _GPT4ALL_NTHREADS
 | 
						|
        assert embedding.device == _GPT4ALL_DEVICE
 | 
						|
        assert embedding.gpt4all_kwargs == {}
 | 
						|
        assert isinstance(embedding.client, MockEmbed4All)
 | 
						|
 | 
						|
 | 
						|
def test_create_gpt4all_embeddings_with_kwargs() -> None:
 | 
						|
    with patch.dict(sys.modules, {"gpt4all": MockGpt4AllPackage()}):
 | 
						|
        embedding = GPT4AllEmbeddings(  # type: ignore[call-arg]
 | 
						|
            model_name=_GPT4ALL_MODEL_NAME,
 | 
						|
            n_threads=_GPT4ALL_NTHREADS,
 | 
						|
            device=_GPT4ALL_DEVICE,
 | 
						|
            gpt4all_kwargs=_GPT4ALL_KWARGS,
 | 
						|
        )
 | 
						|
 | 
						|
        assert embedding.model_name == _GPT4ALL_MODEL_NAME
 | 
						|
        assert embedding.n_threads == _GPT4ALL_NTHREADS
 | 
						|
        assert embedding.device == _GPT4ALL_DEVICE
 | 
						|
        assert embedding.gpt4all_kwargs == _GPT4ALL_KWARGS
 | 
						|
        assert isinstance(embedding.client, MockEmbed4All)
 |