mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
Fix GPT4All bug w/ "n_ctx" param (#7093)
Running `GPT4All` per the [docs](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/gpt4all), I see: ``` $ from langchain.llms import GPT4All $ model = GPT4All(model=local_path) $ model("The capital of France is ", max_tokens=10) TypeError: generate() got an unexpected keyword argument 'n_ctx' ``` It appears `n_ctx` is [no longer a supported param](https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All.generate) in the GPT4All API from https://github.com/nomic-ai/gpt4all/pull/1090. It now uses `max_tokens`, so I set this. And I also set other defaults used in GPT4All client [here](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-bindings/python/gpt4all/gpt4all.py). Confirm it now works: ``` $ from langchain.llms import GPT4All $ model = GPT4All(model=local_path) $ model("The capital of France is ", max_tokens=10) < Model logging > "....Paris." ``` --------- Co-authored-by: R. Lance Martin <rlm@Rs-MacBook-Pro.local>
This commit is contained in:
parent
6631fd5168
commit
265c285057
@ -32,7 +32,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@ -45,7 +45,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@ -64,13 +64,20 @@
|
||||
"source": [
|
||||
"### Specify Model\n",
|
||||
"\n",
|
||||
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n",
|
||||
"To run locally, download a compatible ggml-formatted model. \n",
|
||||
" \n",
|
||||
"**Download option 1**: The [gpt4all page](https://gpt4all.io/index.html) has a useful `Model Explorer` section:\n",
|
||||
"\n",
|
||||
"For full installation instructions go [here](https://gpt4all.io/index.html).\n",
|
||||
"* Select a model of interest\n",
|
||||
"* Download using the UI and move the `.bin` to the `local_path` (noted below)\n",
|
||||
"\n",
|
||||
"The GPT4All Chat installer needs to decompress a 3GB LLM model during the installation process!\n",
|
||||
"For more info, visit https://github.com/nomic-ai/gpt4all.\n",
|
||||
"\n",
|
||||
"Note that new models are uploaded regularly - check the link above for the most recent `.bin` URL"
|
||||
"--- \n",
|
||||
"\n",
|
||||
"**Download option 2**: Uncomment the below block to download a model. \n",
|
||||
"\n",
|
||||
"* You may want to update `url` to a new version, whih can be browsed using the [gpt4all page](https://gpt4all.io/index.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -81,22 +88,8 @@
|
||||
"source": [
|
||||
"local_path = (\n",
|
||||
" \"./models/ggml-gpt4all-l13b-snoozy.bin\" # replace with your desired local file path\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Uncomment the below block to download a model. You may want to update `url` to a new version."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
")\n",
|
||||
"\n",
|
||||
"# import requests\n",
|
||||
"\n",
|
||||
"# from pathlib import Path\n",
|
||||
@ -126,8 +119,10 @@
|
||||
"source": [
|
||||
"# Callbacks support token-wise streaming\n",
|
||||
"callbacks = [StreamingStdOutCallbackHandler()]\n",
|
||||
"\n",
|
||||
"# Verbose is required to pass to the callback manager\n",
|
||||
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
|
||||
"\n",
|
||||
"# If you want to use a custom model add the backend parameter\n",
|
||||
"# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n",
|
||||
"llm = GPT4All(model=local_path, backend=\"gptj\", callbacks=callbacks, verbose=True)"
|
||||
@ -170,7 +165,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -19,7 +19,7 @@ class GPT4All(LLM):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import GPT4All
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8)
|
||||
model = GPT4All(model="./models/gpt4all-model.bin", n_threads=8)
|
||||
|
||||
# Simplest invocation
|
||||
response = model("Once upon a time, ")
|
||||
@ -30,7 +30,7 @@ class GPT4All(LLM):
|
||||
|
||||
backend: Optional[str] = Field(None, alias="backend")
|
||||
|
||||
n_ctx: int = Field(512, alias="n_ctx")
|
||||
max_tokens: int = Field(200, alias="max_tokens")
|
||||
"""Token context window."""
|
||||
|
||||
n_parts: int = Field(-1, alias="n_parts")
|
||||
@ -61,10 +61,10 @@ class GPT4All(LLM):
|
||||
n_predict: Optional[int] = 256
|
||||
"""The maximum number of tokens to generate."""
|
||||
|
||||
temp: Optional[float] = 0.8
|
||||
temp: Optional[float] = 0.7
|
||||
"""The temperature to use for sampling."""
|
||||
|
||||
top_p: Optional[float] = 0.95
|
||||
top_p: Optional[float] = 0.1
|
||||
"""The top-p value to use for sampling."""
|
||||
|
||||
top_k: Optional[int] = 40
|
||||
@ -79,19 +79,15 @@ class GPT4All(LLM):
|
||||
repeat_last_n: Optional[int] = 64
|
||||
"Last n tokens to penalize"
|
||||
|
||||
repeat_penalty: Optional[float] = 1.3
|
||||
repeat_penalty: Optional[float] = 1.18
|
||||
"""The penalty to apply to repeated tokens."""
|
||||
|
||||
n_batch: int = Field(1, alias="n_batch")
|
||||
n_batch: int = Field(8, alias="n_batch")
|
||||
"""Batch size for prompt processing."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
context_erase: float = 0.5
|
||||
"""Leave (n_ctx * context_erase) tokens
|
||||
starting from beginning if the context has run out."""
|
||||
|
||||
allow_download: bool = False
|
||||
"""If model does not exist in ~/.cache/gpt4all/, download it."""
|
||||
|
||||
@ -105,7 +101,7 @@ class GPT4All(LLM):
|
||||
@staticmethod
|
||||
def _model_param_names() -> Set[str]:
|
||||
return {
|
||||
"n_ctx",
|
||||
"max_tokens",
|
||||
"n_predict",
|
||||
"top_k",
|
||||
"top_p",
|
||||
@ -113,12 +109,11 @@ class GPT4All(LLM):
|
||||
"n_batch",
|
||||
"repeat_penalty",
|
||||
"repeat_last_n",
|
||||
"context_erase",
|
||||
}
|
||||
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"n_ctx": self.n_ctx,
|
||||
"max_tokens": self.max_tokens,
|
||||
"n_predict": self.n_predict,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
@ -126,7 +121,6 @@ class GPT4All(LLM):
|
||||
"n_batch": self.n_batch,
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"repeat_last_n": self.repeat_last_n,
|
||||
"context_erase": self.context_erase,
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
|
Loading…
Reference in New Issue
Block a user