mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-13 13:46:08 +00:00
Merge 260ad4b163
into cd70db29ed
This commit is contained in:
commit
9fb7aedbc6
@ -312,6 +312,8 @@ int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, con
|
|||||||
|
|
||||||
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
|
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
|
||||||
|
|
||||||
|
const char *llmodel_model_chat_template(const char *model_path, const char **error);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -34,11 +34,11 @@ llmodel_model llmodel_model_create(const char *model_path)
|
|||||||
return fres;
|
return fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llmodel_set_error(const char **errptr, const char *message)
|
static void llmodel_set_error(const char **errptr, std::string message)
|
||||||
{
|
{
|
||||||
thread_local static std::string last_error_message;
|
thread_local static std::string last_error_message;
|
||||||
if (errptr) {
|
if (errptr) {
|
||||||
last_error_message = message;
|
last_error_message = std::move(message);
|
||||||
*errptr = last_error_message.c_str();
|
*errptr = last_error_message.c_str();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -318,3 +318,15 @@ void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_to
|
|||||||
for (auto &[name, token] : wrapper->llModel->specialTokens())
|
for (auto &[name, token] : wrapper->llModel->specialTokens())
|
||||||
callback(name.c_str(), token.c_str());
|
callback(name.c_str(), token.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char *llmodel_model_chat_template(const char *model_path, const char **error)
|
||||||
|
{
|
||||||
|
static std::string s_chatTemplate;
|
||||||
|
auto res = LLModel::Implementation::chatTemplate(model_path);
|
||||||
|
if (res) {
|
||||||
|
s_chatTemplate = *res;
|
||||||
|
return s_chatTemplate.c_str();
|
||||||
|
}
|
||||||
|
llmodel_set_error(error, std::move(res.error()));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
@ -9,7 +9,7 @@ import textwrap
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, TypeVar, overload
|
||||||
|
|
||||||
if sys.version_info >= (3, 9):
|
if sys.version_info >= (3, 9):
|
||||||
import importlib.resources as importlib_resources
|
import importlib.resources as importlib_resources
|
||||||
@ -227,6 +227,9 @@ llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32
|
|||||||
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
|
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
|
||||||
llmodel.llmodel_model_foreach_special_token.restype = None
|
llmodel.llmodel_model_foreach_special_token.restype = None
|
||||||
|
|
||||||
|
llmodel.llmodel_model_chat_template.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p)]
|
||||||
|
llmodel.llmodel_model_chat_template.restype = ctypes.c_char_p
|
||||||
|
|
||||||
ResponseCallbackType = Callable[[int, str], bool]
|
ResponseCallbackType = Callable[[int, str], bool]
|
||||||
RawResponseCallbackType = Callable[[int, bytes], bool]
|
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||||
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
|
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
|
||||||
@ -290,10 +293,7 @@ class LLModel:
|
|||||||
|
|
||||||
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
|
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
|
||||||
self.model: ctypes.c_void_p | None = model
|
self.model: ctypes.c_void_p | None = model
|
||||||
self.special_tokens_map: dict[str, str] = {}
|
self._special_tokens_map: dict[str, str] | None = None
|
||||||
llmodel.llmodel_model_foreach_special_token(
|
|
||||||
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __del__(self, llmodel=llmodel):
|
def __del__(self, llmodel=llmodel):
|
||||||
if hasattr(self, 'model'):
|
if hasattr(self, 'model'):
|
||||||
@ -320,6 +320,26 @@ class LLModel:
|
|||||||
dev = llmodel.llmodel_model_gpu_device_name(self.model)
|
dev = llmodel.llmodel_model_gpu_device_name(self.model)
|
||||||
return None if dev is None else dev.decode()
|
return None if dev is None else dev.decode()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def builtin_chat_template(self) -> str:
|
||||||
|
err = ctypes.c_char_p()
|
||||||
|
tmpl = llmodel.llmodel_model_chat_template(self.model_path, ctypes.byref(err))
|
||||||
|
if tmpl is not None:
|
||||||
|
return tmpl.decode()
|
||||||
|
s = err.value
|
||||||
|
raise ValueError('Failed to get chat template', 'null' if s is None else s.decode())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def special_tokens_map(self) -> dict[str, str]:
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
|
if self._special_tokens_map is None:
|
||||||
|
tokens: dict[str, str] = {}
|
||||||
|
cb = SpecialTokenCallback(lambda n, t: tokens.__setitem__(n.decode(), t.decode()))
|
||||||
|
llmodel.llmodel_model_foreach_special_token(self.model, cb)
|
||||||
|
self._special_tokens_map = tokens
|
||||||
|
return self._special_tokens_map
|
||||||
|
|
||||||
def count_prompt_tokens(self, prompt: str) -> int:
|
def count_prompt_tokens(self, prompt: str) -> int:
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self._raise_closed()
|
self._raise_closed()
|
||||||
@ -331,8 +351,6 @@ class LLModel:
|
|||||||
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
|
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
|
||||||
return n_tok
|
return n_tok
|
||||||
|
|
||||||
llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_gpus(mem_required: int = 0) -> list[str]:
|
def list_gpus(mem_required: int = 0) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -37,9 +37,9 @@ DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
|||||||
|
|
||||||
ConfigType: TypeAlias = "dict[str, Any]"
|
ConfigType: TypeAlias = "dict[str, Any]"
|
||||||
|
|
||||||
# Environment setup adapted from HF transformers
|
|
||||||
@_operator_call
|
@_operator_call
|
||||||
def _jinja_env() -> ImmutableSandboxedEnvironment:
|
def _jinja_env() -> ImmutableSandboxedEnvironment:
|
||||||
|
# Environment setup adapted from HF transformers
|
||||||
def raise_exception(message: str) -> NoReturn:
|
def raise_exception(message: str) -> NoReturn:
|
||||||
raise jinja2.exceptions.TemplateError(message)
|
raise jinja2.exceptions.TemplateError(message)
|
||||||
|
|
||||||
@ -56,14 +56,17 @@ def _jinja_env() -> ImmutableSandboxedEnvironment:
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
class MessageType(TypedDict):
|
class Message(TypedDict):
|
||||||
|
"""A message in a chat with a GPT4All model."""
|
||||||
|
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class ChatSession(NamedTuple):
|
class _ChatSession(NamedTuple):
|
||||||
template: jinja2.Template
|
template: jinja2.Template
|
||||||
history: list[MessageType]
|
template_source: str
|
||||||
|
history: list[Message]
|
||||||
|
|
||||||
|
|
||||||
class Embed4All:
|
class Embed4All:
|
||||||
@ -193,6 +196,17 @@ class GPT4All:
|
|||||||
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
|
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
RE_LEGACY_SYSPROMPT = re.compile(
|
||||||
|
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|"
|
||||||
|
r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
|
||||||
|
re.MULTILINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
RE_JINJA_LIKE = re.compile(
|
||||||
|
r"\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -233,7 +247,7 @@ class GPT4All:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self._chat_session: ChatSession | None = None
|
self._chat_session: _ChatSession | None = None
|
||||||
|
|
||||||
device_init = None
|
device_init = None
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
@ -260,6 +274,7 @@ class GPT4All:
|
|||||||
|
|
||||||
# Retrieve model and download if allowed
|
# Retrieve model and download if allowed
|
||||||
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
||||||
|
self._was_allow_download = allow_download
|
||||||
self.model = LLModel(self.config["path"], n_ctx, ngl, backend)
|
self.model = LLModel(self.config["path"], n_ctx, ngl, backend)
|
||||||
if device_init is not None:
|
if device_init is not None:
|
||||||
self.model.init_gpu(device_init)
|
self.model.init_gpu(device_init)
|
||||||
@ -291,15 +306,20 @@ class GPT4All:
|
|||||||
return self.model.device
|
return self.model.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_chat_session(self) -> list[MessageType] | None:
|
def current_chat_session(self) -> list[Message] | None:
|
||||||
|
"""The message history of the current chat session."""
|
||||||
return None if self._chat_session is None else self._chat_session.history
|
return None if self._chat_session is None else self._chat_session.history
|
||||||
|
|
||||||
@current_chat_session.setter
|
@current_chat_session.setter
|
||||||
def current_chat_session(self, history: list[MessageType]) -> None:
|
def current_chat_session(self, history: list[Message]) -> None:
|
||||||
if self._chat_session is None:
|
if self._chat_session is None:
|
||||||
raise ValueError("current_chat_session may only be set when there is an active chat session")
|
raise ValueError("current_chat_session may only be set when there is an active chat session")
|
||||||
self._chat_session.history[:] = history
|
self._chat_session.history[:] = history
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_chat_template(self) -> str | None:
|
||||||
|
return None if self._chat_session is None else self._chat_session.template_source
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_models() -> list[ConfigType]:
|
def list_models() -> list[ConfigType]:
|
||||||
"""
|
"""
|
||||||
@ -569,13 +589,13 @@ class GPT4All:
|
|||||||
last_msg_rendered = prompt
|
last_msg_rendered = prompt
|
||||||
if self._chat_session is not None:
|
if self._chat_session is not None:
|
||||||
session = self._chat_session
|
session = self._chat_session
|
||||||
def render(messages: list[MessageType]) -> str:
|
def render(messages: list[Message]) -> str:
|
||||||
return session.template.render(
|
return session.template.render(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
**self.model.special_tokens_map,
|
**self.model.special_tokens_map,
|
||||||
)
|
)
|
||||||
session.history.append(MessageType(role="user", content=prompt))
|
session.history.append(Message(role="user", content=prompt))
|
||||||
prompt = render(session.history)
|
prompt = render(session.history)
|
||||||
if len(session.history) > 1:
|
if len(session.history) > 1:
|
||||||
last_msg_rendered = render(session.history[-1:])
|
last_msg_rendered = render(session.history[-1:])
|
||||||
@ -590,46 +610,73 @@ class GPT4All:
|
|||||||
def stream() -> Iterator[str]:
|
def stream() -> Iterator[str]:
|
||||||
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
|
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
|
||||||
if self._chat_session is not None:
|
if self._chat_session is not None:
|
||||||
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
self._chat_session.history.append(Message(role="assistant", content=full_response))
|
||||||
return stream()
|
return stream()
|
||||||
|
|
||||||
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
|
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
|
||||||
if self._chat_session is not None:
|
if self._chat_session is not None:
|
||||||
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
self._chat_session.history.append(Message(role="assistant", content=full_response))
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def chat_session(
|
def chat_session(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
system_message: str | Literal[False] | None = None,
|
system_message: str | Literal[False] | None = None,
|
||||||
chat_template: str | None = None,
|
chat_template: str | None = None,
|
||||||
|
warn_legacy: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
|
system_message: An initial instruction for the model, None to use the model default, or False to disable.
|
||||||
|
Defaults to None.
|
||||||
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
|
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
|
||||||
"""
|
warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no valid chat template was found.
|
||||||
|
"""
|
||||||
if system_message is None:
|
if system_message is None:
|
||||||
system_message = self.config.get("systemMessage", False)
|
system_message = self.config.get("systemMessage", False)
|
||||||
|
elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)):
|
||||||
|
print(
|
||||||
|
"Warning: chat_session() was passed a system message that is not plain text. System messages "
|
||||||
|
f"containing {m.group()!r} or with any special prefix/suffix are no longer supported.\nTo disable this "
|
||||||
|
"warning, pass warn_legacy=False.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
if chat_template is None:
|
if chat_template is None:
|
||||||
if "name" not in self.config:
|
if "chatTemplate" in self.config:
|
||||||
raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
|
|
||||||
if "chatTemplate" not in self.config:
|
|
||||||
raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
|
|
||||||
"currently implemented. Please pass a template to chat_session() directly.")
|
|
||||||
if (tmpl := self.config["chatTemplate"]) is None:
|
if (tmpl := self.config["chatTemplate"]) is None:
|
||||||
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
|
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
|
||||||
chat_template = tmpl
|
chat_template = tmpl
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
chat_template = self.model.builtin_chat_template
|
||||||
|
except ValueError as e:
|
||||||
|
if len(e.args) >= 2 and isinstance(err := e.args[1], str):
|
||||||
|
msg = (f"Failed to load default chat template from model: {err}\n"
|
||||||
|
"Please pass a template to chat_session() directly.")
|
||||||
|
if not self._was_allow_download:
|
||||||
|
msg += " If this is a built-in model, consider setting allow_download to True."
|
||||||
|
raise ValueError(msg) from None
|
||||||
|
raise
|
||||||
|
elif warn_legacy and self._is_legacy_chat_template(chat_template):
|
||||||
|
print(
|
||||||
|
"Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
|
||||||
|
"templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
if system_message is not False:
|
if system_message is not False:
|
||||||
history.append(MessageType(role="system", content=system_message))
|
history.append(Message(role="system", content=system_message))
|
||||||
self._chat_session = ChatSession(
|
self._chat_session = _ChatSession(
|
||||||
template=_jinja_env.from_string(chat_template),
|
template=_jinja_env.from_string(chat_template),
|
||||||
|
template_source=chat_template,
|
||||||
history=history,
|
history=history,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -647,6 +694,12 @@ class GPT4All:
|
|||||||
"""
|
"""
|
||||||
return LLModel.list_gpus()
|
return LLModel.list_gpus()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_legacy_chat_template(cls, tmpl: str) -> bool:
|
||||||
|
# check if tmpl does not look like a Jinja template
|
||||||
|
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
|
||||||
|
or not re.search(r"\bcontent\b", tmpl))
|
||||||
|
|
||||||
|
|
||||||
def append_extension_if_missing(model_name):
|
def append_extension_if_missing(model_name):
|
||||||
if not model_name.endswith((".bin", ".gguf")):
|
if not model_name.endswith((".bin", ".gguf")):
|
||||||
|
@ -68,7 +68,7 @@ def get_long_description():
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=package_name,
|
||||||
version="2.8.3.dev0",
|
version="3.0.0.dev0",
|
||||||
description="Python bindings for GPT4All",
|
description="Python bindings for GPT4All",
|
||||||
long_description=get_long_description(),
|
long_description=get_long_description(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
Loading…
Reference in New Issue
Block a user