Python bindings: Custom callbacks, chat session improvement, refactoring (#1145)

* Added the following features: \n 1) Now prompt_model uses the positional argument callback to return the response tokens. \n 2) Due to the callback argument of prompt_model, prompt_model_streaming only manages the queue and threading now, which reduces duplication of the code. \n 3) Added optional verbose argument to prompt_model which prints out the prompt that is passed to the model. \n 4) Chat sessions can now have a header, i.e. an instruction before the transcript of the conversation. The header is set at the creation of the chat session context. \n 5) generate function now accepts an optional callback. \n 6) When streaming and using chat session, the user doesn't need to save assistant's messages by himself. This is done automatically.

* added _empty_response_callback so I don't have to check if callback is None

* added docs

* now if the callback stop generation, the last token is ignored

* fixed type hints, reimplemented chat session header as a system prompt, minor refactoring, docs: removed section about manual update of chat session for streaming

* forgot to add some type hints!

* keep the config of the model in GPT4All class which is taken from models.json if the download is allowed

* During chat sessions, the model-specific systemPrompt and promptTemplate are applied.

* implemented the changes

* Fixed typing. Now the user can set a prompt template that will be applied even outside of a chat session. The template can also have multiple placeholders that can be filled by passing a dictionary to the generate function

* reversed some changes concerning the prompt templates and their functionality

* fixed some type hints, changed list[float] to List[Float]

* fixed type hints, changed List[Float] to List[float]

* fix typo in the comment: Pepare => Prepare

---------

Signed-off-by: 385olt <385olt@gmail.com>
This commit is contained in:
385olt 2023-07-20 00:36:49 +02:00 committed by GitHub
parent 5f0aaf8bdb
commit b4dbbd1485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 213 additions and 146 deletions

View File

@ -91,22 +91,4 @@ To interact with GPT4All responses as the model generates, use the `streaming =
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0'] [' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
``` ```
#### Streaming and Chat Sessions
When streaming tokens in a chat session, you must manually handle collection and updating of the chat history.
```python
from gpt4all import GPT4All
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
with model.chat_session():
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
print(model.current_chat_session)
```
### API documentation
::: gpt4all.gpt4all.GPT4All ::: gpt4all.gpt4all.GPT4All

View File

@ -5,7 +5,7 @@ import os
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Union, Optional from typing import Any, Dict, Iterable, List, Union, Optional
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -13,7 +13,17 @@ from tqdm import tqdm
from . import pyllmodel from . import pyllmodel
# TODO: move to config # TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace(
"\\", "\\\\"
)
DEFAULT_MODEL_CONFIG = {
"systemPrompt": "",
"promptTemplate": "### Human: \n{0}\n### Assistant:\n",
}
ConfigType = Dict[str,str]
MessageType = Dict[str, str]
class Embed4All: class Embed4All:
""" """
@ -34,7 +44,7 @@ class Embed4All:
def embed( def embed(
self, self,
text: str text: str
) -> list[float]: ) -> List[float]:
""" """
Generate an embedding. Generate an embedding.
@ -74,17 +84,20 @@ class GPT4All:
self.model_type = model_type self.model_type = model_type
self.model = pyllmodel.LLModel() self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed # Retrieve model and download if allowed
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download) self.config: ConfigType = self.retrieve_model(
self.model.load_model(model_dest) model_name, model_path=model_path, allow_download=allow_download
)
self.model.load_model(self.config["path"])
# Set n_threads # Set n_threads
if n_threads is not None: if n_threads is not None:
self.model.set_thread_count(n_threads) self.model.set_thread_count(n_threads)
self._is_chat_session_activated = False self._is_chat_session_activated: bool = False
self.current_chat_session = [] self.current_chat_session: List[MessageType] = empty_chat_session()
self._current_prompt_template: str = "{0}"
@staticmethod @staticmethod
def list_models() -> Dict: def list_models() -> List[ConfigType]:
""" """
Fetch model list from https://gpt4all.io/models/models.json. Fetch model list from https://gpt4all.io/models/models.json.
@ -95,8 +108,11 @@ class GPT4All:
@staticmethod @staticmethod
def retrieve_model( def retrieve_model(
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True model_name: str,
) -> str: model_path: Optional[str] = None,
allow_download: bool = True,
verbose: bool = True,
) -> ConfigType:
""" """
Find model file, and if it doesn't exist, download the model. Find model file, and if it doesn't exist, download the model.
@ -108,11 +124,25 @@ class GPT4All:
verbose: If True (default), print debug messages. verbose: If True (default), print debug messages.
Returns: Returns:
Model file destination. Model config.
""" """
model_filename = append_bin_suffix_if_missing(model_name) model_filename = append_bin_suffix_if_missing(model_name)
# get the config for the model
config: ConfigType = DEFAULT_MODEL_CONFIG
if allow_download:
available_models = GPT4All.list_models()
for m in available_models:
if model_filename == m["filename"]:
config.update(m)
config["systemPrompt"] = config["systemPrompt"].strip()
config["promptTemplate"] = config["promptTemplate"].replace(
"%1", "{0}", 1
) # change to Python-style formatting
break
# Validate download directory # Validate download directory
if model_path is None: if model_path is None:
try: try:
@ -131,31 +161,34 @@ class GPT4All:
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\") model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
if os.path.exists(model_dest): if os.path.exists(model_dest):
config.pop("url", None)
config["path"] = model_dest
if verbose: if verbose:
print("Found model file at ", model_dest) print("Found model file at ", model_dest)
return model_dest
# If model file does not exist, download # If model file does not exist, download
elif allow_download: elif allow_download:
# Make sure valid model filename before attempting download # Make sure valid model filename before attempting download
available_models = GPT4All.list_models()
selected_model = None if "url" not in config:
for m in available_models:
if model_filename == m['filename']:
selected_model = m
break
if selected_model is None:
raise ValueError(f"Model filename not in model list: {model_filename}") raise ValueError(f"Model filename not in model list: {model_filename}")
url = selected_model.pop('url', None) url = config.pop("url", None)
return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url) config["path"] = GPT4All.download_model(
model_filename, model_path, verbose=verbose, url=url
)
else: else:
raise ValueError("Failed to retrieve model") raise ValueError("Failed to retrieve model")
return config
@staticmethod @staticmethod
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str: def download_model(
model_filename: str,
model_path: str,
verbose: bool = True,
url: Optional[str] = None,
) -> str:
""" """
Download model from https://gpt4all.io. Download model from https://gpt4all.io.
@ -191,7 +224,7 @@ class GPT4All:
except Exception: except Exception:
if os.path.exists(download_path): if os.path.exists(download_path):
if verbose: if verbose:
print('Cleaning up the interrupted download...') print("Cleaning up the interrupted download...")
os.remove(download_path) os.remove(download_path)
raise raise
@ -218,7 +251,8 @@ class GPT4All:
n_batch: int = 8, n_batch: int = 8,
n_predict: Optional[int] = None, n_predict: Optional[int] = None,
streaming: bool = False, streaming: bool = False,
) -> Union[str, Iterable]: callback: pyllmodel.ResponseCallbackType = pyllmodel.empty_response_callback,
) -> Union[str, Iterable[str]]:
""" """
Generate outputs from any GPT4All model. Generate outputs from any GPT4All model.
@ -233,12 +267,14 @@ class GPT4All:
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements. n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
n_predict: Equivalent to max_tokens, exists for backwards compatibility. n_predict: Equivalent to max_tokens, exists for backwards compatibility.
streaming: If True, this method will instead return a generator that yields tokens as the model generates them. streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False.
Returns: Returns:
Either the entire completion or a generator that yields the completion token by token. Either the entire completion or a generator that yields the completion token by token.
""" """
generate_kwargs = dict(
prompt=prompt, # Preparing the model request
generate_kwargs: Dict[str, Any] = dict(
temp=temp, temp=temp,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
@ -249,42 +285,87 @@ class GPT4All:
) )
if self._is_chat_session_activated: if self._is_chat_session_activated:
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt
self.current_chat_session.append({"role": "user", "content": prompt}) self.current_chat_session.append({"role": "user", "content": prompt})
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session[-1:])
generate_kwargs['reset_context'] = len(self.current_chat_session) == 1 prompt = self._format_chat_prompt_template(
messages = self.current_chat_session[-1:],
default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "",
)
else: else:
generate_kwargs['reset_context'] = True generate_kwargs["reset_context"] = True
if streaming: # Prepare the callback, process the model response
return self.model.prompt_model_streaming(**generate_kwargs) output_collector: List[MessageType]
output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated
output = self.model.prompt_model(**generate_kwargs)
if self._is_chat_session_activated: if self._is_chat_session_activated:
self.current_chat_session.append({"role": "assistant", "content": output}) self.current_chat_session.append({"role": "assistant", "content": ""})
output_collector = self.current_chat_session
return output def _callback_wrapper(
callback: pyllmodel.ResponseCallbackType,
output_collector: List[MessageType],
) -> pyllmodel.ResponseCallbackType:
def _callback(token_id: int, response: str) -> bool:
nonlocal callback, output_collector
output_collector[-1]["content"] += response
return callback(token_id, response)
return _callback
# Send the request to the model
if streaming:
return self.model.prompt_model_streaming(
prompt=prompt,
callback=_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
self.model.prompt_model(
prompt=prompt,
callback=_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
return output_collector[-1]["content"]
@contextmanager @contextmanager
def chat_session(self): def chat_session(
''' self,
system_prompt: str = "",
prompt_template: str = "",
):
"""
Context manager to hold an inference optimized chat session with a GPT4All model. Context manager to hold an inference optimized chat session with a GPT4All model.
'''
Args:
system_prompt: An initial instruction for the model.
prompt_template: Template for the prompts with {0} being replaced by the user message.
"""
# Code to acquire resource, e.g.: # Code to acquire resource, e.g.:
self._is_chat_session_activated = True self._is_chat_session_activated = True
self.current_chat_session = [] self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"])
self._current_prompt_template = prompt_template or self.config["promptTemplate"]
try: try:
yield self yield self
finally: finally:
# Code to release resource, e.g.: # Code to release resource, e.g.:
self._is_chat_session_activated = False self._is_chat_session_activated = False
self.current_chat_session = [] self.current_chat_session = empty_chat_session()
self._current_prompt_template = "{0}"
def _format_chat_prompt_template( def _format_chat_prompt_template(
self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True self,
messages: List[MessageType],
default_prompt_header: str = "",
default_prompt_footer: str = "",
) -> str: ) -> str:
""" """
Helper method for building a prompt using template from list of messages. Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
Args: Args:
messages: List of dictionaries. Each dictionary should have a "role" key messages: List of dictionaries. Each dictionary should have a "role" key
@ -296,19 +377,44 @@ class GPT4All:
Returns: Returns:
Formatted prompt. Formatted prompt.
""" """
full_prompt = ""
if isinstance(default_prompt_header, bool):
import warnings
warnings.warn(
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
DeprecationWarning,
)
default_prompt_header = ""
if isinstance(default_prompt_footer, bool):
import warnings
warnings.warn(
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
DeprecationWarning,
)
default_prompt_footer = ""
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
for message in messages: for message in messages:
if message["role"] == "user": if message["role"] == "user":
user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n" user_message = self._current_prompt_template.format(message["content"])
full_prompt += user_message full_prompt += user_message
if message["role"] == "assistant": if message["role"] == "assistant":
assistant_message = message["content"] + '\n' assistant_message = message["content"] + "\n"
full_prompt += assistant_message full_prompt += assistant_message
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
return full_prompt return full_prompt
def empty_chat_session(system_prompt: str = "") -> List[MessageType]:
return [{"role": "system", "content": system_prompt}]
def append_bin_suffix_if_missing(model_name): def append_bin_suffix_if_missing(model_name):
if not model_name.endswith(".bin"): if not model_name.endswith(".bin"):
model_name += ".bin" model_name += ".bin"

View File

@ -6,26 +6,19 @@ import re
import subprocess import subprocess
import sys import sys
import threading import threading
from typing import Iterable import logging
from typing import Iterable, Callable, List
import pkg_resources import pkg_resources
logger: logging.Logger = logging.getLogger(__name__)
class DualStreamProcessor:
def __init__(self, stream=None):
self.stream = stream
self.output = ""
def write(self, text):
if self.stream is not None:
self.stream.write(text)
self.stream.flush()
self.output += text
# TODO: provide a config file to make this more robust # TODO: provide a config file to make this more robust
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\") LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\") MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace(
"\\", "\\\\"
)
def load_llmodel_library(): def load_llmodel_library():
@ -43,9 +36,9 @@ def load_llmodel_library():
c_lib_ext = get_c_shared_lib_extension() c_lib_ext = get_c_shared_lib_extension()
llmodel_file = "libllmodel" + '.' + c_lib_ext llmodel_file = "libllmodel" + "." + c_lib_ext
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace( llmodel_dir = str(pkg_resources.resource_filename("gpt4all", os.path.join(LLMODEL_PATH, llmodel_file))).replace(
"\\", "\\\\" "\\", "\\\\"
) )
@ -134,7 +127,15 @@ llmodel.llmodel_set_implementation_search_path.restype = None
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p] llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
llmodel.llmodel_threadCount.restype = ctypes.c_int32 llmodel.llmodel_threadCount.restype = ctypes.c_int32
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode('utf-8')) llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode("utf-8"))
ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool]
def empty_response_callback(token_id: int, response: str) -> bool:
return True
class LLModel: class LLModel:
@ -250,9 +251,10 @@ class LLModel:
def generate_embedding( def generate_embedding(
self, self,
text: str text: str
) -> list[float]: ) -> List[float]:
if not text: if not text:
raise ValueError("Text must not be None or empty") raise ValueError("Text must not be None or empty")
embedding_size = ctypes.c_size_t() embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode('utf-8')) c_text = ctypes.c_char_p(text.encode('utf-8'))
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size)) embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
@ -263,6 +265,7 @@ class LLModel:
def prompt_model( def prompt_model(
self, self,
prompt: str, prompt: str,
callback: ResponseCallbackType,
n_predict: int = 4096, n_predict: int = 4096,
top_k: int = 40, top_k: int = 40,
top_p: float = 0.9, top_p: float = 0.9,
@ -272,8 +275,7 @@ class LLModel:
repeat_last_n: int = 10, repeat_last_n: int = 10,
context_erase: float = 0.75, context_erase: float = 0.75,
reset_context: bool = False, reset_context: bool = False,
streaming=False, ):
) -> str:
""" """
Generate response from model from a prompt. Generate response from model from a prompt.
@ -281,26 +283,24 @@ class LLModel:
---------- ----------
prompt: str prompt: str
Question, task, or conversation for model to respond to Question, task, or conversation for model to respond to
streaming: bool callback(token_id:int, response:str): bool
Stream response to stdout The model sends response tokens to callback
Returns Returns
------- -------
Model response str None
""" """
prompt_bytes = prompt.encode('utf-8') logger.info(
"LLModel.prompt_model -- prompt:\n"
+ "%s\n"
+ "===/LLModel.prompt_model -- prompt/===",
prompt,
)
prompt_bytes = prompt.encode("utf-8")
prompt_ptr = ctypes.c_char_p(prompt_bytes) prompt_ptr = ctypes.c_char_p(prompt_bytes)
old_stdout = sys.stdout
stream_processor = DualStreamProcessor()
if streaming:
stream_processor.stream = sys.stdout
sys.stdout = stream_processor
self._set_context( self._set_context(
n_predict=n_predict, n_predict=n_predict,
top_k=top_k, top_k=top_k,
@ -317,56 +317,37 @@ class LLModel:
self.model, self.model,
prompt_ptr, prompt_ptr,
PromptCallback(self._prompt_callback), PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback), ResponseCallback(self._callback_decoder(callback)),
RecalculateCallback(self._recalculate_callback), RecalculateCallback(self._recalculate_callback),
self.context, self.context,
) )
# Revert to old stdout
sys.stdout = old_stdout
# Force new line
return stream_processor.output
def prompt_model_streaming( def prompt_model_streaming(
self, self,
prompt: str, prompt: str,
n_predict: int = 4096, callback: ResponseCallbackType = empty_response_callback,
top_k: int = 40, **kwargs
top_p: float = 0.9, ) -> Iterable[str]:
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
) -> Iterable:
# Symbol to terminate from generator # Symbol to terminate from generator
TERMINATING_SYMBOL = object() TERMINATING_SYMBOL = object()
output_queue = queue.Queue() output_queue = queue.Queue()
prompt_bytes = prompt.encode('utf-8')
prompt_ptr = ctypes.c_char_p(prompt_bytes)
self._set_context(
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
reset_context=reset_context,
)
# Put response tokens into an output queue # Put response tokens into an output queue
def _generator_response_callback(token_id, response): def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
output_queue.put(response.decode('utf-8', 'replace')) def _generator_callback(token_id: int, response: str):
nonlocal callback
if callback(token_id, response):
output_queue.put(response)
return True return True
def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context): return False
llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context)
return _generator_callback
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, callback, **kwargs)
output_queue.put(TERMINATING_SYMBOL) output_queue.put(TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator # Kick off llmodel_prompt in separate thread so we can return generator
@ -374,13 +355,10 @@ class LLModel:
thread = threading.Thread( thread = threading.Thread(
target=run_llmodel_prompt, target=run_llmodel_prompt,
args=( args=(
self.model, prompt,
prompt_ptr, _generator_callback_wrapper(callback)
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
self.context,
), ),
kwargs=kwargs,
) )
thread.start() thread.start()
@ -391,18 +369,19 @@ class LLModel:
break break
yield response yield response
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
def _raw_callback(token_id: int, response: bytes) -> bool:
nonlocal callback
return callback(token_id, response.decode("utf-8", "replace"))
return _raw_callback
# Empty prompt callback # Empty prompt callback
@staticmethod @staticmethod
def _prompt_callback(token_id): def _prompt_callback(token_id: int) -> bool:
return True
# Empty response callback method that just prints response to be collected
@staticmethod
def _response_callback(token_id, response):
sys.stdout.write(response.decode('utf-8', 'replace'))
return True return True
# Empty recalculate callback # Empty recalculate callback
@staticmethod @staticmethod
def _recalculate_callback(is_recalculating): def _recalculate_callback(is_recalculating: bool) -> bool:
return is_recalculating return is_recalculating