python: load templates from model files, and add legacy template warning

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-12-05 14:35:48 -05:00
parent 2db59f0092
commit d6638b5064
4 changed files with 95 additions and 16 deletions

View File

@ -312,6 +312,8 @@ int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, con
void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);
const char *llmodel_model_chat_template(const char *model_path, const char **error);
#ifdef __cplusplus
}
#endif

View File

@ -34,11 +34,11 @@ llmodel_model llmodel_model_create(const char *model_path)
return fres;
}
static void llmodel_set_error(const char **errptr, const char *message)
static void llmodel_set_error(const char **errptr, std::string message)
{
thread_local static std::string last_error_message;
if (errptr) {
last_error_message = message;
last_error_message = std::move(message);
*errptr = last_error_message.c_str();
}
}
@ -318,3 +318,15 @@ void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_to
for (auto &[name, token] : wrapper->llModel->specialTokens())
callback(name.c_str(), token.c_str());
}
const char *llmodel_model_chat_template(const char *model_path, const char **error)
{
static std::string s_chatTemplate;
auto res = LLModel::Implementation::chatTemplate(model_path);
if (res) {
s_chatTemplate = *res;
return s_chatTemplate.c_str();
}
llmodel_set_error(error, std::move(res.error()));
return nullptr;
}

View File

@ -227,6 +227,9 @@ llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
llmodel.llmodel_model_foreach_special_token.restype = None
llmodel.llmodel_model_chat_template.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p)]
llmodel.llmodel_model_chat_template.restype = ctypes.c_char_p
ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool]
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
@ -290,10 +293,7 @@ class LLModel:
raise RuntimeError(f"Unable to instantiate model: {errmsg}")
self.model: ctypes.c_void_p | None = model
self.special_tokens_map: dict[str, str] = {}
llmodel.llmodel_model_foreach_special_token(
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
)
self._special_tokens_map: dict[str, str] | None = None
def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'):
@ -320,6 +320,26 @@ class LLModel:
dev = llmodel.llmodel_model_gpu_device_name(self.model)
return None if dev is None else dev.decode()
@property
def builtin_chat_template(self) -> str:
err = ctypes.c_char_p()
tmpl = llmodel.llmodel_model_chat_template(self.model_path, ctypes.byref(err))
if tmpl is not None:
return tmpl.decode()
s = err.value
raise ValueError('Failed to get chat template', 'null' if s is None else s.decode())
@property
def special_tokens_map(self) -> dict[str, str]:
if self.model is None:
self._raise_closed()
if self._special_tokens_map is None:
tokens: dict[str, str] = {}
cb = SpecialTokenCallback(lambda n, t: tokens.__setitem__(n.decode(), t.decode()))
llmodel.llmodel_model_foreach_special_token(self.model, cb)
self._special_tokens_map = tokens
return self._special_tokens_map
def count_prompt_tokens(self, prompt: str) -> int:
if self.model is None:
self._raise_closed()

View File

@ -62,8 +62,9 @@ class MessageType(TypedDict):
class ChatSession(NamedTuple):
template: jinja2.Template
history: list[MessageType]
template: jinja2.Template
template_source: str
history: list[MessageType]
class Embed4All:
@ -193,6 +194,16 @@ class GPT4All:
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
"""
RE_LEGACY_SYSPROMPT = re.compile(
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
re.MULTILINE,
)
RE_JINJA_LIKE = re.compile(
r"\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}",
re.DOTALL,
)
def __init__(
self,
model_name: str,
@ -260,6 +271,7 @@ class GPT4All:
# Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
self._was_allow_download = allow_download
self.model = LLModel(self.config["path"], n_ctx, ngl, backend)
if device_init is not None:
self.model.init_gpu(device_init)
@ -300,6 +312,10 @@ class GPT4All:
raise ValueError("current_chat_session may only be set when there is an active chat session")
self._chat_session.history[:] = history
@property
def current_chat_template(self) -> str | None:
return None if self._chat_session is None else self._chat_session.template_source
@staticmethod
def list_models() -> list[ConfigType]:
"""
@ -598,11 +614,19 @@ class GPT4All:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
return full_response
@classmethod
def is_legacy_chat_template(cls, tmpl: str) -> bool:
"""A fairly reliable heuristic for detecting templates that don't look like Jinja templates."""
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
or not re.search(r"\bcontent\b", tmpl))
@contextmanager
def chat_session(
self,
*,
system_message: str | Literal[False] | None = None,
chat_template: str | None = None,
warn_legacy: bool = True,
):
"""
Context manager to hold an inference optimized chat session with a GPT4All model.
@ -614,22 +638,43 @@ class GPT4All:
if system_message is None:
system_message = self.config.get("systemMessage", False)
elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)):
print(
"Warning: chat_session() was passed a system message that is not plain text. System messages "
f"containing {m.group()!r} or with any special prefix/suffix are no longer supported.\nTo disable this "
"warning, pass warn_legacy=False.",
file=sys.stderr,
)
if chat_template is None:
if "name" not in self.config:
raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
if "chatTemplate" not in self.config:
raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
"currently implemented. Please pass a template to chat_session() directly.")
if (tmpl := self.config["chatTemplate"]) is None:
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
chat_template = tmpl
if "chatTemplate" in self.config:
if (tmpl := self.config["chatTemplate"]) is None:
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
chat_template = tmpl
else:
try:
chat_template = self.model.builtin_chat_template
except ValueError as e:
if len(e.args) >= 2 and isinstance(err := e.args[1], str):
msg = (f"Failed to load default chat template from model: {err}\n"
"Please pass a template to chat_session() directly.")
if not self._was_allow_download:
msg += " If this is a built-in model, consider setting allow_download to True."
raise ValueError(msg) from None
raise
elif warn_legacy and self.is_legacy_chat_template(chat_template):
print(
"Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
"templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.",
file=sys.stderr,
)
history = []
if system_message is not False:
history.append(MessageType(role="system", content=system_message))
self._chat_session = ChatSession(
template=_jinja_env.from_string(chat_template),
template_source=chat_template,
history=history,
)
try: