From d3ce6aad2e89cb91e4e6e999b45e96fe279597f1 Mon Sep 17 00:00:00 2001 From: Alex JW <57624157+Alex-J-W@users.noreply.github.com> Date: Wed, 8 May 2024 23:44:47 +0200 Subject: [PATCH] community: Instantiate GPT4AllEmbeddings with parameters (#21238) ### GPT4AllEmbeddings parameters --- **Description:** As of right now the **Embed4All** class inside _GPT4AllEmbeddings_ is instantiated as it's default which leaves no room to customize the chosen model and it's behavior. Thus: - GPT4AllEmbeddings can now be instantiated with custom parameters like a different model that shall be used. --------- Co-authored-by: AlexJauchWalser --- .../langchain_community/embeddings/gpt4all.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/embeddings/gpt4all.py b/libs/community/langchain_community/embeddings/gpt4all.py index f7983a5968b..87c0bdca4bf 100644 --- a/libs/community/langchain_community/embeddings/gpt4all.py +++ b/libs/community/langchain_community/embeddings/gpt4all.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, root_validator @@ -14,9 +14,18 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): from langchain_community.embeddings import GPT4AllEmbeddings - embeddings = GPT4AllEmbeddings() + model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf" + gpt4all_kwargs = {'allow_download': 'True'} + embeddings = GPT4AllEmbeddings( + model_name=model_name, + gpt4all_kwargs=gpt4all_kwargs + ) """ + model_name: str + n_threads: Optional[int] = None + device: Optional[str] = "cpu" + gpt4all_kwargs: Optional[dict] = {} client: Any #: :meta private: @root_validator() @@ -26,7 +35,12 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): try: from gpt4all import Embed4All - values["client"] = Embed4All() + values["client"] = Embed4All( + model_name=values["model_name"], + n_threads=values.get("n_threads"), + device=values.get("device"), + **values.get("gpt4all_kwargs"), + ) except ImportError: raise ImportError( "Could not import gpt4all library. "