mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-12 05:12:07 +00:00
resolve some flake8 complaints about new code
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
b3cc860233
commit
260ad4b163
@ -9,7 +9,7 @@ import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, TypeVar, overload
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
import importlib.resources as importlib_resources
|
||||
|
@ -37,9 +37,9 @@ DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
||||
|
||||
ConfigType: TypeAlias = "dict[str, Any]"
|
||||
|
||||
# Environment setup adapted from HF transformers
|
||||
@_operator_call
|
||||
def _jinja_env() -> ImmutableSandboxedEnvironment:
|
||||
# Environment setup adapted from HF transformers
|
||||
def raise_exception(message: str) -> NoReturn:
|
||||
raise jinja2.exceptions.TemplateError(message)
|
||||
|
||||
@ -56,15 +56,17 @@ def _jinja_env() -> ImmutableSandboxedEnvironment:
|
||||
return env
|
||||
|
||||
|
||||
class MessageType(TypedDict):
|
||||
class Message(TypedDict):
|
||||
"""A message in a chat with a GPT4All model."""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatSession(NamedTuple):
|
||||
class _ChatSession(NamedTuple):
|
||||
template: jinja2.Template
|
||||
template_source: str
|
||||
history: list[MessageType]
|
||||
history: list[Message]
|
||||
|
||||
|
||||
class Embed4All:
|
||||
@ -195,7 +197,8 @@ class GPT4All:
|
||||
"""
|
||||
|
||||
RE_LEGACY_SYSPROMPT = re.compile(
|
||||
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
|
||||
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|"
|
||||
r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
@ -244,7 +247,7 @@ class GPT4All:
|
||||
"""
|
||||
|
||||
self.model_type = model_type
|
||||
self._chat_session: ChatSession | None = None
|
||||
self._chat_session: _ChatSession | None = None
|
||||
|
||||
device_init = None
|
||||
if sys.platform == "darwin":
|
||||
@ -303,11 +306,12 @@ class GPT4All:
|
||||
return self.model.device
|
||||
|
||||
@property
|
||||
def current_chat_session(self) -> list[MessageType] | None:
|
||||
def current_chat_session(self) -> list[Message] | None:
|
||||
"""The message history of the current chat session."""
|
||||
return None if self._chat_session is None else self._chat_session.history
|
||||
|
||||
@current_chat_session.setter
|
||||
def current_chat_session(self, history: list[MessageType]) -> None:
|
||||
def current_chat_session(self, history: list[Message]) -> None:
|
||||
if self._chat_session is None:
|
||||
raise ValueError("current_chat_session may only be set when there is an active chat session")
|
||||
self._chat_session.history[:] = history
|
||||
@ -585,13 +589,13 @@ class GPT4All:
|
||||
last_msg_rendered = prompt
|
||||
if self._chat_session is not None:
|
||||
session = self._chat_session
|
||||
def render(messages: list[MessageType]) -> str:
|
||||
def render(messages: list[Message]) -> str:
|
||||
return session.template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=True,
|
||||
**self.model.special_tokens_map,
|
||||
)
|
||||
session.history.append(MessageType(role="user", content=prompt))
|
||||
session.history.append(Message(role="user", content=prompt))
|
||||
prompt = render(session.history)
|
||||
if len(session.history) > 1:
|
||||
last_msg_rendered = render(session.history[-1:])
|
||||
@ -606,20 +610,14 @@ class GPT4All:
|
||||
def stream() -> Iterator[str]:
|
||||
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
|
||||
if self._chat_session is not None:
|
||||
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
||||
self._chat_session.history.append(Message(role="assistant", content=full_response))
|
||||
return stream()
|
||||
|
||||
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
|
||||
if self._chat_session is not None:
|
||||
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
|
||||
self._chat_session.history.append(Message(role="assistant", content=full_response))
|
||||
return full_response
|
||||
|
||||
@classmethod
|
||||
def is_legacy_chat_template(cls, tmpl: str) -> bool:
|
||||
"""A fairly reliable heuristic for detecting templates that don't look like Jinja templates."""
|
||||
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
|
||||
or not re.search(r"\bcontent\b", tmpl))
|
||||
|
||||
@contextmanager
|
||||
def chat_session(
|
||||
self,
|
||||
@ -632,10 +630,14 @@ class GPT4All:
|
||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||
|
||||
Args:
|
||||
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
|
||||
system_message: An initial instruction for the model, None to use the model default, or False to disable.
|
||||
Defaults to None.
|
||||
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
|
||||
"""
|
||||
warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid chat template was found.
|
||||
"""
|
||||
if system_message is None:
|
||||
system_message = self.config.get("systemMessage", False)
|
||||
elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)):
|
||||
@ -662,7 +664,7 @@ class GPT4All:
|
||||
msg += " If this is a built-in model, consider setting allow_download to True."
|
||||
raise ValueError(msg) from None
|
||||
raise
|
||||
elif warn_legacy and self.is_legacy_chat_template(chat_template):
|
||||
elif warn_legacy and self._is_legacy_chat_template(chat_template):
|
||||
print(
|
||||
"Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
|
||||
"templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.",
|
||||
@ -671,8 +673,8 @@ class GPT4All:
|
||||
|
||||
history = []
|
||||
if system_message is not False:
|
||||
history.append(MessageType(role="system", content=system_message))
|
||||
self._chat_session = ChatSession(
|
||||
history.append(Message(role="system", content=system_message))
|
||||
self._chat_session = _ChatSession(
|
||||
template=_jinja_env.from_string(chat_template),
|
||||
template_source=chat_template,
|
||||
history=history,
|
||||
@ -692,6 +694,12 @@ class GPT4All:
|
||||
"""
|
||||
return LLModel.list_gpus()
|
||||
|
||||
@classmethod
|
||||
def _is_legacy_chat_template(cls, tmpl: str) -> bool:
|
||||
# check if tmpl does not look like a Jinja template
|
||||
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
|
||||
or not re.search(r"\bcontent\b", tmpl))
|
||||
|
||||
|
||||
def append_extension_if_missing(model_name):
|
||||
if not model_name.endswith((".bin", ".gguf")):
|
||||
|
Loading…
Reference in New Issue
Block a user