diff --git a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h index 271475ba..3f7b0851 100644 --- a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h @@ -312,6 +312,8 @@ int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, con void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback); +const char *llmodel_model_chat_template(const char *model_path, const char **error); + #ifdef __cplusplus } #endif diff --git a/gpt4all-backend/src/llmodel_c.cpp b/gpt4all-backend/src/llmodel_c.cpp index a8c5554d..cb44da71 100644 --- a/gpt4all-backend/src/llmodel_c.cpp +++ b/gpt4all-backend/src/llmodel_c.cpp @@ -34,11 +34,11 @@ llmodel_model llmodel_model_create(const char *model_path) return fres; } -static void llmodel_set_error(const char **errptr, const char *message) +static void llmodel_set_error(const char **errptr, std::string message) { thread_local static std::string last_error_message; if (errptr) { - last_error_message = message; + last_error_message = std::move(message); *errptr = last_error_message.c_str(); } } @@ -318,3 +318,15 @@ void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_to for (auto &[name, token] : wrapper->llModel->specialTokens()) callback(name.c_str(), token.c_str()); } + +const char *llmodel_model_chat_template(const char *model_path, const char **error) +{ + static std::string s_chatTemplate; + auto res = LLModel::Implementation::chatTemplate(model_path); + if (res) { + s_chatTemplate = *res; + return s_chatTemplate.c_str(); + } + llmodel_set_error(error, std::move(res.error())); + return nullptr; +} diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 616ce80a..c59f1dc5 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -227,6 +227,9 @@ llmodel.llmodel_count_prompt_tokens.restype = ctypes.c_int32 llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback] llmodel.llmodel_model_foreach_special_token.restype = None +llmodel.llmodel_model_chat_template.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p)] +llmodel.llmodel_model_chat_template.restype = ctypes.c_char_p + ResponseCallbackType = Callable[[int, str], bool] RawResponseCallbackType = Callable[[int, bytes], bool] EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]' @@ -290,10 +293,7 @@ class LLModel: raise RuntimeError(f"Unable to instantiate model: {errmsg}") self.model: ctypes.c_void_p | None = model - self.special_tokens_map: dict[str, str] = {} - llmodel.llmodel_model_foreach_special_token( - self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()), - ) + self._special_tokens_map: dict[str, str] | None = None def __del__(self, llmodel=llmodel): if hasattr(self, 'model'): @@ -320,6 +320,26 @@ class LLModel: dev = llmodel.llmodel_model_gpu_device_name(self.model) return None if dev is None else dev.decode() + @property + def builtin_chat_template(self) -> str: + err = ctypes.c_char_p() + tmpl = llmodel.llmodel_model_chat_template(self.model_path, ctypes.byref(err)) + if tmpl is not None: + return tmpl.decode() + s = err.value + raise ValueError('Failed to get chat template', 'null' if s is None else s.decode()) + + @property + def special_tokens_map(self) -> dict[str, str]: + if self.model is None: + self._raise_closed() + if self._special_tokens_map is None: + tokens: dict[str, str] = {} + cb = SpecialTokenCallback(lambda n, t: tokens.__setitem__(n.decode(), t.decode())) + llmodel.llmodel_model_foreach_special_token(self.model, cb) + self._special_tokens_map = tokens + return self._special_tokens_map + def count_prompt_tokens(self, prompt: str) -> int: if self.model is None: self._raise_closed() diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 84b236c9..390eb410 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -62,8 +62,9 @@ class MessageType(TypedDict): class ChatSession(NamedTuple): - template: jinja2.Template - history: list[MessageType] + template: jinja2.Template + template_source: str + history: list[MessageType] class Embed4All: @@ -193,6 +194,16 @@ class GPT4All: Python class that handles instantiation, downloading, generation and chat with GPT4All models. """ + RE_LEGACY_SYSPROMPT = re.compile( + r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<>", + re.MULTILINE, + ) + + RE_JINJA_LIKE = re.compile( + r"\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}", + re.DOTALL, + ) + def __init__( self, model_name: str, @@ -260,6 +271,7 @@ class GPT4All: # Retrieve model and download if allowed self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose) + self._was_allow_download = allow_download self.model = LLModel(self.config["path"], n_ctx, ngl, backend) if device_init is not None: self.model.init_gpu(device_init) @@ -300,6 +312,10 @@ class GPT4All: raise ValueError("current_chat_session may only be set when there is an active chat session") self._chat_session.history[:] = history + @property + def current_chat_template(self) -> str | None: + return None if self._chat_session is None else self._chat_session.template_source + @staticmethod def list_models() -> list[ConfigType]: """ @@ -598,11 +614,19 @@ class GPT4All: self._chat_session.history.append(MessageType(role="assistant", content=full_response)) return full_response + @classmethod + def is_legacy_chat_template(cls, tmpl: str) -> bool: + """A fairly reliable heuristic for detecting templates that don't look like Jinja templates.""" + return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl) + or not re.search(r"\bcontent\b", tmpl)) + @contextmanager def chat_session( self, + *, system_message: str | Literal[False] | None = None, chat_template: str | None = None, + warn_legacy: bool = True, ): """ Context manager to hold an inference optimized chat session with a GPT4All model. @@ -614,22 +638,43 @@ class GPT4All: if system_message is None: system_message = self.config.get("systemMessage", False) + elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)): + print( + "Warning: chat_session() was passed a system message that is not plain text. System messages " + f"containing {m.group()!r} or with any special prefix/suffix are no longer supported.\nTo disable this " + "warning, pass warn_legacy=False.", + file=sys.stderr, + ) if chat_template is None: - if "name" not in self.config: - raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.") - if "chatTemplate" not in self.config: - raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not " - "currently implemented. Please pass a template to chat_session() directly.") - if (tmpl := self.config["chatTemplate"]) is None: - raise ValueError(f"The model {self.config['name']!r} does not support chat.") - chat_template = tmpl + if "chatTemplate" in self.config: + if (tmpl := self.config["chatTemplate"]) is None: + raise ValueError(f"The model {self.config['name']!r} does not support chat.") + chat_template = tmpl + else: + try: + chat_template = self.model.builtin_chat_template + except ValueError as e: + if len(e.args) >= 2 and isinstance(err := e.args[1], str): + msg = (f"Failed to load default chat template from model: {err}\n" + "Please pass a template to chat_session() directly.") + if not self._was_allow_download: + msg += " If this is a built-in model, consider setting allow_download to True." + raise ValueError(msg) from None + raise + elif warn_legacy and self.is_legacy_chat_template(chat_template): + print( + "Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt " + "templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.", + file=sys.stderr, + ) history = [] if system_message is not False: history.append(MessageType(role="system", content=system_message)) self._chat_session = ChatSession( template=_jinja_env.from_string(chat_template), + template_source=chat_template, history=history, ) try: