From b1e90b3075a4f9be7a30ec9e58c73ab8d2d830e9 Mon Sep 17 00:00:00 2001 From: wenngong <76683249+wenngong@users.noreply.github.com> Date: Fri, 5 Jul 2024 22:46:34 +0800 Subject: [PATCH] community: add model_name param valid for GPT4AllEmbeddings (#23867) Description: add model_name param valid for GPT4AllEmbeddings Issue: #23863 #22819 --------- Co-authored-by: gongwn1 --- libs/community/langchain_community/embeddings/gpt4all.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/embeddings/gpt4all.py b/libs/community/langchain_community/embeddings/gpt4all.py index 87c0bdca4bf..1c6b6501898 100644 --- a/libs/community/langchain_community/embeddings/gpt4all.py +++ b/libs/community/langchain_community/embeddings/gpt4all.py @@ -22,21 +22,20 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): ) """ - model_name: str + model_name: Optional[str] = None n_threads: Optional[int] = None device: Optional[str] = "cpu" gpt4all_kwargs: Optional[dict] = {} client: Any #: :meta private: - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that GPT4All library is installed.""" - try: from gpt4all import Embed4All values["client"] = Embed4All( - model_name=values["model_name"], + model_name=values.get("model_name"), n_threads=values.get("n_threads"), device=values.get("device"), **values.get("gpt4all_kwargs"),