mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-26 15:31:55 +00:00
Python Bindings: Improved unit tests, documentation and unification of API (#1090)
* Makefiles, black, isort * Black and isort * unit tests and generation method * chat context provider * context does not reset * Current state * Fixup * Python bindings with unit tests * GPT4All Python Bindings: chat contexts, tests * New python bindings and backend fixes * Black and Isort * Documentation error * preserved n_predict for backwords compat with langchain --------- Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
40a3faeb05
commit
46a0762bd5
@ -128,6 +128,9 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
std::function<bool(bool)> recalc_func =
|
||||
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate_callback));
|
||||
|
||||
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
|
||||
wrapper->promptContext.tokens.resize(ctx->n_past);
|
||||
|
||||
// Copy the C prompt context
|
||||
wrapper->promptContext.n_past = ctx->n_past;
|
||||
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
||||
|
7
gpt4all-bindings/python/.isort.cfg
Normal file
7
gpt4all-bindings/python/.isort.cfg
Normal file
@ -0,0 +1,7 @@
|
||||
[settings]
|
||||
known_third_party=geopy,nltk,np,numpy,pandas,pysbd,fire,torch
|
||||
|
||||
line_length=120
|
||||
include_trailing_comma=True
|
||||
multi_line_output=3
|
||||
use_parentheses=True
|
@ -5,7 +5,7 @@ The [GPT4All Chat Client](https://gpt4all.io) lets you easily interact with any
|
||||
It is optimized to run 7-13B parameter LLMs on the CPU's of any computer running OSX/Windows/Linux.
|
||||
|
||||
## Running LLMs on CPU
|
||||
The GPT4All Chat UI supports models from all newer versions of `GGML`, `llama.cpp` including the `LLaMA`, `MPT`, `replit` and `GPT-J` architectures. The `falcon` architecture will soon also be supported.
|
||||
The GPT4All Chat UI supports models from all newer versions of `GGML`, `llama.cpp` including the `LLaMA`, `MPT`, `replit`, `GPT-J` and `falcon` architectures
|
||||
|
||||
GPT4All maintains an official list of recommended models located in [models.json](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models.json). You can pull request new models to it and if accepted they will show up in the official download dialog.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# GPT4All Python API
|
||||
The `GPT4All` package provides Python bindings and an API to our C/C++ model backend libraries.
|
||||
The source code, README, and local build instructions can be found [here](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/python).
|
||||
The `GPT4All` python package provides bindings to our C/C++ model backend libraries.
|
||||
The source code and local build instructions can be found [here](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/python).
|
||||
|
||||
|
||||
## Quickstart
|
||||
@ -9,29 +9,88 @@ The source code, README, and local build instructions can be found [here](https:
|
||||
pip install gpt4all
|
||||
```
|
||||
|
||||
In Python, run the following commands to retrieve a GPT4All model and generate a response
|
||||
to a prompt.
|
||||
=== "GPT4All Example"
|
||||
``` py
|
||||
from gpt4all import GPT4All
|
||||
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||
output = model.generate("The capital of France is ", max_tokens=3)
|
||||
print(output)
|
||||
```
|
||||
=== "Output"
|
||||
```
|
||||
1. Paris
|
||||
```
|
||||
|
||||
**Download Note:**
|
||||
By default, models are stored in `~/.cache/gpt4all/` (you can change this with `model_path`). If the file already exists, model download will be skipped.
|
||||
### Chatting with GPT4All
|
||||
Local LLMs can be optimized for chat conversions by reusing previous computational history.
|
||||
|
||||
```python
|
||||
import gpt4all
|
||||
gptj = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
|
||||
messages = [{"role": "user", "content": "Name 3 colors"}]
|
||||
gptj.chat_completion(messages)
|
||||
```
|
||||
Use the GPT4All `chat_session` context manager to hold chat conversations with the model.
|
||||
|
||||
## Give it a try!
|
||||
[Google Colab Tutorial](https://colab.research.google.com/drive/1QRFHV5lj1Kb7_tGZZGZ-E6BfX6izpeMI?usp=sharing)
|
||||
=== "GPT4All Example"
|
||||
``` py
|
||||
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
|
||||
with model.chat_session():
|
||||
response = model.generate(prompt='hello', top_k=1)
|
||||
response = model.generate(prompt='write me a short poem', top_k=1)
|
||||
response = model.generate(prompt='thank you', top_k=1)
|
||||
print(model.current_chat_session)
|
||||
```
|
||||
=== "Output"
|
||||
``` json
|
||||
[
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'hello'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'What is your name?'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'write me a short poem'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': "I would love to help you with that! Here's a short poem I came up with:\nBeneath the autumn leaves,\nThe wind whispers through the trees.\nA gentle breeze, so at ease,\nAs if it were born to play.\nAnd as the sun sets in the sky,\nThe world around us grows still."
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'thank you'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': "You're welcome! I hope this poem was helpful or inspiring for you. Let me know if there is anything else I can assist you with."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
Python bindings support the following ggml architectures: `gptj`, `llama`, `mpt`. See API reference for more details.
|
||||
When using GPT4All models in the chat_session context:
|
||||
|
||||
## Best Practices
|
||||
- The model is given a prompt template which makes it chatty.
|
||||
- Internal K/V caches are preserved from previous conversation history speeding up inference.
|
||||
|
||||
There are two methods to interface with the underlying language model, `chat_completion()` and `generate()`. Chat completion formats a user-provided message dictionary into a prompt template (see API documentation for more details and options). This will usually produce much better results and is the approach we recommend. You may also prompt the model with `generate()` which will just pass the raw input string to the model.
|
||||
|
||||
## API Reference
|
||||
### Generation Parameters
|
||||
|
||||
::: gpt4all.gpt4all.GPT4All.generate
|
||||
|
||||
|
||||
### Streaming Generations
|
||||
To interact with GPT4All responses as the model generates, use the `streaming = True` flag during generation.
|
||||
|
||||
=== "GPT4All Example"
|
||||
``` py
|
||||
from gpt4all import GPT4All
|
||||
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||
tokens = []
|
||||
for token in model.generate("The capital of France is", max_tokens=20, streaming=True):
|
||||
tokens.append(token)
|
||||
print(tokens)
|
||||
```
|
||||
=== "Output"
|
||||
```
|
||||
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
|
||||
```
|
||||
|
||||
::: gpt4all.gpt4all.GPT4All
|
||||
|
@ -6,6 +6,19 @@ Nomic AI oversees contributions to the open-source ecosystem ensuring quality, s
|
||||
|
||||
GPT4All software is optimized to run inference of 7-13 billion parameter large language models on the CPUs of laptops, desktops and servers.
|
||||
|
||||
=== "GPT4All Example"
|
||||
``` py
|
||||
from gpt4all import GPT4All
|
||||
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||
output = model.generate("The capital of France is ", max_tokens=3)
|
||||
print(output)
|
||||
```
|
||||
=== "Output"
|
||||
```
|
||||
1. Paris
|
||||
```
|
||||
See [Python Bindings](gpt4all_python.md) to use GPT4All.
|
||||
|
||||
### Navigating the Documentation
|
||||
In an effort to ensure cross-operating system and cross-language compatibility, the [GPT4All software ecosystem](https://github.com/nomic-ai/gpt4all)
|
||||
is organized as a monorepo with the following structure:
|
||||
@ -18,31 +31,31 @@ This C API is then bound to any higher level programming language such as C++, P
|
||||
|
||||
Explore detailed documentation for the backend, bindings and chat client in the sidebar.
|
||||
## Models
|
||||
The GPT4All software ecosystem is compatible with the following Transformer architectures:
|
||||
|
||||
- `Falcon`
|
||||
- `LLaMA` (including `OpenLLaMA`)
|
||||
- `MPT` (including `Replit`)
|
||||
- `GPTJ`
|
||||
|
||||
You can find an exhaustive list of supported models on the [website](https://gpt4all.io) or in the [models directory](https://raw.githubusercontent.com/nomic-ai/gpt4all/main/gpt4all-chat/metadata/models.json)
|
||||
|
||||
|
||||
GPT4All models are artifacts produced through a process known as neural network quantization.
|
||||
A multi-billion parameter transformer decoder usually takes 30+ GB of VRAM to execute a forward pass.
|
||||
Most people do not have such a powerful computer or access to GPU hardware. By running trained LLMs through quantization algorithms,
|
||||
GPT4All models can run on your laptop using only 4-8GB of RAM enabling their wide-spread utility.
|
||||
|
||||
The GPT4All software ecosystem is currently compatible with three variants of the Transformer neural network architecture:
|
||||
|
||||
- LLaMa
|
||||
|
||||
- GPT-J
|
||||
|
||||
- MPT
|
||||
GPT4All models can run on your laptop using only 4-8GB of RAM enabling their wide-spread usage.
|
||||
|
||||
Any model trained with one of these architectures can be quantized and run locally with all GPT4All bindings and in the
|
||||
chat client. You can add new variants by contributing the gpt4all-backend.
|
||||
|
||||
You can find an exhaustive list of pre-quantized models on the [website](https://gpt4all.io) or in the download pane of the chat client.
|
||||
|
||||
## Frequently Asked Questions
|
||||
Find answers to frequently asked questions by searching the [Github issues](https://github.com/nomic-ai/gpt4all/issues) or in the [documentation FAQ](gpt4all_faq.md).
|
||||
|
||||
## Getting the most of your local LLM
|
||||
|
||||
**Inference Speed**
|
||||
Inference speed of a local LLM depends on two factors: model size and the number of tokens given as input.
|
||||
of a local LLM depends on two factors: model size and the number of tokens given as input.
|
||||
It is not advised to prompt local LLMs with large chunks of context as their inference speed will heavily degrade.
|
||||
You will likely want to run GPT4All models on GPU if you would like to utilize context windows larger than 750 tokens. Native GPU support for GPT4All models is planned.
|
||||
|
||||
|
@ -1,2 +1,2 @@
|
||||
from .pyllmodel import LLModel # noqa
|
||||
from .gpt4all import GPT4All # noqa
|
||||
from .pyllmodel import LLModel # noqa
|
||||
|
@ -2,9 +2,10 @@
|
||||
Python only API for running all GPT4All models.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Dict, List
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Union
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
@ -15,14 +16,14 @@ from . import pyllmodel
|
||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
||||
|
||||
|
||||
class GPT4All():
|
||||
"""Python API for retrieving and interacting with GPT4All models.
|
||||
|
||||
Attributes:
|
||||
model: Pointer to underlying C model.
|
||||
class GPT4All:
|
||||
"""
|
||||
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True, n_threads = None):
|
||||
def __init__(
|
||||
self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None
|
||||
):
|
||||
"""
|
||||
Constructor
|
||||
|
||||
@ -41,11 +42,14 @@ class GPT4All():
|
||||
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
||||
self.model.load_model(model_dest)
|
||||
# Set n_threads
|
||||
if n_threads != None:
|
||||
if n_threads is not None:
|
||||
self.model.set_thread_count(n_threads)
|
||||
|
||||
self._is_chat_session_activated = False
|
||||
self.current_chat_session = []
|
||||
|
||||
@staticmethod
|
||||
def list_models():
|
||||
def list_models() -> Dict:
|
||||
"""
|
||||
Fetch model list from https://gpt4all.io/models/models.json.
|
||||
|
||||
@ -55,8 +59,9 @@ class GPT4All():
|
||||
return requests.get("https://gpt4all.io/models/models.json").json()
|
||||
|
||||
@staticmethod
|
||||
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True,
|
||||
verbose: bool = True) -> str:
|
||||
def retrieve_model(
|
||||
model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Find model file, and if it doesn't exist, download the model.
|
||||
|
||||
@ -78,8 +83,10 @@ class GPT4All():
|
||||
try:
|
||||
os.makedirs(DEFAULT_MODEL_DIRECTORY, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ValueError(f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. "
|
||||
"Please specify model_path.")
|
||||
raise ValueError(
|
||||
f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. "
|
||||
"Please specify model_path."
|
||||
)
|
||||
model_path = DEFAULT_MODEL_DIRECTORY
|
||||
else:
|
||||
model_path = model_path.replace("\\", "\\\\")
|
||||
@ -108,7 +115,7 @@ class GPT4All():
|
||||
raise ValueError(f"Model filename not in model list: {model_filename}")
|
||||
url = selected_model.pop('url', None)
|
||||
|
||||
return GPT4All.download_model(model_filename, model_path, verbose = verbose, url=url)
|
||||
return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
|
||||
else:
|
||||
raise ValueError("Failed to retrieve model")
|
||||
|
||||
@ -126,6 +133,7 @@ class GPT4All():
|
||||
Returns:
|
||||
Model file destination.
|
||||
"""
|
||||
|
||||
def get_download_url(model_filename):
|
||||
if url:
|
||||
return url
|
||||
@ -137,7 +145,7 @@ class GPT4All():
|
||||
|
||||
response = requests.get(download_url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 2 ** 20 # 1 MB
|
||||
block_size = 2**20 # 1 MB
|
||||
|
||||
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
|
||||
try:
|
||||
@ -154,9 +162,7 @@ class GPT4All():
|
||||
|
||||
# Validate download was successful
|
||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
||||
raise RuntimeError(
|
||||
"An error occurred during download. Downloaded file may not work."
|
||||
)
|
||||
raise RuntimeError("An error occurred during download. Downloaded file may not work.")
|
||||
|
||||
# Sleep for a little bit so OS can remove file lock
|
||||
time.sleep(2)
|
||||
@ -165,101 +171,83 @@ class GPT4All():
|
||||
print("Model downloaded at: ", download_path)
|
||||
return download_path
|
||||
|
||||
# TODO: this naming is just confusing now and needs to be deprecated now that we have generator
|
||||
# Need to better consolidate all these different model response methods
|
||||
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int = 200,
|
||||
temp: float = 0.7,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.1,
|
||||
repeat_penalty: float = 1.18,
|
||||
repeat_last_n: int = 64,
|
||||
n_batch: int = 8,
|
||||
n_predict: int = None,
|
||||
streaming: bool = False,
|
||||
) -> Union[str, Iterable]:
|
||||
"""
|
||||
Surfaced method of running generate without accessing model object.
|
||||
Generate outputs from any GPT4All model.
|
||||
|
||||
Args:
|
||||
prompt: Raw string to be passed to model.
|
||||
streaming: True if want output streamed to stdout.
|
||||
**generate_kwargs: Optional kwargs to pass to prompt context.
|
||||
prompt: The prompt for the model the complete.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
temp: The model temperature. Larger values increase creativity but decrease factuality.
|
||||
top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.
|
||||
top_p: Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p.
|
||||
repeat_penalty: Penalize the model for repetition. Higher values result in less repetition.
|
||||
repeat_last_n: How far in the models generation history to apply the repeat penalty.
|
||||
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
|
||||
n_predict: Equivalent to max_tokens, exists for backwards compatability.
|
||||
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
|
||||
|
||||
Returns:
|
||||
Raw string of generated model response.
|
||||
Either the entire completion or a generator that yields the completion token by token.
|
||||
"""
|
||||
return self.model.prompt_model(prompt, streaming=streaming, **generate_kwargs)
|
||||
generate_kwargs = locals()
|
||||
generate_kwargs.pop('self')
|
||||
generate_kwargs.pop('max_tokens')
|
||||
generate_kwargs.pop('streaming')
|
||||
generate_kwargs['n_predict'] = max_tokens
|
||||
if n_predict is not None:
|
||||
generate_kwargs['n_predict'] = n_predict
|
||||
|
||||
def generator(self, prompt: str, **generate_kwargs) -> str:
|
||||
"""
|
||||
Surfaced method of running generate without accessing model object.
|
||||
if streaming and self._is_chat_session_activated:
|
||||
raise NotImplementedError("Streaming tokens in a chat session is not currently supported.")
|
||||
|
||||
Args:
|
||||
prompt: Raw string to be passed to model.
|
||||
streaming: True if want output streamed to stdout.
|
||||
**generate_kwargs: Optional kwargs to pass to prompt context.
|
||||
if self._is_chat_session_activated:
|
||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session)
|
||||
generate_kwargs['reset_context'] = len(self.current_chat_session) == 1
|
||||
else:
|
||||
generate_kwargs['reset_context'] = True
|
||||
|
||||
Returns:
|
||||
Raw string of generated model response.
|
||||
"""
|
||||
return self.model.generator(prompt, **generate_kwargs)
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(**generate_kwargs)
|
||||
|
||||
def chat_completion(self,
|
||||
messages: List[Dict],
|
||||
default_prompt_header: bool = True,
|
||||
default_prompt_footer: bool = True,
|
||||
verbose: bool = True,
|
||||
streaming: bool = True,
|
||||
**generate_kwargs) -> dict:
|
||||
"""
|
||||
Format list of message dictionaries into a prompt and call model
|
||||
generate on prompt. Returns a response dictionary with metadata and
|
||||
generated content.
|
||||
output = self.model.prompt_model(**generate_kwargs)
|
||||
|
||||
Args:
|
||||
messages: List of dictionaries. Each dictionary should have a "role" key
|
||||
with value of "system", "assistant", or "user" and a "content" key with a
|
||||
string value. Messages are organized such that "system" messages are at top of prompt,
|
||||
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
|
||||
"Response: {content}".
|
||||
default_prompt_header: If True (default), add default prompt header after any system role messages and
|
||||
before user/assistant role messages.
|
||||
default_prompt_footer: If True (default), add default footer at end of prompt.
|
||||
verbose: If True (default), print full prompt and generated response.
|
||||
streaming: True if want output streamed to stdout.
|
||||
**generate_kwargs: Optional kwargs to pass to prompt context.
|
||||
if self._is_chat_session_activated:
|
||||
self.current_chat_session.append({"role": "assistant", "content": output})
|
||||
|
||||
Returns:
|
||||
Response dictionary with:
|
||||
"model": name of model.
|
||||
"usage": a dictionary with number of full prompt tokens, number of
|
||||
generated tokens in response, and total tokens.
|
||||
"choices": List of message dictionary where "content" is generated response and "role" is set
|
||||
as "assistant". Right now, only one choice is returned by model.
|
||||
"""
|
||||
full_prompt = self._build_prompt(messages,
|
||||
default_prompt_header=default_prompt_header,
|
||||
default_prompt_footer=default_prompt_footer)
|
||||
if verbose:
|
||||
print(full_prompt)
|
||||
return output
|
||||
|
||||
response = self.model.prompt_model(full_prompt, streaming=streaming, **generate_kwargs)
|
||||
@contextmanager
|
||||
def chat_session(self):
|
||||
'''
|
||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||
'''
|
||||
# Code to acquire resource, e.g.:
|
||||
self._is_chat_session_activated = True
|
||||
self._current_chat_session = []
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# Code to release resource, e.g.:
|
||||
self._is_chat_session_activated = False
|
||||
self._current_chat_session = []
|
||||
|
||||
if verbose and not streaming:
|
||||
print(response)
|
||||
|
||||
response_dict = {
|
||||
"model": self.model.model_name,
|
||||
"usage": {"prompt_tokens": len(full_prompt),
|
||||
"completion_tokens": len(response),
|
||||
"total_tokens": len(full_prompt) + len(response)},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return response_dict
|
||||
|
||||
@staticmethod
|
||||
def _build_prompt(messages: List[Dict],
|
||||
default_prompt_header=True,
|
||||
default_prompt_footer=True) -> str:
|
||||
def _format_chat_prompt_template(
|
||||
self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True
|
||||
) -> str:
|
||||
"""
|
||||
Helper method for building a prompt using template from list of messages.
|
||||
|
||||
@ -269,37 +257,20 @@ class GPT4All():
|
||||
string value. Messages are organized such that "system" messages are at top of prompt,
|
||||
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
|
||||
"Response: {content}".
|
||||
default_prompt_header: If True (default), add default prompt header after any system role messages and
|
||||
before user/assistant role messages.
|
||||
default_prompt_footer: If True (default), add default footer at end of prompt.
|
||||
|
||||
Returns:
|
||||
Formatted prompt.
|
||||
"""
|
||||
full_prompt = ""
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_message = message["content"] + "\n"
|
||||
full_prompt += system_message
|
||||
|
||||
if default_prompt_header:
|
||||
full_prompt += """### Instruction:
|
||||
The prompt below is a question to answer, a task to complete, or a conversation
|
||||
to respond to; decide which and write an appropriate response.
|
||||
\n### Prompt: """
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
user_message = "\n" + message["content"]
|
||||
user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n"
|
||||
full_prompt += user_message
|
||||
if message["role"] == "assistant":
|
||||
assistant_message = "\n### Response: " + message["content"]
|
||||
assistant_message = message["content"] + '\n'
|
||||
full_prompt += assistant_message
|
||||
|
||||
if default_prompt_footer:
|
||||
full_prompt += "\n### Response:"
|
||||
|
||||
return full_prompt
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import pkg_resources
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
@ -7,6 +6,10 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from typing import Iterable
|
||||
|
||||
import pkg_resources
|
||||
|
||||
|
||||
class DualStreamProcessor:
|
||||
def __init__(self, stream=None):
|
||||
@ -19,10 +22,12 @@ class DualStreamProcessor:
|
||||
self.stream.flush()
|
||||
self.output += text
|
||||
|
||||
|
||||
# TODO: provide a config file to make this more robust
|
||||
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
||||
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
|
||||
|
||||
|
||||
def load_llmodel_library():
|
||||
system = platform.system()
|
||||
|
||||
@ -40,21 +45,25 @@ def load_llmodel_library():
|
||||
|
||||
llmodel_file = "libllmodel" + '.' + c_lib_ext
|
||||
|
||||
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', \
|
||||
os.path.join(LLMODEL_PATH, llmodel_file))).replace("\\", "\\\\")
|
||||
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace(
|
||||
"\\", "\\\\"
|
||||
)
|
||||
|
||||
llmodel_lib = ctypes.CDLL(llmodel_dir)
|
||||
|
||||
return llmodel_lib
|
||||
|
||||
|
||||
llmodel = load_llmodel_library()
|
||||
|
||||
|
||||
class LLModelError(ctypes.Structure):
|
||||
_fields_ = [("message", ctypes.c_char_p),
|
||||
("code", ctypes.c_int32)]
|
||||
_fields_ = [("message", ctypes.c_char_p), ("code", ctypes.c_int32)]
|
||||
|
||||
|
||||
class LLModelPromptContext(ctypes.Structure):
|
||||
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
|
||||
_fields_ = [
|
||||
("logits", ctypes.POINTER(ctypes.c_float)),
|
||||
("logits_size", ctypes.c_size_t),
|
||||
("tokens", ctypes.POINTER(ctypes.c_int32)),
|
||||
("tokens_size", ctypes.c_size_t),
|
||||
@ -67,7 +76,9 @@ class LLModelPromptContext(ctypes.Structure):
|
||||
("n_batch", ctypes.c_int32),
|
||||
("repeat_penalty", ctypes.c_float),
|
||||
("repeat_last_n", ctypes.c_int32),
|
||||
("context_erase", ctypes.c_float)]
|
||||
("context_erase", ctypes.c_float),
|
||||
]
|
||||
|
||||
|
||||
# Define C function signatures using ctypes
|
||||
llmodel.llmodel_model_create.argtypes = [ctypes.c_char_p]
|
||||
@ -90,12 +101,14 @@ PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||
|
||||
llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,
|
||||
llmodel.llmodel_prompt.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_char_p,
|
||||
PromptCallback,
|
||||
ResponseCallback,
|
||||
RecalculateCallback,
|
||||
ctypes.POINTER(LLModelPromptContext)]
|
||||
ctypes.POINTER(LLModelPromptContext),
|
||||
]
|
||||
|
||||
llmodel.llmodel_prompt.restype = None
|
||||
|
||||
@ -142,7 +155,6 @@ class LLModel:
|
||||
else:
|
||||
raise ValueError("Unable to instantiate model")
|
||||
|
||||
|
||||
def load_model(self, model_path: str) -> bool:
|
||||
"""
|
||||
Load model from a file.
|
||||
@ -182,21 +194,59 @@ class LLModel:
|
||||
raise Exception("Model not loaded")
|
||||
return llmodel.llmodel_threadCount(self.model)
|
||||
|
||||
def prompt_model(self,
|
||||
prompt: str,
|
||||
logits_size: int = 0,
|
||||
tokens_size: int = 0,
|
||||
n_past: int = 0,
|
||||
n_ctx: int = 1024,
|
||||
n_predict: int = 128,
|
||||
def _set_context(
|
||||
self,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = .9,
|
||||
temp: float = .1,
|
||||
top_p: float = 0.9,
|
||||
temp: float = 0.1,
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = .5,
|
||||
streaming: bool = True) -> str:
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
):
|
||||
if self.context is None:
|
||||
self.context = LLModelPromptContext(
|
||||
logits_size=0,
|
||||
tokens_size=0,
|
||||
n_past=0,
|
||||
n_ctx=0,
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
n_batch=n_batch,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
context_erase=context_erase,
|
||||
)
|
||||
elif reset_context:
|
||||
self.context.n_past = 0
|
||||
|
||||
self.context.n_predict = n_predict
|
||||
self.context.top_k = top_k
|
||||
self.context.top_p = top_p
|
||||
self.context.temp = temp
|
||||
self.context.n_batch = n_batch
|
||||
self.context.repeat_penalty = repeat_penalty
|
||||
self.context.repeat_last_n = repeat_last_n
|
||||
self.context.context_erase = context_erase
|
||||
|
||||
def prompt_model(
|
||||
self,
|
||||
prompt: str,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.9,
|
||||
temp: float = 0.1,
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
streaming=False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
|
||||
@ -224,13 +274,7 @@ class LLModel:
|
||||
|
||||
sys.stdout = stream_processor
|
||||
|
||||
|
||||
if self.context is None:
|
||||
self.context = LLModelPromptContext(
|
||||
logits_size=logits_size,
|
||||
tokens_size=tokens_size,
|
||||
n_past=n_past,
|
||||
n_ctx=n_ctx,
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
@ -238,37 +282,37 @@ class LLModel:
|
||||
n_batch=n_batch,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
context_erase=context_erase
|
||||
context_erase=context_erase,
|
||||
reset_context=reset_context,
|
||||
)
|
||||
|
||||
llmodel.llmodel_prompt(self.model,
|
||||
llmodel.llmodel_prompt(
|
||||
self.model,
|
||||
prompt,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._response_callback),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context)
|
||||
self.context,
|
||||
)
|
||||
|
||||
# Revert to old stdout
|
||||
sys.stdout = old_stdout
|
||||
# Force new line
|
||||
print()
|
||||
return stream_processor.output
|
||||
|
||||
def generator(self,
|
||||
def prompt_model_streaming(
|
||||
self,
|
||||
prompt: str,
|
||||
logits_size: int = 0,
|
||||
tokens_size: int = 0,
|
||||
n_past: int = 0,
|
||||
n_ctx: int = 1024,
|
||||
n_predict: int = 128,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = .9,
|
||||
temp: float = .1,
|
||||
top_p: float = 0.9,
|
||||
temp: float = 0.1,
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = .5) -> str:
|
||||
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
) -> Iterable:
|
||||
# Symbol to terminate from generator
|
||||
TERMINATING_SYMBOL = "#TERMINATE#"
|
||||
|
||||
@ -277,12 +321,7 @@ class LLModel:
|
||||
prompt = prompt.encode('utf-8')
|
||||
prompt = ctypes.c_char_p(prompt)
|
||||
|
||||
if self.context is None:
|
||||
self.context = LLModelPromptContext(
|
||||
logits_size=logits_size,
|
||||
tokens_size=tokens_size,
|
||||
n_past=n_past,
|
||||
n_ctx=n_ctx,
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
@ -290,7 +329,8 @@ class LLModel:
|
||||
n_batch=n_batch,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
context_erase=context_erase
|
||||
context_erase=context_erase,
|
||||
reset_context=reset_context,
|
||||
)
|
||||
|
||||
# Put response tokens into an output queue
|
||||
@ -298,30 +338,23 @@ class LLModel:
|
||||
output_queue.put(response.decode('utf-8', 'replace'))
|
||||
return True
|
||||
|
||||
def run_llmodel_prompt(model,
|
||||
prompt,
|
||||
prompt_callback,
|
||||
response_callback,
|
||||
recalculate_callback,
|
||||
context):
|
||||
llmodel.llmodel_prompt(model,
|
||||
prompt,
|
||||
prompt_callback,
|
||||
response_callback,
|
||||
recalculate_callback,
|
||||
context)
|
||||
def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context):
|
||||
llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context)
|
||||
output_queue.put(TERMINATING_SYMBOL)
|
||||
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
# immediately
|
||||
thread = threading.Thread(target=run_llmodel_prompt,
|
||||
args=(self.model,
|
||||
thread = threading.Thread(
|
||||
target=run_llmodel_prompt,
|
||||
args=(
|
||||
self.model,
|
||||
prompt,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(_generator_response_callback),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context))
|
||||
self.context,
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
|
||||
# Generator
|
||||
|
54
gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py
Normal file
54
gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py
Normal file
@ -0,0 +1,54 @@
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from gpt4all import GPT4All
|
||||
|
||||
|
||||
def test_inference():
|
||||
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
|
||||
output_1 = model.generate('hello', top_k=1)
|
||||
|
||||
with model.chat_session():
|
||||
response = model.generate(prompt='hello', top_k=1)
|
||||
response = model.generate(prompt='write me a short poem', top_k=1)
|
||||
response = model.generate(prompt='thank you', top_k=1)
|
||||
print(model.current_chat_session)
|
||||
|
||||
output_2 = model.generate('hello', top_k=1)
|
||||
|
||||
assert output_1 == output_2
|
||||
|
||||
tokens = []
|
||||
for token in model.generate('hello', streaming=True):
|
||||
tokens.append(token)
|
||||
|
||||
assert len(tokens) > 0
|
||||
|
||||
with model.chat_session():
|
||||
try:
|
||||
response = model.generate(prompt='hello', top_k=1, streaming=True)
|
||||
assert False
|
||||
except NotImplementedError:
|
||||
assert True
|
||||
|
||||
|
||||
def test_inference_hparams():
|
||||
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
|
||||
|
||||
output = model.generate("The capital of france is ", max_tokens=3)
|
||||
assert 'Paris' in output
|
||||
|
||||
|
||||
def test_inference_falcon():
|
||||
model = GPT4All(model_name='ggml-model-gpt4all-falcon-q4_0.bin')
|
||||
prompt = 'hello'
|
||||
output = model.generate(prompt)
|
||||
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
def test_inference_mpt():
|
||||
model = GPT4All(model_name='ggml-mpt-7b-chat.bin')
|
||||
prompt = 'hello'
|
||||
output = model.generate(prompt)
|
||||
assert len(output) > 0
|
@ -14,3 +14,12 @@ wheel:
|
||||
clean:
|
||||
rm -rf {.pytest_cache,env,gpt4all.egg-info}
|
||||
find . | grep -E "(__pycache__|\.pyc|\.pyo$\)" | xargs rm -rf
|
||||
|
||||
black:
|
||||
source env/bin/activate; black -l 120 -S --target-version py36 gpt4all
|
||||
|
||||
isort:
|
||||
source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all
|
||||
|
||||
test:
|
||||
source env/bin/activate; pytest -s gpt4all/tests
|
@ -9,10 +9,12 @@ use_directory_urls: false
|
||||
|
||||
nav:
|
||||
- 'index.md'
|
||||
- 'gpt4all_chat.md'
|
||||
- 'gpt4all_python.md'
|
||||
- 'Tutorials':
|
||||
- 'gpt4all_modal.md'
|
||||
- 'Bindings':
|
||||
- 'GPT4All in Python': 'gpt4all_python.md'
|
||||
- 'GPT4All Chat Client': 'gpt4all_chat.md'
|
||||
- 'gpt4all_cli.md'
|
||||
# - 'Tutorials':
|
||||
# - 'gpt4all_modal.md'
|
||||
- 'Wiki':
|
||||
- 'gpt4all_faq.md'
|
||||
|
||||
|
@ -61,7 +61,7 @@ copy_prebuilt_C_lib(SRC_CLIB_DIRECtORY,
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.3.6",
|
||||
version="1.0.0",
|
||||
description="Python bindings for GPT4All",
|
||||
author="Richard Guo",
|
||||
author_email="richard@nomic.ai",
|
||||
@ -83,7 +83,9 @@ setup(
|
||||
'mkdocs-material',
|
||||
'mkautodoc',
|
||||
'mkdocstrings[python]',
|
||||
'mkdocs-jupyter'
|
||||
'mkdocs-jupyter',
|
||||
'black',
|
||||
'isort'
|
||||
]
|
||||
},
|
||||
package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},
|
||||
|
@ -1,62 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from gpt4all.gpt4all import GPT4All
|
||||
|
||||
def test_invalid_model_type():
|
||||
model_type = "bad_type"
|
||||
with pytest.raises(ValueError):
|
||||
GPT4All.get_model_from_type(model_type)
|
||||
|
||||
def test_valid_model_type():
|
||||
model_type = "gptj"
|
||||
assert GPT4All.get_model_from_type(model_type).model_type == model_type
|
||||
|
||||
def test_invalid_model_name():
|
||||
model_name = "bad_filename.bin"
|
||||
with pytest.raises(ValueError):
|
||||
GPT4All.get_model_from_name(model_name)
|
||||
|
||||
def test_valid_model_name():
|
||||
model_name = "ggml-gpt4all-l13b-snoozy"
|
||||
model_type = "llama"
|
||||
assert GPT4All.get_model_from_name(model_name).model_type == model_type
|
||||
model_name += ".bin"
|
||||
assert GPT4All.get_model_from_name(model_name).model_type == model_type
|
||||
|
||||
def test_build_prompt():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello there."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi, how can I help you?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Reverse a list in Python."
|
||||
}
|
||||
]
|
||||
|
||||
expected_prompt = """You are a helpful assistant.\
|
||||
\n### Instruction:
|
||||
The prompt below is a question to answer, a task to complete, or a conversation
|
||||
to respond to; decide which and write an appropriate response.\
|
||||
### Prompt:\
|
||||
Hello there.\
|
||||
Response: Hi, how can I help you?\
|
||||
Reverse a list in Python.\
|
||||
### Response:"""
|
||||
|
||||
print(expected_prompt)
|
||||
|
||||
full_prompt = GPT4All._build_prompt(messages, default_prompt_footer=True, default_prompt_header=True)
|
||||
|
||||
print("\n\n\n")
|
||||
print(full_prompt)
|
||||
assert len(full_prompt) == len(expected_prompt)
|
@ -1,61 +0,0 @@
|
||||
from io import StringIO
|
||||
import sys
|
||||
|
||||
from gpt4all import pyllmodel
|
||||
|
||||
# TODO: Integration test for loadmodel and prompt.
|
||||
# # Right now, too slow b/c it requires file download.
|
||||
|
||||
def test_create_gptj():
|
||||
gptj = pyllmodel.GPTJModel()
|
||||
assert gptj.model_type == "gptj"
|
||||
|
||||
def test_create_llama():
|
||||
llama = pyllmodel.LlamaModel()
|
||||
assert llama.model_type == "llama"
|
||||
|
||||
def test_create_mpt():
|
||||
mpt = pyllmodel.MPTModel()
|
||||
assert mpt.model_type == "mpt"
|
||||
|
||||
def prompt_unloaded_mpt():
|
||||
mpt = pyllmodel.MPTModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
mpt.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "MPT ERROR: prompt won't work with an unloaded model!"
|
||||
|
||||
def prompt_unloaded_gptj():
|
||||
gptj = pyllmodel.GPTJModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
gptj.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "GPT-J ERROR: prompt won't work with an unloaded model!"
|
||||
|
||||
def prompt_unloaded_llama():
|
||||
llama = pyllmodel.LlamaModel()
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
|
||||
llama.prompt("hello there")
|
||||
|
||||
response = collect_response.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
response = response.strip()
|
||||
assert response == "LLAMA ERROR: prompt won't work with an unloaded model!"
|
Loading…
Reference in New Issue
Block a user