diff --git a/libs/community/langchain_community/embeddings/gpt4all.py b/libs/community/langchain_community/embeddings/gpt4all.py index 87c0bdca4bf..1c6b6501898 100644 --- a/libs/community/langchain_community/embeddings/gpt4all.py +++ b/libs/community/langchain_community/embeddings/gpt4all.py @@ -22,21 +22,20 @@ class GPT4AllEmbeddings(BaseModel, Embeddings): ) """ - model_name: str + model_name: Optional[str] = None n_threads: Optional[int] = None device: Optional[str] = "cpu" gpt4all_kwargs: Optional[dict] = {} client: Any #: :meta private: - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that GPT4All library is installed.""" - try: from gpt4all import Embed4All values["client"] = Embed4All( - model_name=values["model_name"], + model_name=values.get("model_name"), n_threads=values.get("n_threads"), device=values.get("device"), **values.get("gpt4all_kwargs"),