Remove binary state from high-level API and use Jinja templates (#3147)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Signed-off-by: Adam Treat <treat.adam@gmail.com>
Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Jared Van Bortel
2024-11-25 10:04:17 -05:00
committed by GitHub
parent 3320094d29
commit 225bf6be93
54 changed files with 3423 additions and 2224 deletions

View File

@@ -9,7 +9,7 @@ import textwrap
import threading
from enum import Enum
from queue import Queue
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Literal, NoReturn, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
@@ -23,7 +23,9 @@ else:
from typing import TypedDict
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing_extensions import ParamSpec, TypeAlias
T = TypeVar("T")
P = ParamSpec("P")
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
@@ -31,7 +33,7 @@ cuda_found: bool = False
# TODO(jared): use operator.call after we drop python 3.10 support
def _operator_call(obj, /, *args, **kwargs):
def _operator_call(obj: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
return obj(*args, **kwargs)
@@ -116,16 +118,15 @@ llmodel = load_llmodel_library()
class LLModelPromptContext(ctypes.Structure):
_fields_ = [
("n_past", ctypes.c_int32),
("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("min_p", ctypes.c_float),
("temp", ctypes.c_float),
("n_batch", ctypes.c_int32),
("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("min_p", ctypes.c_float),
("temp", ctypes.c_float),
("n_batch", ctypes.c_int32),
("repeat_penalty", ctypes.c_float),
("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float),
("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float),
]
@@ -157,23 +158,21 @@ llmodel.llmodel_required_mem.restype = ctypes.c_size_t
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t, ctypes.c_bool)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
SpecialTokenCallback = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p)
llmodel.llmodel_prompt.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.c_char_p,
PromptCallback,
ResponseCallback,
ctypes.c_bool,
ctypes.POINTER(LLModelPromptContext),
ctypes.c_bool,
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_char_p),
]
llmodel.llmodel_prompt.restype = None
llmodel.llmodel_prompt.restype = ctypes.c_bool
llmodel.llmodel_embed.argtypes = [
ctypes.c_void_p,
@@ -222,6 +221,12 @@ llmodel.llmodel_model_backend_name.restype = ctypes.c_char_p
llmodel.llmodel_model_gpu_device_name.argtypes = [ctypes.c_void_p]
llmodel.llmodel_model_gpu_device_name.restype = ctypes.c_char_p
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char_p)]
llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
llmodel.llmodel_model_foreach_special_token.restype = None
ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool]
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
@@ -266,7 +271,6 @@ class LLModel:
self.model_path = model_path.encode()
self.n_ctx = n_ctx
self.ngl = ngl
self.context: LLModelPromptContext | None = None
self.buffer = bytearray()
self.buff_expecting_cont_bytes: int = 0
@@ -286,6 +290,10 @@ class LLModel:
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
self.model: ctypes.c_void_p | None = model
self.special_tokens_map: dict[str, str] = {}
llmodel.llmodel_model_foreach_special_token(
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
)
def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'):
@@ -312,6 +320,19 @@ class LLModel:
dev = llmodel.llmodel_model_gpu_device_name(self.model)
return None if dev is None else dev.decode()
def count_prompt_tokens(self, prompt: str) -> int:
if self.model is None:
self._raise_closed()
err = ctypes.c_char_p()
n_tok = llmodel.llmodel_count_prompt_tokens(self.model, prompt, ctypes.byref(err))
if n_tok < 0:
s = err.value
errmsg = 'null' if s is None else s.decode()
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
return n_tok
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
@staticmethod
def list_gpus(mem_required: int = 0) -> list[str]:
"""
@@ -375,48 +396,6 @@ class LLModel:
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
def _set_context(
self,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
min_p: float = 0.0,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
):
if self.context is None:
context = LLModelPromptContext(
n_past=0,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
)
self.context = context
else:
context = self.context
if reset_context:
self.context.n_past = 0
self.context.n_predict = n_predict
self.context.top_k = top_k
self.context.top_p = top_p
self.context.min_p = min_p
self.context.temp = temp
self.context.n_batch = n_batch
self.context.repeat_penalty = repeat_penalty
self.context.repeat_last_n = repeat_last_n
self.context.context_erase = context_erase
@overload
def generate_embeddings(
self, text: str, prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool,
@@ -486,20 +465,18 @@ class LLModel:
def prompt_model(
self,
prompt: str,
prompt_template: str,
callback: ResponseCallbackType,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
min_p: float = 0.0,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
special: bool = False,
prompt : str,
callback : ResponseCallbackType,
n_predict : int = 4096,
top_k : int = 40,
top_p : float = 0.9,
min_p : float = 0.0,
temp : float = 0.1,
n_batch : int = 8,
repeat_penalty : float = 1.2,
repeat_last_n : int = 10,
context_erase : float = 0.75,
reset_context : bool = False,
):
"""
Generate response from model from a prompt.
@@ -522,34 +499,38 @@ class LLModel:
self.buffer.clear()
self.buff_expecting_cont_bytes = 0
self._set_context(
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
reset_context=reset_context,
context = LLModelPromptContext(
n_predict = n_predict,
top_k = top_k,
top_p = top_p,
min_p = min_p,
temp = temp,
n_batch = n_batch,
repeat_penalty = repeat_penalty,
repeat_last_n = repeat_last_n,
context_erase = context_erase,
)
llmodel.llmodel_prompt(
error_msg: bytes | None = None
def error_callback(msg: bytes) -> None:
nonlocal error_msg
error_msg = msg
err = ctypes.c_char_p()
if not llmodel.llmodel_prompt(
self.model,
ctypes.c_char_p(prompt.encode()),
ctypes.c_char_p(prompt_template.encode()),
PromptCallback(self._prompt_callback),
ResponseCallback(self._callback_decoder(callback)),
True,
self.context,
special,
ctypes.c_char_p(),
)
context,
ctypes.byref(err),
):
s = err.value
raise RuntimeError(f"prompt error: {'null' if s is None else s.decode()}")
def prompt_model_streaming(
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]:
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs: Any,
) -> Iterator[str]:
if self.model is None:
self._raise_closed()
@@ -568,15 +549,15 @@ class LLModel:
return _generator_callback
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, prompt_template, callback, **kwargs)
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, callback, **kwargs)
output_queue.put(Sentinel.TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
# immediately
thread = threading.Thread(
target=run_llmodel_prompt,
args=(prompt, prompt_template, _generator_callback_wrapper(callback)),
args=(prompt, _generator_callback_wrapper(callback)),
kwargs=kwargs,
)
thread.start()
@@ -631,5 +612,5 @@ class LLModel:
# Empty prompt callback
@staticmethod
def _prompt_callback(token_id: int) -> bool:
def _prompt_callback(token_ids: ctypes._Pointer[ctypes.c_int32], n_token_ids: int, cached: bool) -> bool:
return True

View File

@@ -4,37 +4,66 @@ Python only API for running all GPT4All models.
from __future__ import annotations
import hashlib
import json
import os
import platform
import re
import sys
import warnings
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, NamedTuple, NoReturn, Protocol, TypedDict, overload
import jinja2
import requests
from jinja2.sandbox import ImmutableSandboxedEnvironment
from requests.exceptions import ChunkedEncodingError
from tqdm import tqdm
from urllib3.exceptions import IncompleteRead, ProtocolError
from ._pyllmodel import (CancellationError as CancellationError, EmbCancelCallbackType, EmbedResult as EmbedResult,
LLModel, ResponseCallbackType, empty_response_callback)
LLModel, ResponseCallbackType, _operator_call, empty_response_callback)
if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias
if sys.platform == 'darwin':
if sys.platform == "darwin":
import fcntl
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
ConfigType: TypeAlias = "dict[str, Any]"
ConfigType: TypeAlias = 'dict[str, Any]'
MessageType: TypeAlias = 'dict[str, str]'
# Environment setup adapted from HF transformers
@_operator_call
def _jinja_env() -> ImmutableSandboxedEnvironment:
def raise_exception(message: str) -> NoReturn:
raise jinja2.exceptions.TemplateError(message)
def tojson(obj: Any, indent: int | None = None) -> str:
return json.dumps(obj, ensure_ascii=False, indent=indent)
def strftime_now(fmt: str) -> str:
return datetime.now().strftime(fmt)
env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
env.filters["tojson" ] = tojson
env.globals["raise_exception"] = raise_exception
env.globals["strftime_now" ] = strftime_now
return env
class MessageType(TypedDict):
role: str
content: str
class ChatSession(NamedTuple):
template: jinja2.Template
history: list[MessageType]
class Embed4All:
@@ -54,7 +83,7 @@ class Embed4All:
kwargs: Remaining keyword arguments are passed to the `GPT4All` constructor.
"""
if model_name is None:
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
self.gpt4all = GPT4All(model_name, n_threads=n_threads, device=device, **kwargs)
def __enter__(self) -> Self:
@@ -145,18 +174,18 @@ class Embed4All:
dimensionality = -1
else:
if dimensionality <= 0:
raise ValueError(f'Dimensionality must be None or a positive integer, got {dimensionality}')
raise ValueError(f"Dimensionality must be None or a positive integer, got {dimensionality}")
if dimensionality < self.MIN_DIMENSIONALITY:
warnings.warn(
f'Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}.'
' Performance may be degraded.'
f"Dimensionality {dimensionality} is less than the suggested minimum of {self.MIN_DIMENSIONALITY}."
" Performance may be degraded."
)
try:
do_mean = {"mean": True, "truncate": False}[long_text_mode]
except KeyError:
raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}")
result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas, cancel_cb)
return result if return_dict else result['embeddings']
return result if return_dict else result["embeddings"]
class GPT4All:
@@ -204,8 +233,7 @@ class GPT4All:
"""
self.model_type = model_type
self._history: list[MessageType] | None = None
self._current_prompt_template: str = "{0}"
self._chat_session: ChatSession | None = None
device_init = None
if sys.platform == "darwin":
@@ -264,7 +292,13 @@ class GPT4All:
@property
def current_chat_session(self) -> list[MessageType] | None:
return None if self._history is None else list(self._history)
return None if self._chat_session is None else self._chat_session.history
@current_chat_session.setter
def current_chat_session(self, history: list[MessageType]) -> None:
if self._chat_session is None:
raise ValueError("current_chat_session may only be set when there is an active chat session")
self._chat_session.history[:] = history
@staticmethod
def list_models() -> list[ConfigType]:
@@ -276,7 +310,7 @@ class GPT4All:
"""
resp = requests.get("https://gpt4all.io/models/models3.json")
if resp.status_code != 200:
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
raise ValueError(f"Request failed: HTTP {resp.status_code} {resp.reason}")
return resp.json()
@classmethod
@@ -306,15 +340,9 @@ class GPT4All:
# get the config for the model
config: ConfigType = {}
if allow_download:
available_models = cls.list_models()
for m in available_models:
if model_filename == m["filename"]:
tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE)
# change to Python-style formatting
m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1)
config.update(m)
break
models = cls.list_models()
if (model := next((m for m in models if m["filename"] == model_filename), None)) is not None:
config.update(model)
# Validate download directory
if model_path is None:
@@ -378,13 +406,13 @@ class GPT4All:
headers = {}
if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response
headers["Range"] = f"bytes={offset}-" # resume incomplete response
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
response = requests.get(url, stream=True, headers=headers)
if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
raise ValueError('Connection was interrupted and server does not support range requests')
raise ValueError(f"Request failed: HTTP {response.status_code} {response.reason}")
if offset and (response.status_code != 206 or str(offset) not in response.headers.get("Content-Range", "")):
raise ValueError("Connection was interrupted and server does not support range requests")
if (enc := response.headers.get("Content-Encoding")) is not None:
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
return response
@@ -483,19 +511,19 @@ class GPT4All:
def generate(
self,
prompt: str,
prompt : str,
*,
max_tokens: int = 200,
temp: float = 0.7,
top_k: int = 40,
top_p: float = 0.4,
min_p: float = 0.0,
repeat_penalty: float = 1.18,
repeat_last_n: int = 64,
n_batch: int = 8,
n_predict: int | None = None,
streaming: bool = False,
callback: ResponseCallbackType = empty_response_callback,
max_tokens : int = 200,
temp : float = 0.7,
top_k : int = 40,
top_p : float = 0.4,
min_p : float = 0.0,
repeat_penalty : float = 1.18,
repeat_last_n : int = 64,
n_batch : int = 8,
n_predict : int | None = None,
streaming : bool = False,
callback : ResponseCallbackType = empty_response_callback,
) -> Any:
"""
Generate outputs from any GPT4All model.
@@ -520,122 +548,94 @@ class GPT4All:
# Preparing the model request
generate_kwargs: dict[str, Any] = dict(
temp=temp,
top_k=top_k,
top_p=top_p,
min_p=min_p,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
n_batch=n_batch,
n_predict=n_predict if n_predict is not None else max_tokens,
temp = temp,
top_k = top_k,
top_p = top_p,
min_p = min_p,
repeat_penalty = repeat_penalty,
repeat_last_n = repeat_last_n,
n_batch = n_batch,
n_predict = n_predict if n_predict is not None else max_tokens,
)
if self._history is not None:
# check if there is only one message, i.e. system prompt:
reset = len(self._history) == 1
self._history.append({"role": "user", "content": prompt})
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
if fct_func is GPT4All._format_chat_prompt_template:
if reset:
# ingest system prompt
# use "%1%2" and not "%1" to avoid implicit whitespace
self.model.prompt_model(self._history[0]["content"], "%1%2",
empty_response_callback,
n_batch=n_batch, n_predict=0, reset_context=True, special=True)
prompt_template = self._current_prompt_template.format("%1", "%2")
else:
warnings.warn(
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
DeprecationWarning,
)
# special tokens won't be processed
prompt = self._format_chat_prompt_template(
self._history[-1:],
self._history[0]["content"] if reset else "",
)
prompt_template = "%1"
generate_kwargs["reset_context"] = reset
else:
prompt_template = "%1"
generate_kwargs["reset_context"] = True
# Prepare the callback, process the model response
output_collector: list[MessageType]
output_collector = [
{"content": ""}
] # placeholder for the self._history if chat session is not activated
full_response = ""
if self._history is not None:
self._history.append({"role": "assistant", "content": ""})
output_collector = self._history
def _callback_wrapper(token_id: int, response: str) -> bool:
nonlocal full_response
full_response += response
return callback(token_id, response)
def _callback_wrapper(
callback: ResponseCallbackType,
output_collector: list[MessageType],
) -> ResponseCallbackType:
def _callback(token_id: int, response: str) -> bool:
nonlocal callback, output_collector
last_msg_rendered = prompt
if self._chat_session is not None:
session = self._chat_session
def render(messages: list[MessageType]) -> str:
return session.template.render(
messages=messages,
add_generation_prompt=True,
**self.model.special_tokens_map,
)
session.history.append(MessageType(role="user", content=prompt))
prompt = render(session.history)
if len(session.history) > 1:
last_msg_rendered = render(session.history[-1:])
output_collector[-1]["content"] += response
return callback(token_id, response)
return _callback
# Check request length
last_msg_len = self.model.count_prompt_tokens(last_msg_rendered)
if last_msg_len > (limit := self.model.n_ctx - 4):
raise ValueError(f"Your message was too long and could not be processed ({last_msg_len} > {limit}).")
# Send the request to the model
if streaming:
return self.model.prompt_model_streaming(
prompt,
prompt_template,
_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
def stream() -> Iterator[str]:
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
return stream()
self.model.prompt_model(
prompt,
prompt_template,
_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
return output_collector[-1]["content"]
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
return full_response
@contextmanager
def chat_session(
self,
system_prompt: str | None = None,
prompt_template: str | None = None,
system_message: str | Literal[False] | None = None,
chat_template: str | None = None,
):
"""
Context manager to hold an inference optimized chat session with a GPT4All model.
Args:
system_prompt: An initial instruction for the model.
prompt_template: Template for the prompts with {0} being replaced by the user message.
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
"""
if system_prompt is None:
system_prompt = self.config.get("systemPrompt", "")
if system_message is None:
system_message = self.config.get("systemMessage", False)
if prompt_template is None:
if (tmpl := self.config.get("promptTemplate")) is None:
warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template "
"is deprecated. Defaulting to Alpaca.", DeprecationWarning)
tmpl = DEFAULT_PROMPT_TEMPLATE
prompt_template = tmpl
if chat_template is None:
if "name" not in self.config:
raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
if "chatTemplate" not in self.config:
raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
"currently implemented. Please pass a template to chat_session() directly.")
if (tmpl := self.config["chatTemplate"]) is None:
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
chat_template = tmpl
if re.search(r"%1(?![0-9])", prompt_template):
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
"placeholder, please use '{0}' instead.")
self._history = [{"role": "system", "content": system_prompt}]
self._current_prompt_template = prompt_template
history = []
if system_message is not False:
history.append(MessageType(role="system", content=system_message))
self._chat_session = ChatSession(
template=_jinja_env.from_string(chat_template),
history=history,
)
try:
yield self
finally:
self._history = None
self._current_prompt_template = "{0}"
self._chat_session = None
@staticmethod
def list_gpus() -> list[str]:
@@ -647,43 +647,6 @@ class GPT4All:
"""
return LLModel.list_gpus()
def _format_chat_prompt_template(
self,
messages: list[MessageType],
default_prompt_header: str = "",
default_prompt_footer: str = "",
) -> str:
"""
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
Warning:
This function was deprecated in version 2.3.0, and will be removed in a future release.
Args:
messages: List of dictionaries. Each dictionary should have a "role" key
with value of "system", "assistant", or "user" and a "content" key with a
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Response: {content}".
Returns:
Formatted prompt.
"""
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
for message in messages:
if message["role"] == "user":
user_message = self._current_prompt_template.format(message["content"])
full_prompt += user_message
if message["role"] == "assistant":
assistant_message = message["content"] + "\n"
full_prompt += assistant_message
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
return full_prompt
def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")):
@@ -696,7 +659,7 @@ class _HasFileno(Protocol):
def _fsync(fd: int | _HasFileno) -> None:
if sys.platform == 'darwin':
if sys.platform == "darwin":
# Apple's fsync does not flush the drive write cache
try:
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)