mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-23 12:31:22 +00:00
Remove binary state from high-level API and use Jinja templates (#3147)
Signed-off-by: Jared Van Bortel <jared@nomic.ai> Signed-off-by: Adam Treat <treat.adam@gmail.com> Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, NoReturn, TypeVar, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
import importlib.resources as importlib_resources
|
||||
@@ -23,7 +23,9 @@ else:
|
||||
from typing import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_extensions import ParamSpec, TypeAlias
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
|
||||
|
||||
@@ -31,7 +33,7 @@ cuda_found: bool = False
|
||||
|
||||
|
||||
# TODO(jared): use operator.call after we drop python 3.10 support
|
||||
def _operator_call(obj, /, *args, **kwargs):
|
||||
def _operator_call(obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
return obj(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -116,16 +118,15 @@ llmodel = load_llmodel_library()
|
||||
|
||||
class LLModelPromptContext(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("n_past", ctypes.c_int32),
|
||||
("n_predict", ctypes.c_int32),
|
||||
("top_k", ctypes.c_int32),
|
||||
("top_p", ctypes.c_float),
|
||||
("min_p", ctypes.c_float),
|
||||
("temp", ctypes.c_float),
|
||||
("n_batch", ctypes.c_int32),
|
||||
("n_predict", ctypes.c_int32),
|
||||
("top_k", ctypes.c_int32),
|
||||
("top_p", ctypes.c_float),
|
||||
("min_p", ctypes.c_float),
|
||||
("temp", ctypes.c_float),
|
||||
("n_batch", ctypes.c_int32),
|
||||
("repeat_penalty", ctypes.c_float),
|
||||
("repeat_last_n", ctypes.c_int32),
|
||||
("context_erase", ctypes.c_float),
|
||||
("repeat_last_n", ctypes.c_int32),
|
||||
("context_erase", ctypes.c_float),
|
||||
]
|
||||
|
||||
|
||||
@@ -157,23 +158,21 @@ llmodel.llmodel_required_mem.restype = ctypes.c_size_t
|
||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
|
||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
|
||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t, ctypes.c_bool)
|
||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
|
||||
SpecialTokenCallback = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p)
|
||||
|
||||
llmodel.llmodel_prompt.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_char_p,
|
||||
PromptCallback,
|
||||
ResponseCallback,
|
||||
ctypes.c_bool,
|
||||
ctypes.POINTER(LLModelPromptContext),
|
||||
ctypes.c_bool,
|
||||
ctypes.c_char_p,
|
||||
ctypes.POINTER(ctypes.c_char_p),
|
||||
]
|
||||
|
||||
llmodel.llmodel_prompt.restype = None
|
||||
llmodel.llmodel_prompt.restype = ctypes.c_bool
|
||||
|
||||
llmodel.llmodel_embed.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
@@ -222,6 +221,12 @@ llmodel.llmodel_model_backend_name.restype = ctypes.c_char_p
|
||||
llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p
|
||||
|
||||
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char_p)]
|
||||
llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32
|
||||
|
||||
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
|
||||
llmodel.llmodel_model_foreach_special_token.restype = None
|
||||
|
||||
ResponseCallbackType = Callable[[int, str], bool]
|
||||
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
|
||||
@@ -266,7 +271,6 @@ class LLModel:
|
||||
self.model_path = model_path.encode()
|
||||
self.n_ctx = n_ctx
|
||||
self.ngl = ngl
|
||||
self.context: LLModelPromptContext | None = None
|
||||
self.buffer = bytearray()
|
||||
self.buff_expecting_cont_bytes: int = 0
|
||||
|
||||
@@ -286,6 +290,10 @@ class LLModel:
|
||||
|
||||
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
|
||||
self.model: ctypes.c_void_p | None = model
|
||||
self.special_tokens_map: dict[str, str] = {}
|
||||
llmodel.llmodel_model_foreach_special_token(
|
||||
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
|
||||
)
|
||||
|
||||
def __del__(self, llmodel=llmodel):
|
||||
if hasattr(self, 'model'):
|
||||
@@ -312,6 +320,19 @@ class LLModel:
|
||||
dev = llmodel.llmodel_model_gpu_device_name(self.model)
|
||||
return None if dev is None else dev.decode()
|
||||
|
||||
def count_prompt_tokens(self, prompt: str) -> int:
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
err = ctypes.c_char_p()
|
||||
n_tok = llmodel.llmodel_count_prompt_tokens(self.model, prompt, ctypes.byref(err))
|
||||
if n_tok < 0:
|
||||
s = err.value
|
||||
errmsg = 'null' if s is None else s.decode()
|
||||
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
|
||||
return n_tok
|
||||
|
||||
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
|
||||
@staticmethod
|
||||
def list_gpus(mem_required: int = 0) -> list[str]:
|
||||
"""
|
||||
@@ -375,48 +396,6 @@ class LLModel:
|
||||
raise Exception("Model not loaded")
|
||||
return llmodel.llmodel_threadCount(self.model)
|
||||
|
||||
def _set_context(
|
||||
self,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.9,
|
||||
min_p: float = 0.0,
|
||||
temp: float = 0.1,
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
):
|
||||
if self.context is None:
|
||||
context = LLModelPromptContext(
|
||||
n_past=0,
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
temp=temp,
|
||||
n_batch=n_batch,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
context_erase=context_erase,
|
||||
)
|
||||
self.context = context
|
||||
else:
|
||||
context = self.context
|
||||
if reset_context:
|
||||
self.context.n_past = 0
|
||||
|
||||
self.context.n_predict = n_predict
|
||||
self.context.top_k = top_k
|
||||
self.context.top_p = top_p
|
||||
self.context.min_p = min_p
|
||||
self.context.temp = temp
|
||||
self.context.n_batch = n_batch
|
||||
self.context.repeat_penalty = repeat_penalty
|
||||
self.context.repeat_last_n = repeat_last_n
|
||||
self.context.context_erase = context_erase
|
||||
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
@@ -486,20 +465,18 @@ class LLModel:
|
||||
|
||||
def prompt_model(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: str,
|
||||
callback: ResponseCallbackType,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.9,
|
||||
min_p: float = 0.0,
|
||||
temp: float = 0.1,
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
special: bool = False,
|
||||
prompt : str,
|
||||
callback : ResponseCallbackType,
|
||||
n_predict : int = 4096,
|
||||
top_k : int = 40,
|
||||
top_p : float = 0.9,
|
||||
min_p : float = 0.0,
|
||||
temp : float = 0.1,
|
||||
n_batch : int = 8,
|
||||
repeat_penalty : float = 1.2,
|
||||
repeat_last_n : int = 10,
|
||||
context_erase : float = 0.75,
|
||||
reset_context : bool = False,
|
||||
):
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
@@ -522,34 +499,38 @@ class LLModel:
|
||||
self.buffer.clear()
|
||||
self.buff_expecting_cont_bytes = 0
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
temp=temp,
|
||||
n_batch=n_batch,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
context_erase=context_erase,
|
||||
reset_context=reset_context,
|
||||
context = LLModelPromptContext(
|
||||
n_predict = n_predict,
|
||||
top_k = top_k,
|
||||
top_p = top_p,
|
||||
min_p = min_p,
|
||||
temp = temp,
|
||||
n_batch = n_batch,
|
||||
repeat_penalty = repeat_penalty,
|
||||
repeat_last_n = repeat_last_n,
|
||||
context_erase = context_erase,
|
||||
)
|
||||
|
||||
llmodel.llmodel_prompt(
|
||||
error_msg: bytes | None = None
|
||||
def error_callback(msg: bytes) -> None:
|
||||
nonlocal error_msg
|
||||
error_msg = msg
|
||||
|
||||
err = ctypes.c_char_p()
|
||||
if not llmodel.llmodel_prompt(
|
||||
self.model,
|
||||
ctypes.c_char_p(prompt.encode()),
|
||||
ctypes.c_char_p(prompt_template.encode()),
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._callback_decoder(callback)),
|
||||
True,
|
||||
self.context,
|
||||
special,
|
||||
ctypes.c_char_p(),
|
||||
)
|
||||
context,
|
||||
ctypes.byref(err),
|
||||
):
|
||||
s = err.value
|
||||
raise RuntimeError(f"prompt error: {'null' if s is None else s.decode()}")
|
||||
|
||||
def prompt_model_streaming(
|
||||
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
) -> Iterable[str]:
|
||||
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
|
||||
@@ -568,15 +549,15 @@ class LLModel:
|
||||
|
||||
return _generator_callback
|
||||
|
||||
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, prompt_template, callback, **kwargs)
|
||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, callback, **kwargs)
|
||||
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
# immediately
|
||||
thread = threading.Thread(
|
||||
target=run_llmodel_prompt,
|
||||
args=(prompt, prompt_template, _generator_callback_wrapper(callback)),
|
||||
args=(prompt, _generator_callback_wrapper(callback)),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
thread.start()
|
||||
@@ -631,5 +612,5 @@ class LLModel:
|
||||
|
||||
# Empty prompt callback
|
||||
@staticmethod
|
||||
def _prompt_callback(token_id: int) -> bool:
|
||||
def _prompt_callback(token_ids: ctypes._Pointer[ctypes.c_int32], n_token_ids: int, cached: bool) -> bool:
|
||||
return True
|
||||
|
@@ -4,37 +4,66 @@ Python only API for running all GPT4All models.
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, NoReturn, Protocol, TypedDict, overload
|
||||
|
||||
import jinja2
|
||||
import requests
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
from tqdm import tqdm
|
||||
from urllib3.exceptions import IncompleteRead, ProtocolError
|
||||
|
||||
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
|
||||
LLModel, ResponseCallbackType, empty_response_callback)
|
||||
LLModel, ResponseCallbackType, _operator_call, empty_response_callback)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
if sys.platform == "darwin":
|
||||
import fcntl
|
||||
|
||||
# TODO: move to config
|
||||
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
|
||||
ConfigType: TypeAlias = "dict[str, Any]"
|
||||
|
||||
ConfigType: TypeAlias = 'dict[str, Any]'
|
||||
MessageType: TypeAlias = 'dict[str, str]'
|
||||
# Environment setup adapted from HF transformers
|
||||
@_operator_call
|
||||
def _jinja_env() -> ImmutableSandboxedEnvironment:
|
||||
def raise_exception(message: str) -> NoReturn:
|
||||
raise jinja2.exceptions.TemplateError(message)
|
||||
|
||||
def tojson(obj: Any, indent: int | None = None) -> str:
|
||||
return json.dumps(obj, ensure_ascii=False, indent=indent)
|
||||
|
||||
def strftime_now(fmt: str) -> str:
|
||||
return datetime.now().strftime(fmt)
|
||||
|
||||
env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||
env.filters["tojson" ] = tojson
|
||||
env.globals["raise_exception"] = raise_exception
|
||||
env.globals["strftime_now" ] = strftime_now
|
||||
return env
|
||||
|
||||
|
||||
class MessageType(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatSession(NamedTuple):
|
||||
template: jinja2.Template
|
||||
history: list[MessageType]
|
||||
|
||||
|
||||
class Embed4All:
|
||||
@@ -54,7 +83,7 @@ class Embed4All:
|
||||
kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor.
|
||||
"""
|
||||
if model_name is None:
|
||||
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
|
||||
model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
|
||||
self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
@@ -145,18 +174,18 @@ class Embed4All:
|
||||
dimensionality = -1
|
||||
else:
|
||||
if dimensionality <= 0:
|
||||
raise ValueError(f'Dimensionality must be None or a positive integer, got {dimensionality}')
|
||||
raise ValueError(f"Dimensionality must be None or a positive integer, got {dimensionality}")
|
||||
if dimensionality < self.MIN_DIMENSIONALITY:
|
||||
warnings.warn(
|
||||
f'Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}.'
|
||||
' Performance may be degraded.'
|
||||
f"Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}."
|
||||
" Performance may be degraded."
|
||||
)
|
||||
try:
|
||||
do_mean = {"mean": True, "truncate": False}[long_text_mode]
|
||||
except KeyError:
|
||||
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
|
||||
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
|
||||
return result if return_dict else result['embeddings']
|
||||
return result if return_dict else result["embeddings"]
|
||||
|
||||
|
||||
class GPT4All:
|
||||
@@ -204,8 +233,7 @@ class GPT4All:
|
||||
"""
|
||||
|
||||
self.model_type = model_type
|
||||
self._history: list[MessageType] | None = None
|
||||
self._current_prompt_template: str = "{0}"
|
||||
self._chat_session: ChatSession | None = None
|
||||
|
||||
device_init = None
|
||||
if sys.platform == "darwin":
|
||||
@@ -264,7 +292,13 @@ class GPT4All:
|
||||
|
||||
@property
|
||||
def current_chat_session(self) -> list[MessageType] | None:
|
||||
return None if self._history is None else list(self._history)
|
||||
return None if self._chat_session is None else self._chat_session.history
|
||||
|
||||
@current_chat_session.setter
|
||||
def current_chat_session(self, history: list[MessageType]) -> None:
|
||||
if self._chat_session is None:
|
||||
raise ValueError("current_chat_session may only be set when there is an active chat session")
|
||||
self._chat_session.history[:] = history
|
||||
|
||||
@staticmethod
|
||||
def list_models() -> list[ConfigType]:
|
||||
@@ -276,7 +310,7 @@ class GPT4All:
|
||||
"""
|
||||
resp = requests.get("https://gpt4all.io/models/models3.json")
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
|
||||
raise ValueError(f"Request failed: HTTP {resp.status_code} {resp.reason}")
|
||||
return resp.json()
|
||||
|
||||
@classmethod
|
||||
@@ -306,15 +340,9 @@ class GPT4All:
|
||||
# get the config for the model
|
||||
config: ConfigType = {}
|
||||
if allow_download:
|
||||
available_models = cls.list_models()
|
||||
|
||||
for m in available_models:
|
||||
if model_filename == m["filename"]:
|
||||
tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE)
|
||||
# change to Python-style formatting
|
||||
m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1)
|
||||
config.update(m)
|
||||
break
|
||||
models = cls.list_models()
|
||||
if (model := next((m for m in models if m["filename"] == model_filename), None)) is not None:
|
||||
config.update(model)
|
||||
|
||||
# Validate download directory
|
||||
if model_path is None:
|
||||
@@ -378,13 +406,13 @@ class GPT4All:
|
||||
headers = {}
|
||||
if offset:
|
||||
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
||||
headers['Range'] = f'bytes={offset}-' # resume incomplete response
|
||||
headers["Range"] = f"bytes={offset}-" # resume incomplete response
|
||||
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
|
||||
response = requests.get(url, stream=True, headers=headers)
|
||||
if response.status_code not in (200, 206):
|
||||
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
||||
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
|
||||
raise ValueError('Connection was interrupted and server does not support range requests')
|
||||
raise ValueError(f"Request failed: HTTP {response.status_code} {response.reason}")
|
||||
if offset and (response.status_code != 206 or str(offset) not in response.headers.get("Content-Range", "")):
|
||||
raise ValueError("Connection was interrupted and server does not support range requests")
|
||||
if (enc := response.headers.get("Content-Encoding")) is not None:
|
||||
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
|
||||
return response
|
||||
@@ -483,19 +511,19 @@ class GPT4All:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt : str,
|
||||
*,
|
||||
max_tokens: int = 200,
|
||||
temp: float = 0.7,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.4,
|
||||
min_p: float = 0.0,
|
||||
repeat_penalty: float = 1.18,
|
||||
repeat_last_n: int = 64,
|
||||
n_batch: int = 8,
|
||||
n_predict: int | None = None,
|
||||
streaming: bool = False,
|
||||
callback: ResponseCallbackType = empty_response_callback,
|
||||
max_tokens : int = 200,
|
||||
temp : float = 0.7,
|
||||
top_k : int = 40,
|
||||
top_p : float = 0.4,
|
||||
min_p : float = 0.0,
|
||||
repeat_penalty : float = 1.18,
|
||||
repeat_last_n : int = 64,
|
||||
n_batch : int = 8,
|
||||
n_predict : int | None = None,
|
||||
streaming : bool = False,
|
||||
callback : ResponseCallbackType = empty_response_callback,
|
||||
) -> Any:
|
||||
"""
|
||||
Generate outputs from any GPT4All model.
|
||||
@@ -520,122 +548,94 @@ class GPT4All:
|
||||
|
||||
# Preparing the model request
|
||||
generate_kwargs: dict[str, Any] = dict(
|
||||
temp=temp,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
n_batch=n_batch,
|
||||
n_predict=n_predict if n_predict is not None else max_tokens,
|
||||
temp = temp,
|
||||
top_k = top_k,
|
||||
top_p = top_p,
|
||||
min_p = min_p,
|
||||
repeat_penalty = repeat_penalty,
|
||||
repeat_last_n = repeat_last_n,
|
||||
n_batch = n_batch,
|
||||
n_predict = n_predict if n_predict is not None else max_tokens,
|
||||
)
|
||||
|
||||
if self._history is not None:
|
||||
# check if there is only one message, i.e. system prompt:
|
||||
reset = len(self._history) == 1
|
||||
self._history.append({"role": "user", "content": prompt})
|
||||
|
||||
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
|
||||
if fct_func is GPT4All._format_chat_prompt_template:
|
||||
if reset:
|
||||
# ingest system prompt
|
||||
# use "%1%2" and not "%1" to avoid implicit whitespace
|
||||
self.model.prompt_model(self._history[0]["content"], "%1%2",
|
||||
empty_response_callback,
|
||||
n_batch=n_batch, n_predict=0, reset_context=True, special=True)
|
||||
prompt_template = self._current_prompt_template.format("%1", "%2")
|
||||
else:
|
||||
warnings.warn(
|
||||
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
# special tokens won't be processed
|
||||
prompt = self._format_chat_prompt_template(
|
||||
self._history[-1:],
|
||||
self._history[0]["content"] if reset else "",
|
||||
)
|
||||
prompt_template = "%1"
|
||||
generate_kwargs["reset_context"] = reset
|
||||
else:
|
||||
prompt_template = "%1"
|
||||
generate_kwargs["reset_context"] = True
|
||||
|
||||
# Prepare the callback, process the model response
|
||||
output_collector: list[MessageType]
|
||||
output_collector = [
|
||||
{"content": ""}
|
||||
] # placeholder for the self._history if chat session is not activated
|
||||
full_response = ""
|
||||
|
||||
if self._history is not None:
|
||||
self._history.append({"role": "assistant", "content": ""})
|
||||
output_collector = self._history
|
||||
def _callback_wrapper(token_id: int, response: str) -> bool:
|
||||
nonlocal full_response
|
||||
full_response += response
|
||||
return callback(token_id, response)
|
||||
|
||||
def _callback_wrapper(
|
||||
callback: ResponseCallbackType,
|
||||
output_collector: list[MessageType],
|
||||
) -> ResponseCallbackType:
|
||||
def _callback(token_id: int, response: str) -> bool:
|
||||
nonlocal callback, output_collector
|
||||
last_msg_rendered = prompt
|
||||
if self._chat_session is not None:
|
||||
session = self._chat_session
|
||||
def render(messages: list[MessageType]) -> str:
|
||||
return session.template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=True,
|
||||
**self.model.special_tokens_map,
|
||||
)
|
||||
session.history.append(MessageType(role="user", content=prompt))
|
||||
prompt = render(session.history)
|
||||
if len(session.history) > 1:
|
||||
last_msg_rendered = render(session.history[-1:])
|
||||
|
||||
output_collector[-1]["content"] += response
|
||||
|
||||
return callback(token_id, response)
|
||||
|
||||
return _callback
|
||||
# Check request length
|
||||
last_msg_len = self.model.count_prompt_tokens(last_msg_rendered)
|
||||
if last_msg_len > (limit := self.model.n_ctx - 4):
|
||||
raise ValueError(f"Your message was too long and could not be processed ({last_msg_len} > {limit}).")
|
||||
|
||||
# Send the request to the model
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(
|
||||
prompt,
|
||||
prompt_template,
|
||||
_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
def stream() -> Iterator[str]:
|
||||
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
|
||||
if self._chat_session is not None:
|
||||
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
||||
return stream()
|
||||
|
||||
self.model.prompt_model(
|
||||
prompt,
|
||||
prompt_template,
|
||||
_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return output_collector[-1]["content"]
|
||||
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
|
||||
if self._chat_session is not None:
|
||||
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
||||
return full_response
|
||||
|
||||
@contextmanager
|
||||
def chat_session(
|
||||
self,
|
||||
system_prompt: str | None = None,
|
||||
prompt_template: str | None = None,
|
||||
system_message: str | Literal[False] | None = None,
|
||||
chat_template: str | None = None,
|
||||
):
|
||||
"""
|
||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||
|
||||
Args:
|
||||
system_prompt: An initial instruction for the model.
|
||||
prompt_template: Template for the prompts with {0} being replaced by the user message.
|
||||
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
|
||||
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
|
||||
"""
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = self.config.get("systemPrompt", "")
|
||||
if system_message is None:
|
||||
system_message = self.config.get("systemMessage", False)
|
||||
|
||||
if prompt_template is None:
|
||||
if (tmpl := self.config.get("promptTemplate")) is None:
|
||||
warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template "
|
||||
"is deprecated. Defaulting to Alpaca.", DeprecationWarning)
|
||||
tmpl = DEFAULT_PROMPT_TEMPLATE
|
||||
prompt_template = tmpl
|
||||
if chat_template is None:
|
||||
if "name" not in self.config:
|
||||
raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
|
||||
if "chatTemplate" not in self.config:
|
||||
raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
|
||||
"currently implemented. Please pass a template to chat_session() directly.")
|
||||
if (tmpl := self.config["chatTemplate"]) is None:
|
||||
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
|
||||
chat_template = tmpl
|
||||
|
||||
if re.search(r"%1(?![0-9])", prompt_template):
|
||||
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
|
||||
"placeholder, please use '{0}' instead.")
|
||||
|
||||
self._history = [{"role": "system", "content": system_prompt}]
|
||||
self._current_prompt_template = prompt_template
|
||||
history = []
|
||||
if system_message is not False:
|
||||
history.append(MessageType(role="system", content=system_message))
|
||||
self._chat_session = ChatSession(
|
||||
template=_jinja_env.from_string(chat_template),
|
||||
history=history,
|
||||
)
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._history = None
|
||||
self._current_prompt_template = "{0}"
|
||||
self._chat_session = None
|
||||
|
||||
@staticmethod
|
||||
def list_gpus() -> list[str]:
|
||||
@@ -647,43 +647,6 @@ class GPT4All:
|
||||
"""
|
||||
return LLModel.list_gpus()
|
||||
|
||||
def _format_chat_prompt_template(
|
||||
self,
|
||||
messages: list[MessageType],
|
||||
default_prompt_header: str = "",
|
||||
default_prompt_footer: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
|
||||
|
||||
Warning:
|
||||
This function was deprecated in version 2.3.0, and will be removed in a future release.
|
||||
|
||||
Args:
|
||||
messages: List of dictionaries. Each dictionary should have a "role" key
|
||||
with value of "system", "assistant", or "user" and a "content" key with a
|
||||
string value. Messages are organized such that "system" messages are at top of prompt,
|
||||
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
|
||||
"Response: {content}".
|
||||
|
||||
Returns:
|
||||
Formatted prompt.
|
||||
"""
|
||||
|
||||
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
user_message = self._current_prompt_template.format(message["content"])
|
||||
full_prompt += user_message
|
||||
if message["role"] == "assistant":
|
||||
assistant_message = message["content"] + "\n"
|
||||
full_prompt += assistant_message
|
||||
|
||||
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
|
||||
|
||||
return full_prompt
|
||||
|
||||
|
||||
def append_extension_if_missing(model_name):
|
||||
if not model_name.endswith((".bin", ".gguf")):
|
||||
@@ -696,7 +659,7 @@ class _HasFileno(Protocol):
|
||||
|
||||
|
||||
def _fsync(fd: int | _HasFileno) -> None:
|
||||
if sys.platform == 'darwin':
|
||||
if sys.platform == "darwin":
|
||||
# Apple's fsync does not flush the drive write cache
|
||||
try:
|
||||
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
|
||||
|
Reference in New Issue
Block a user