Update GPT4ALL integration (#4567)

# Update GPT4ALL integration

GPT4ALL have completely changed their bindings. They use a bit odd
implementation that doesn't fit well into base.py and it will probably
be changed again, so it's a temporary solution.

Fixes #3839, #4628
This commit is contained in:
Alexey Nominas 2023-05-18 22:38:54 +06:00 committed by GitHub
parent e2d7677526
commit c9e2a01875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 72 deletions

View File

@ -27,7 +27,7 @@
} }
], ],
"source": [ "source": [
"%pip install pygpt4all > /dev/null" "%pip install gpt4all > /dev/null"
] ]
}, },
{ {
@ -64,7 +64,7 @@
"source": [ "source": [
"### Specify Model\n", "### Specify Model\n",
"\n", "\n",
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/pygpt4all\n", "To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n",
"\n", "\n",
"For full installation instructions go [here](https://gpt4all.io/index.html).\n", "For full installation instructions go [here](https://gpt4all.io/index.html).\n",
"\n", "\n",
@ -102,7 +102,7 @@
"\n", "\n",
"# Path(local_path).parent.mkdir(parents=True, exist_ok=True)\n", "# Path(local_path).parent.mkdir(parents=True, exist_ok=True)\n",
"\n", "\n",
"# # Example model. Check https://github.com/nomic-ai/pygpt4all for the latest models.\n", "# # Example model. Check https://github.com/nomic-ai/gpt4all for the latest models.\n",
"# url = 'http://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin'\n", "# url = 'http://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin'\n",
"\n", "\n",
"# # send a GET request to the URL to download the file. Stream since it's large\n", "# # send a GET request to the URL to download the file. Stream since it's large\n",
@ -126,7 +126,8 @@
"callbacks = [StreamingStdOutCallbackHandler()]\n", "callbacks = [StreamingStdOutCallbackHandler()]\n",
"# Verbose is required to pass to the callback manager\n", "# Verbose is required to pass to the callback manager\n",
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n", "llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
"# If you want to use GPT4ALL_J model add the backend parameter\n", "# If you want to use a custom model add the backend parameter\n",
"# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n",
"llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)" "llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)"
] ]
}, },

View File

@ -12,7 +12,7 @@ from langchain.llms.utils import enforce_stop_tokens
class GPT4All(LLM): class GPT4All(LLM):
r"""Wrapper around GPT4All language models. r"""Wrapper around GPT4All language models.
To use, you should have the ``pygpt4all`` python package installed, the To use, you should have the ``gpt4all`` python package installed, the
pre-trained model file, and the model's config information. pre-trained model file, and the model's config information.
Example: Example:
@ -28,7 +28,7 @@ class GPT4All(LLM):
model: str model: str
"""Path to the pre-trained GPT4All model file.""" """Path to the pre-trained GPT4All model file."""
backend: str = Field("llama", alias="backend") backend: Optional[str] = Field(None, alias="backend")
n_ctx: int = Field(512, alias="n_ctx") n_ctx: int = Field(512, alias="n_ctx")
"""Token context window.""" """Token context window."""
@ -88,6 +88,10 @@ class GPT4All(LLM):
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
context_erase: float = 0.5
"""Leave (n_ctx * context_erase) tokens
starting from beginning if the context has run out."""
client: Any = None #: :meta private: client: Any = None #: :meta private:
class Config: class Config:
@ -95,86 +99,55 @@ class GPT4All(LLM):
extra = Extra.forbid extra = Extra.forbid
def _llama_default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"n_predict": self.n_predict,
"n_threads": self.n_threads,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
}
def _gptj_default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"n_predict": self.n_predict,
"n_threads": self.n_threads,
"top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
}
@staticmethod @staticmethod
def _llama_param_names() -> Set[str]: def _model_param_names() -> Set[str]:
"""Get the identifying parameters."""
return { return {
"seed",
"n_ctx", "n_ctx",
"n_parts", "n_predict",
"f16_kv", "top_k",
"logits_all", "top_p",
"vocab_only", "temp",
"use_mlock", "n_batch",
"embedding", "repeat_penalty",
"repeat_last_n",
"context_erase",
} }
@staticmethod
def _gptj_param_names() -> Set[str]:
"""Get the identifying parameters."""
return set()
@staticmethod
def _model_param_names(backend: str) -> Set[str]:
if backend == "llama":
return GPT4All._llama_param_names()
else:
return GPT4All._gptj_param_names()
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
if self.backend == "llama": return {
return self._llama_default_params() "n_ctx": self.n_ctx,
else: "n_predict": self.n_predict,
return self._gptj_default_params() "top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
"n_batch": self.n_batch,
"repeat_penalty": self.repeat_penalty,
"repeat_last_n": self.repeat_last_n,
"context_erase": self.context_erase,
}
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment.""" """Validate that the python package exists in the environment."""
try: try:
backend = values["backend"] from gpt4all import GPT4All as GPT4AllModel
if backend == "llama":
from pygpt4all import GPT4All as GPT4AllModel full_path = values["model"]
elif backend == "gptj": model_path, delimiter, model_name = full_path.rpartition("/")
from pygpt4all import GPT4All_J as GPT4AllModel model_path += delimiter
else:
raise ValueError(f"Incorrect gpt4all backend {cls.backend}")
model_kwargs = {
k: v
for k, v in values.items()
if k in GPT4All._model_param_names(backend)
}
values["client"] = GPT4AllModel( values["client"] = GPT4AllModel(
model_path=values["model"], model_name=model_name,
**model_kwargs, model_path=model_path or None,
model_type=values["backend"],
allow_download=False,
) )
values["backend"] = values["client"].model.model_type
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import pygpt4all python package. " "Could not import gpt4all python package. "
"Please install it with `pip install pygpt4all`." "Please install it with `pip install gpt4all`."
) )
return values return values
@ -185,9 +158,7 @@ class GPT4All(LLM):
"model": self.model, "model": self.model,
**self._default_params(), **self._default_params(),
**{ **{
k: v k: v for k, v in self.__dict__.items() if k in self._model_param_names()
for k, v in self.__dict__.items()
if k in self._model_param_names(self.backend)
}, },
} }