mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 14:43:08 +00:00
Update GPT4ALL integration (#4567)
# Update GPT4ALL integration GPT4ALL have completely changed their bindings. They use a bit odd implementation that doesn't fit well into base.py and it will probably be changed again, so it's a temporary solution. Fixes #3839, #4628
This commit is contained in:
parent
e2d7677526
commit
c9e2a01875
@ -27,7 +27,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"%pip install pygpt4all > /dev/null"
|
"%pip install gpt4all > /dev/null"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -64,7 +64,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"### Specify Model\n",
|
"### Specify Model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/pygpt4all\n",
|
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n",
|
||||||
"\n",
|
"\n",
|
||||||
"For full installation instructions go [here](https://gpt4all.io/index.html).\n",
|
"For full installation instructions go [here](https://gpt4all.io/index.html).\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -102,7 +102,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Path(local_path).parent.mkdir(parents=True, exist_ok=True)\n",
|
"# Path(local_path).parent.mkdir(parents=True, exist_ok=True)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# # Example model. Check https://github.com/nomic-ai/pygpt4all for the latest models.\n",
|
"# # Example model. Check https://github.com/nomic-ai/gpt4all for the latest models.\n",
|
||||||
"# url = 'http://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin'\n",
|
"# url = 'http://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin'\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# # send a GET request to the URL to download the file. Stream since it's large\n",
|
"# # send a GET request to the URL to download the file. Stream since it's large\n",
|
||||||
@ -126,7 +126,8 @@
|
|||||||
"callbacks = [StreamingStdOutCallbackHandler()]\n",
|
"callbacks = [StreamingStdOutCallbackHandler()]\n",
|
||||||
"# Verbose is required to pass to the callback manager\n",
|
"# Verbose is required to pass to the callback manager\n",
|
||||||
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
|
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
|
||||||
"# If you want to use GPT4ALL_J model add the backend parameter\n",
|
"# If you want to use a custom model add the backend parameter\n",
|
||||||
|
"# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n",
|
||||||
"llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)"
|
"llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -12,7 +12,7 @@ from langchain.llms.utils import enforce_stop_tokens
|
|||||||
class GPT4All(LLM):
|
class GPT4All(LLM):
|
||||||
r"""Wrapper around GPT4All language models.
|
r"""Wrapper around GPT4All language models.
|
||||||
|
|
||||||
To use, you should have the ``pygpt4all`` python package installed, the
|
To use, you should have the ``gpt4all`` python package installed, the
|
||||||
pre-trained model file, and the model's config information.
|
pre-trained model file, and the model's config information.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -28,7 +28,7 @@ class GPT4All(LLM):
|
|||||||
model: str
|
model: str
|
||||||
"""Path to the pre-trained GPT4All model file."""
|
"""Path to the pre-trained GPT4All model file."""
|
||||||
|
|
||||||
backend: str = Field("llama", alias="backend")
|
backend: Optional[str] = Field(None, alias="backend")
|
||||||
|
|
||||||
n_ctx: int = Field(512, alias="n_ctx")
|
n_ctx: int = Field(512, alias="n_ctx")
|
||||||
"""Token context window."""
|
"""Token context window."""
|
||||||
@ -88,6 +88,10 @@ class GPT4All(LLM):
|
|||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
|
|
||||||
|
context_erase: float = 0.5
|
||||||
|
"""Leave (n_ctx * context_erase) tokens
|
||||||
|
starting from beginning if the context has run out."""
|
||||||
|
|
||||||
client: Any = None #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -95,86 +99,55 @@ class GPT4All(LLM):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
def _llama_default_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {
|
|
||||||
"n_predict": self.n_predict,
|
|
||||||
"n_threads": self.n_threads,
|
|
||||||
"repeat_last_n": self.repeat_last_n,
|
|
||||||
"repeat_penalty": self.repeat_penalty,
|
|
||||||
"top_k": self.top_k,
|
|
||||||
"top_p": self.top_p,
|
|
||||||
"temp": self.temp,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _gptj_default_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {
|
|
||||||
"n_predict": self.n_predict,
|
|
||||||
"n_threads": self.n_threads,
|
|
||||||
"top_k": self.top_k,
|
|
||||||
"top_p": self.top_p,
|
|
||||||
"temp": self.temp,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _llama_param_names() -> Set[str]:
|
def _model_param_names() -> Set[str]:
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {
|
return {
|
||||||
"seed",
|
|
||||||
"n_ctx",
|
"n_ctx",
|
||||||
"n_parts",
|
"n_predict",
|
||||||
"f16_kv",
|
"top_k",
|
||||||
"logits_all",
|
"top_p",
|
||||||
"vocab_only",
|
"temp",
|
||||||
"use_mlock",
|
"n_batch",
|
||||||
"embedding",
|
"repeat_penalty",
|
||||||
|
"repeat_last_n",
|
||||||
|
"context_erase",
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _gptj_param_names() -> Set[str]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return set()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _model_param_names(backend: str) -> Set[str]:
|
|
||||||
if backend == "llama":
|
|
||||||
return GPT4All._llama_param_names()
|
|
||||||
else:
|
|
||||||
return GPT4All._gptj_param_names()
|
|
||||||
|
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
if self.backend == "llama":
|
return {
|
||||||
return self._llama_default_params()
|
"n_ctx": self.n_ctx,
|
||||||
else:
|
"n_predict": self.n_predict,
|
||||||
return self._gptj_default_params()
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temp": self.temp,
|
||||||
|
"n_batch": self.n_batch,
|
||||||
|
"repeat_penalty": self.repeat_penalty,
|
||||||
|
"repeat_last_n": self.repeat_last_n,
|
||||||
|
"context_erase": self.context_erase,
|
||||||
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that the python package exists in the environment."""
|
"""Validate that the python package exists in the environment."""
|
||||||
try:
|
try:
|
||||||
backend = values["backend"]
|
from gpt4all import GPT4All as GPT4AllModel
|
||||||
if backend == "llama":
|
|
||||||
from pygpt4all import GPT4All as GPT4AllModel
|
full_path = values["model"]
|
||||||
elif backend == "gptj":
|
model_path, delimiter, model_name = full_path.rpartition("/")
|
||||||
from pygpt4all import GPT4All_J as GPT4AllModel
|
model_path += delimiter
|
||||||
else:
|
|
||||||
raise ValueError(f"Incorrect gpt4all backend {cls.backend}")
|
|
||||||
|
|
||||||
model_kwargs = {
|
|
||||||
k: v
|
|
||||||
for k, v in values.items()
|
|
||||||
if k in GPT4All._model_param_names(backend)
|
|
||||||
}
|
|
||||||
values["client"] = GPT4AllModel(
|
values["client"] = GPT4AllModel(
|
||||||
model_path=values["model"],
|
model_name=model_name,
|
||||||
**model_kwargs,
|
model_path=model_path or None,
|
||||||
|
model_type=values["backend"],
|
||||||
|
allow_download=False,
|
||||||
)
|
)
|
||||||
|
values["backend"] = values["client"].model.model_type
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import pygpt4all python package. "
|
"Could not import gpt4all python package. "
|
||||||
"Please install it with `pip install pygpt4all`."
|
"Please install it with `pip install gpt4all`."
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -185,9 +158,7 @@ class GPT4All(LLM):
|
|||||||
"model": self.model,
|
"model": self.model,
|
||||||
**self._default_params(),
|
**self._default_params(),
|
||||||
**{
|
**{
|
||||||
k: v
|
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
|
||||||
for k, v in self.__dict__.items()
|
|
||||||
if k in self._model_param_names(self.backend)
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user