diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index b54a8696..06c9223e 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -115,7 +115,9 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_reloadingToChangeVariant(false) , m_processedSystemPrompt(false) , m_restoreStateFromText(false) + , m_checkToolCall(false) , m_maybeToolCall(false) + , m_foundToolCall(false) { moveToThread(&m_llmThread); connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, @@ -705,34 +707,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) return false; } - // Only valid for llama 3.1 instruct - if (m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct")) { - // Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling - // For brave_search and wolfram_alpha ipython is always used - - // <|python_tag|> - // brave_search.call(query="...") - // <|eom_id|> - const int eom_id = 128008; - const int python_tag = 128010; - - // If we have a built-in tool call, then it should be the first token - const bool isFirstResponseToken = m_promptResponseTokens == m_promptTokens; - Q_ASSERT(token != python_tag || isFirstResponseToken); - if (isFirstResponseToken && token == python_tag) { - m_maybeToolCall = true; - ++m_promptResponseTokens; - return !m_stopGenerating; - } - - // Check for end of built-in tool call - Q_ASSERT(token != eom_id || !m_maybeToolCall); - if (token == eom_id) { - ++m_promptResponseTokens; - return false; - } - } - // m_promptResponseTokens is related to last prompt/response not // the entire context window which we can reset on regenerate prompt ++m_promptResponseTokens; @@ -740,7 +714,25 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) Q_ASSERT(!response.empty()); m_response.append(response); - if (!m_maybeToolCall) + // If we're checking for a tool called and the response is equal or exceeds 11 chars + // then we check + if (m_checkToolCall && m_response.size() >= 11) { + if (m_response.starts_with("")) { + m_maybeToolCall = true; + m_response.erase(0, 11); + } + m_checkToolCall = false; + } + + // Check if we're at the end of tool call and erase the end tag + if (m_maybeToolCall && m_response.ends_with("")) { + m_foundToolCall = true; + m_response.erase(m_response.length() - 12); + return false; + } + + // If we're not checking for tool call and haven't detected one, then send along the response + if (!m_checkToolCall && !m_maybeToolCall) emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); return !m_stopGenerating; @@ -822,8 +814,12 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString /*allowContextShift*/ true, m_ctx); m_ctx.n_predict = old_n_predict; // now we are ready for a response } + + m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); + m_checkToolCall = false; + m_maybeToolCall = false; #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -831,8 +827,8 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->stop(); qint64 elapsed = totalTime.elapsed(); std::string trimmed = trim_whitespace(m_response); - if (m_maybeToolCall) { - m_maybeToolCall = false; + if (m_foundToolCall) { + m_foundToolCall = false; m_ctx.n_past = std::max(0, m_ctx.n_past); m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; @@ -860,25 +856,46 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { - Q_ASSERT(m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct")); - emit toolCalled(tr("searching web...")); + QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); + if (toolTemplate.isEmpty()) { + // FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for + // us to process it. We should probably not even attempt further generation and just show an + // error in the chat somehow? + qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, + MySettings::globalInstance()->modelPromptTemplate(m_modelInfo), + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + } - // Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling - // For brave_search and wolfram_alpha ipython is always used + QJsonParseError err; + QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); - static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))"); - QRegularExpressionMatch match = re.match(toolCall); + if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { + qWarning() << "WARNING: The tool call had null or invalid json " << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + } - QString promptTemplate("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); - QString query; - if (match.hasMatch()) { - query = match.captured(1); - } else { - qWarning() << "WARNING: Could not find the tool for " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, promptTemplate, + QJsonObject rootObject = toolCallDoc.object(); + if (!rootObject.contains("name") || !rootObject.contains("arguments")) { + qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); } + const QString tool = toolCallDoc["name"].toString(); + const QJsonObject args = toolCallDoc["arguments"].toObject(); + + if (tool != "brave_search" || !args.contains("query")) { + qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall; + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + } + + const QString query = args["query"].toString(); + + emit toolCalled(tr("searching web...")); + const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); Q_ASSERT(apiKey != ""); @@ -887,7 +904,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 emit sourceExcerptsChanged(braveResponse.second); - return promptInternal(QList()/*collectionList*/, braveResponse.first, promptTemplate, + return promptInternal(QList()/*collectionList*/, braveResponse.first, toolTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index b4874b70..d9d47ae9 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -242,7 +242,9 @@ private: bool m_reloadingToChangeVariant; bool m_processedSystemPrompt; bool m_restoreStateFromText; + bool m_checkToolCall; bool m_maybeToolCall; + bool m_foundToolCall; // m_pristineLoadedState is set if saveSate is unnecessary, either because: // - an unload was queued during LLModel::restoreState() // - the chat will be restored from text and hasn't been interacted with yet diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 580b615f..6ed91150 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -323,6 +323,17 @@ void ModelInfo::setPromptTemplate(const QString &t) m_promptTemplate = t; } +QString ModelInfo::toolTemplate() const +{ + return MySettings::globalInstance()->modelToolTemplate(*this); +} + +void ModelInfo::setToolTemplate(const QString &t) +{ + if (shouldSaveMetadata()) MySettings::globalInstance()->setModelToolTemplate(*this, t, true /*force*/); + m_toolTemplate = t; +} + QString ModelInfo::systemPrompt() const { return MySettings::globalInstance()->modelSystemPrompt(*this); @@ -385,6 +396,7 @@ QVariantMap ModelInfo::getFields() const { "repeatPenalty", m_repeatPenalty }, { "repeatPenaltyTokens", m_repeatPenaltyTokens }, { "promptTemplate", m_promptTemplate }, + { "toolTemplate", m_toolTemplate }, { "systemPrompt", m_systemPrompt }, { "chatNamePrompt", m_chatNamePrompt }, { "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt }, @@ -504,6 +516,7 @@ ModelList::ModelList() connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);; connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings); + connect(MySettings::globalInstance(), &MySettings::toolTemplateChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings); connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors); @@ -776,6 +789,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->repeatPenaltyTokens(); case PromptTemplateRole: return info->promptTemplate(); + case ToolTemplateRole: + return info->toolTemplate(); case SystemPromptRole: return info->systemPrompt(); case ChatNamePromptRole: @@ -952,6 +967,8 @@ void ModelList::updateData(const QString &id, const QVector info->setRepeatPenaltyTokens(value.toInt()); break; case PromptTemplateRole: info->setPromptTemplate(value.toString()); break; + case ToolTemplateRole: + info->setToolTemplate(value.toString()); break; case SystemPromptRole: info->setSystemPrompt(value.toString()); break; case ChatNamePromptRole: @@ -1107,6 +1124,7 @@ QString ModelList::clone(const ModelInfo &model) { ModelList::RepeatPenaltyRole, model.repeatPenalty() }, { ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() }, { ModelList::PromptTemplateRole, model.promptTemplate() }, + { ModelList::ToolTemplateRole, model.toolTemplate() }, { ModelList::SystemPromptRole, model.systemPrompt() }, { ModelList::ChatNamePromptRole, model.chatNamePrompt() }, { ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() }, @@ -1551,6 +1569,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() }); if (obj.contains("promptTemplate")) data.append({ ModelList::PromptTemplateRole, obj["promptTemplate"].toString() }); + if (obj.contains("toolTemplate")) + data.append({ ModelList::ToolTemplateRole, obj["toolTemplate"].toString() }); if (obj.contains("systemPrompt")) data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() }); updateData(id, data); @@ -1852,6 +1872,10 @@ void ModelList::updateModelsFromSettings() const QString promptTemplate = settings.value(g + "/promptTemplate").toString(); data.append({ ModelList::PromptTemplateRole, promptTemplate }); } + if (settings.contains(g + "/toolTemplate")) { + const QString toolTemplate = settings.value(g + "/toolTemplate").toString(); + data.append({ ModelList::ToolTemplateRole, toolTemplate }); + } if (settings.contains(g + "/systemPrompt")) { const QString systemPrompt = settings.value(g + "/systemPrompt").toString(); data.append({ ModelList::SystemPromptRole, systemPrompt }); diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 7c13da8e..9e9f088f 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -68,6 +68,7 @@ struct ModelInfo { Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty) Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens) Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate) + Q_PROPERTY(QString toolTemplate READ toolTemplate WRITE setToolTemplate) Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt) Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt) Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt) @@ -178,6 +179,8 @@ public: void setRepeatPenaltyTokens(int t); QString promptTemplate() const; void setPromptTemplate(const QString &t); + QString toolTemplate() const; + void setToolTemplate(const QString &t); QString systemPrompt() const; void setSystemPrompt(const QString &p); QString chatNamePrompt() const; @@ -215,6 +218,7 @@ private: double m_repeatPenalty = 1.18; int m_repeatPenaltyTokens = 64; QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n"; + QString m_toolTemplate = ""; QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n"; QString m_chatNamePrompt = "Describe the above conversation in seven words or less."; QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts."; @@ -339,6 +343,7 @@ public: RepeatPenaltyRole, RepeatPenaltyTokensRole, PromptTemplateRole, + ToolTemplateRole, SystemPromptRole, ChatNamePromptRole, SuggestedFollowUpPromptRole, @@ -393,6 +398,7 @@ public: roles[RepeatPenaltyRole] = "repeatPenalty"; roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens"; roles[PromptTemplateRole] = "promptTemplate"; + roles[ToolTemplateRole] = "toolTemplate"; roles[SystemPromptRole] = "systemPrompt"; roles[ChatNamePromptRole] = "chatNamePrompt"; roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt"; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index e1530a85..0b94989e 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -194,6 +194,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &info) setModelRepeatPenalty(info, info.m_repeatPenalty); setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens); setModelPromptTemplate(info, info.m_promptTemplate); + setModelToolTemplate(info, info.m_toolTemplate); setModelSystemPrompt(info, info.m_systemPrompt); setModelChatNamePrompt(info, info.m_chatNamePrompt); setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt); @@ -296,6 +297,7 @@ int MySettings::modelGpuLayers (const ModelInfo &info) const double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); } int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); } QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); } +QString MySettings::modelToolTemplate (const ModelInfo &info) const { return getModelSetting("toolTemplate", info).toString(); } QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); } QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); } QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); } @@ -405,6 +407,11 @@ void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &va setModelSetting("promptTemplate", info, value, force, true); } +void MySettings::setModelToolTemplate(const ModelInfo &info, const QString &value, bool force) +{ + setModelSetting("toolTemplate", info, value, force, true); +} + void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force) { setModelSetting("systemPrompt", info, value, force, true); diff --git a/gpt4all-chat/mysettings.h b/gpt4all-chat/mysettings.h index 59301cf2..205c21d0 100644 --- a/gpt4all-chat/mysettings.h +++ b/gpt4all-chat/mysettings.h @@ -126,6 +126,8 @@ public: Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false); QString modelPromptTemplate(const ModelInfo &info) const; Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false); + QString modelToolTemplate(const ModelInfo &info) const; + Q_INVOKABLE void setModelToolTemplate(const ModelInfo &info, const QString &value, bool force = false); QString modelSystemPrompt(const ModelInfo &info) const; Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false); int modelContextLength(const ModelInfo &info) const; @@ -217,6 +219,7 @@ Q_SIGNALS: void repeatPenaltyChanged(const ModelInfo &info); void repeatPenaltyTokensChanged(const ModelInfo &info); void promptTemplateChanged(const ModelInfo &info); + void toolTemplateChanged(const ModelInfo &info); void systemPromptChanged(const ModelInfo &info); void chatNamePromptChanged(const ModelInfo &info); void suggestedFollowUpPromptChanged(const ModelInfo &info); diff --git a/gpt4all-chat/qml/ModelSettings.qml b/gpt4all-chat/qml/ModelSettings.qml index 5e896eb1..99489349 100644 --- a/gpt4all-chat/qml/ModelSettings.qml +++ b/gpt4all-chat/qml/ModelSettings.qml @@ -209,7 +209,7 @@ MySettingsTab { id: promptTemplateLabelHelp text: qsTr("Must contain the string \"%1\" to be replaced with the user's input.") color: theme.textErrorColor - visible: templateTextArea.text.indexOf("%1") === -1 + visible: promptTemplateTextArea.text.indexOf("%1") === -1 wrapMode: TextArea.Wrap } } @@ -220,27 +220,27 @@ MySettingsTab { Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true - Layout.minimumHeight: Math.max(100, templateTextArea.contentHeight + 20) + Layout.minimumHeight: Math.max(100, promptTemplateTextArea.contentHeight + 20) color: "transparent" clip: true MyTextArea { - id: templateTextArea + id: promptTemplateTextArea anchors.fill: parent text: root.currentModelInfo.promptTemplate Connections { target: MySettings function onPromptTemplateChanged() { - templateTextArea.text = root.currentModelInfo.promptTemplate; + promptTemplateTextArea.text = root.currentModelInfo.promptTemplate; } } Connections { target: root function onCurrentModelInfoChanged() { - templateTextArea.text = root.currentModelInfo.promptTemplate; + promptTemplateTextArea.text = root.currentModelInfo.promptTemplate; } } onTextChanged: { - if (templateTextArea.text.indexOf("%1") !== -1) { + if (promptTemplateTextArea.text.indexOf("%1") !== -1) { MySettings.setModelPromptTemplate(root.currentModelInfo, text) } } @@ -250,18 +250,64 @@ MySettingsTab { } } + MySettingsLabel { + Layout.row: 11 + Layout.column: 0 + Layout.columnSpan: 2 + Layout.topMargin: 15 + id: toolTemplateLabel + text: qsTr("Tool Template") + helpText: qsTr("The template that allows tool calls to inject information into the context.") + } + + Rectangle { + id: toolTemplate + Layout.row: 12 + Layout.column: 0 + Layout.columnSpan: 2 + Layout.fillWidth: true + Layout.minimumHeight: Math.max(100, toolTemplateTextArea.contentHeight + 20) + color: "transparent" + clip: true + MyTextArea { + id: toolTemplateTextArea + anchors.fill: parent + text: root.currentModelInfo.toolTemplate + Connections { + target: MySettings + function onToolTemplateChanged() { + toolTemplateTextArea.text = root.currentModelInfo.toolTemplate; + } + } + Connections { + target: root + function onCurrentModelInfoChanged() { + toolTemplateTextArea.text = root.currentModelInfo.toolTemplate; + } + } + onTextChanged: { + if (toolTemplateTextArea.text.indexOf("%1") !== -1) { + MySettings.setModelToolTemplate(root.currentModelInfo, text) + } + } + Accessible.role: Accessible.EditableText + Accessible.name: toolTemplateLabel.text + Accessible.description: toolTemplateLabel.text + } + } + MySettingsLabel { id: chatNamePromptLabel text: qsTr("Chat Name Prompt") helpText: qsTr("Prompt used to automatically generate chat names.") - Layout.row: 11 + Layout.row: 13 Layout.column: 0 Layout.topMargin: 15 } Rectangle { id: chatNamePrompt - Layout.row: 12 + Layout.row: 14 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -297,14 +343,14 @@ MySettingsTab { id: suggestedFollowUpPromptLabel text: qsTr("Suggested FollowUp Prompt") helpText: qsTr("Prompt used to generate suggested follow-up questions.") - Layout.row: 13 + Layout.row: 15 Layout.column: 0 Layout.topMargin: 15 } Rectangle { id: suggestedFollowUpPrompt - Layout.row: 14 + Layout.row: 16 Layout.column: 0 Layout.columnSpan: 2 Layout.fillWidth: true @@ -337,7 +383,7 @@ MySettingsTab { } GridLayout { - Layout.row: 15 + Layout.row: 17 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15 @@ -833,7 +879,7 @@ MySettingsTab { } Rectangle { - Layout.row: 16 + Layout.row: 18 Layout.column: 0 Layout.columnSpan: 2 Layout.topMargin: 15