Stop hardcoding the tool call checking and rely upon the format advocated by ollama for tool calling.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-07-30 10:30:56 -04:00
parent b9684ff741
commit 7c7558eed3
7 changed files with 162 additions and 57 deletions

View File

@ -115,7 +115,9 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
, m_restoreStateFromText(false)
, m_checkToolCall(false)
, m_maybeToolCall(false)
, m_foundToolCall(false)
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
@ -705,34 +707,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
return false;
}
// Only valid for llama 3.1 instruct
if (m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct")) {
// Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling
// For brave_search and wolfram_alpha ipython is always used
// <|python_tag|>
// brave_search.call(query="...")
// <|eom_id|>
const int eom_id = 128008;
const int python_tag = 128010;
// If we have a built-in tool call, then it should be the first token
const bool isFirstResponseToken = m_promptResponseTokens == m_promptTokens;
Q_ASSERT(token != python_tag || isFirstResponseToken);
if (isFirstResponseToken && token == python_tag) {
m_maybeToolCall = true;
++m_promptResponseTokens;
return !m_stopGenerating;
}
// Check for end of built-in tool call
Q_ASSERT(token != eom_id || !m_maybeToolCall);
if (token == eom_id) {
++m_promptResponseTokens;
return false;
}
}
// m_promptResponseTokens is related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
@ -740,7 +714,25 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
Q_ASSERT(!response.empty());
m_response.append(response);
if (!m_maybeToolCall)
// If we're checking for a tool called and the response is equal or exceeds 11 chars
// then we check
if (m_checkToolCall && m_response.size() >= 11) {
if (m_response.starts_with("<tool_call>")) {
m_maybeToolCall = true;
m_response.erase(0, 11);
}
m_checkToolCall = false;
}
// Check if we're at the end of tool call and erase the end tag
if (m_maybeToolCall && m_response.ends_with("</tool_call>")) {
m_foundToolCall = true;
m_response.erase(m_response.length() - 12);
return false;
}
// If we're not checking for tool call and haven't detected one, then send along the response
if (!m_checkToolCall && !m_maybeToolCall)
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
return !m_stopGenerating;
@ -822,8 +814,12 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
/*allowContextShift*/ true, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
m_checkToolCall = false;
m_maybeToolCall = false;
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@ -831,8 +827,8 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
std::string trimmed = trim_whitespace(m_response);
if (m_maybeToolCall) {
m_maybeToolCall = false;
if (m_foundToolCall) {
m_foundToolCall = false;
m_ctx.n_past = std::max(0, m_ctx.n_past);
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
@ -860,25 +856,46 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p,
float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
{
Q_ASSERT(m_modelInfo.filename().startsWith("Meta-Llama-3.1-8B-Instruct"));
emit toolCalled(tr("searching web..."));
QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
if (toolTemplate.isEmpty()) {
// FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for
// us to process it. We should probably not even attempt further generation and just show an
// error in the chat somehow?
qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall;
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/,
MySettings::globalInstance()->modelPromptTemplate(m_modelInfo),
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
}
// Based on https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling
// For brave_search and wolfram_alpha ipython is always used
QJsonParseError err;
QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);
static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))");
QRegularExpressionMatch match = re.match(toolCall);
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
qWarning() << "WARNING: The tool call had null or invalid json " << toolCall;
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
}
QString promptTemplate("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2");
QString query;
if (match.hasMatch()) {
query = match.captured(1);
} else {
qWarning() << "WARNING: Could not find the tool for " << toolCall;
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, promptTemplate,
QJsonObject rootObject = toolCallDoc.object();
if (!rootObject.contains("name") || !rootObject.contains("arguments")) {
qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall;
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
}
const QString tool = toolCallDoc["name"].toString();
const QJsonObject args = toolCallDoc["arguments"].toObject();
if (tool != "brave_search" || !args.contains("query")) {
qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall;
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
}
const QString query = args["query"].toString();
emit toolCalled(tr("searching web..."));
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
Q_ASSERT(apiKey != "");
@ -887,7 +904,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32
emit sourceExcerptsChanged(braveResponse.second);
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, promptTemplate,
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
}

View File

@ -242,7 +242,9 @@ private:
bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
bool m_checkToolCall;
bool m_maybeToolCall;
bool m_foundToolCall;
// m_pristineLoadedState is set if saveSate is unnecessary, either because:
// - an unload was queued during LLModel::restoreState()
// - the chat will be restored from text and hasn't been interacted with yet

View File

@ -323,6 +323,17 @@ void ModelInfo::setPromptTemplate(const QString &t)
m_promptTemplate = t;
}
QString ModelInfo::toolTemplate() const
{
return MySettings::globalInstance()->modelToolTemplate(*this);
}
void ModelInfo::setToolTemplate(const QString &t)
{
if (shouldSaveMetadata()) MySettings::globalInstance()->setModelToolTemplate(*this, t, true /*force*/);
m_toolTemplate = t;
}
QString ModelInfo::systemPrompt() const
{
return MySettings::globalInstance()->modelSystemPrompt(*this);
@ -385,6 +396,7 @@ QVariantMap ModelInfo::getFields() const
{ "repeatPenalty", m_repeatPenalty },
{ "repeatPenaltyTokens", m_repeatPenaltyTokens },
{ "promptTemplate", m_promptTemplate },
{ "toolTemplate", m_toolTemplate },
{ "systemPrompt", m_systemPrompt },
{ "chatNamePrompt", m_chatNamePrompt },
{ "suggestedFollowUpPrompt", m_suggestedFollowUpPrompt },
@ -504,6 +516,7 @@ ModelList::ModelList()
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);;
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::toolTemplateChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::systemPromptChanged, this, &ModelList::updateDataForSettings);
connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, &ModelList::handleSslErrors);
@ -776,6 +789,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->repeatPenaltyTokens();
case PromptTemplateRole:
return info->promptTemplate();
case ToolTemplateRole:
return info->toolTemplate();
case SystemPromptRole:
return info->systemPrompt();
case ChatNamePromptRole:
@ -952,6 +967,8 @@ void ModelList::updateData(const QString &id, const QVector<QPair<int, QVariant>
info->setRepeatPenaltyTokens(value.toInt()); break;
case PromptTemplateRole:
info->setPromptTemplate(value.toString()); break;
case ToolTemplateRole:
info->setToolTemplate(value.toString()); break;
case SystemPromptRole:
info->setSystemPrompt(value.toString()); break;
case ChatNamePromptRole:
@ -1107,6 +1124,7 @@ QString ModelList::clone(const ModelInfo &model)
{ ModelList::RepeatPenaltyRole, model.repeatPenalty() },
{ ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens() },
{ ModelList::PromptTemplateRole, model.promptTemplate() },
{ ModelList::ToolTemplateRole, model.toolTemplate() },
{ ModelList::SystemPromptRole, model.systemPrompt() },
{ ModelList::ChatNamePromptRole, model.chatNamePrompt() },
{ ModelList::SuggestedFollowUpPromptRole, model.suggestedFollowUpPrompt() },
@ -1551,6 +1569,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
data.append({ ModelList::RepeatPenaltyTokensRole, obj["repeatPenaltyTokens"].toInt() });
if (obj.contains("promptTemplate"))
data.append({ ModelList::PromptTemplateRole, obj["promptTemplate"].toString() });
if (obj.contains("toolTemplate"))
data.append({ ModelList::ToolTemplateRole, obj["toolTemplate"].toString() });
if (obj.contains("systemPrompt"))
data.append({ ModelList::SystemPromptRole, obj["systemPrompt"].toString() });
updateData(id, data);
@ -1852,6 +1872,10 @@ void ModelList::updateModelsFromSettings()
const QString promptTemplate = settings.value(g + "/promptTemplate").toString();
data.append({ ModelList::PromptTemplateRole, promptTemplate });
}
if (settings.contains(g + "/toolTemplate")) {
const QString toolTemplate = settings.value(g + "/toolTemplate").toString();
data.append({ ModelList::ToolTemplateRole, toolTemplate });
}
if (settings.contains(g + "/systemPrompt")) {
const QString systemPrompt = settings.value(g + "/systemPrompt").toString();
data.append({ ModelList::SystemPromptRole, systemPrompt });

View File

@ -68,6 +68,7 @@ struct ModelInfo {
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
Q_PROPERTY(QString toolTemplate READ toolTemplate WRITE setToolTemplate)
Q_PROPERTY(QString systemPrompt READ systemPrompt WRITE setSystemPrompt)
Q_PROPERTY(QString chatNamePrompt READ chatNamePrompt WRITE setChatNamePrompt)
Q_PROPERTY(QString suggestedFollowUpPrompt READ suggestedFollowUpPrompt WRITE setSuggestedFollowUpPrompt)
@ -178,6 +179,8 @@ public:
void setRepeatPenaltyTokens(int t);
QString promptTemplate() const;
void setPromptTemplate(const QString &t);
QString toolTemplate() const;
void setToolTemplate(const QString &t);
QString systemPrompt() const;
void setSystemPrompt(const QString &p);
QString chatNamePrompt() const;
@ -215,6 +218,7 @@ private:
double m_repeatPenalty = 1.18;
int m_repeatPenaltyTokens = 64;
QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n";
QString m_toolTemplate = "";
QString m_systemPrompt = "### System:\nYou are an AI assistant who gives a quality response to whatever humans ask of you.\n\n";
QString m_chatNamePrompt = "Describe the above conversation in seven words or less.";
QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts.";
@ -339,6 +343,7 @@ public:
RepeatPenaltyRole,
RepeatPenaltyTokensRole,
PromptTemplateRole,
ToolTemplateRole,
SystemPromptRole,
ChatNamePromptRole,
SuggestedFollowUpPromptRole,
@ -393,6 +398,7 @@ public:
roles[RepeatPenaltyRole] = "repeatPenalty";
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
roles[PromptTemplateRole] = "promptTemplate";
roles[ToolTemplateRole] = "toolTemplate";
roles[SystemPromptRole] = "systemPrompt";
roles[ChatNamePromptRole] = "chatNamePrompt";
roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt";

View File

@ -194,6 +194,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &info)
setModelRepeatPenalty(info, info.m_repeatPenalty);
setModelRepeatPenaltyTokens(info, info.m_repeatPenaltyTokens);
setModelPromptTemplate(info, info.m_promptTemplate);
setModelToolTemplate(info, info.m_toolTemplate);
setModelSystemPrompt(info, info.m_systemPrompt);
setModelChatNamePrompt(info, info.m_chatNamePrompt);
setModelSuggestedFollowUpPrompt(info, info.m_suggestedFollowUpPrompt);
@ -296,6 +297,7 @@ int MySettings::modelGpuLayers (const ModelInfo &info) const
double MySettings::modelRepeatPenalty (const ModelInfo &info) const { return getModelSetting("repeatPenalty", info).toDouble(); }
int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const { return getModelSetting("repeatPenaltyTokens", info).toInt(); }
QString MySettings::modelPromptTemplate (const ModelInfo &info) const { return getModelSetting("promptTemplate", info).toString(); }
QString MySettings::modelToolTemplate (const ModelInfo &info) const { return getModelSetting("toolTemplate", info).toString(); }
QString MySettings::modelSystemPrompt (const ModelInfo &info) const { return getModelSetting("systemPrompt", info).toString(); }
QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); }
QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); }
@ -405,6 +407,11 @@ void MySettings::setModelPromptTemplate(const ModelInfo &info, const QString &va
setModelSetting("promptTemplate", info, value, force, true);
}
void MySettings::setModelToolTemplate(const ModelInfo &info, const QString &value, bool force)
{
setModelSetting("toolTemplate", info, value, force, true);
}
void MySettings::setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force)
{
setModelSetting("systemPrompt", info, value, force, true);

View File

@ -126,6 +126,8 @@ public:
Q_INVOKABLE void setModelRepeatPenaltyTokens(const ModelInfo &info, int value, bool force = false);
QString modelPromptTemplate(const ModelInfo &info) const;
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &info, const QString &value, bool force = false);
QString modelToolTemplate(const ModelInfo &info) const;
Q_INVOKABLE void setModelToolTemplate(const ModelInfo &info, const QString &value, bool force = false);
QString modelSystemPrompt(const ModelInfo &info) const;
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &info, const QString &value, bool force = false);
int modelContextLength(const ModelInfo &info) const;
@ -217,6 +219,7 @@ Q_SIGNALS:
void repeatPenaltyChanged(const ModelInfo &info);
void repeatPenaltyTokensChanged(const ModelInfo &info);
void promptTemplateChanged(const ModelInfo &info);
void toolTemplateChanged(const ModelInfo &info);
void systemPromptChanged(const ModelInfo &info);
void chatNamePromptChanged(const ModelInfo &info);
void suggestedFollowUpPromptChanged(const ModelInfo &info);

View File

@ -209,7 +209,7 @@ MySettingsTab {
id: promptTemplateLabelHelp
text: qsTr("Must contain the string \"%1\" to be replaced with the user's input.")
color: theme.textErrorColor
visible: templateTextArea.text.indexOf("%1") === -1
visible: promptTemplateTextArea.text.indexOf("%1") === -1
wrapMode: TextArea.Wrap
}
}
@ -220,27 +220,27 @@ MySettingsTab {
Layout.column: 0
Layout.columnSpan: 2
Layout.fillWidth: true
Layout.minimumHeight: Math.max(100, templateTextArea.contentHeight + 20)
Layout.minimumHeight: Math.max(100, promptTemplateTextArea.contentHeight + 20)
color: "transparent"
clip: true
MyTextArea {
id: templateTextArea
id: promptTemplateTextArea
anchors.fill: parent
text: root.currentModelInfo.promptTemplate
Connections {
target: MySettings
function onPromptTemplateChanged() {
templateTextArea.text = root.currentModelInfo.promptTemplate;
promptTemplateTextArea.text = root.currentModelInfo.promptTemplate;
}
}
Connections {
target: root
function onCurrentModelInfoChanged() {
templateTextArea.text = root.currentModelInfo.promptTemplate;
promptTemplateTextArea.text = root.currentModelInfo.promptTemplate;
}
}
onTextChanged: {
if (templateTextArea.text.indexOf("%1") !== -1) {
if (promptTemplateTextArea.text.indexOf("%1") !== -1) {
MySettings.setModelPromptTemplate(root.currentModelInfo, text)
}
}
@ -250,18 +250,64 @@ MySettingsTab {
}
}
MySettingsLabel {
Layout.row: 11
Layout.column: 0
Layout.columnSpan: 2
Layout.topMargin: 15
id: toolTemplateLabel
text: qsTr("Tool Template")
helpText: qsTr("The template that allows tool calls to inject information into the context.")
}
Rectangle {
id: toolTemplate
Layout.row: 12
Layout.column: 0
Layout.columnSpan: 2
Layout.fillWidth: true
Layout.minimumHeight: Math.max(100, toolTemplateTextArea.contentHeight + 20)
color: "transparent"
clip: true
MyTextArea {
id: toolTemplateTextArea
anchors.fill: parent
text: root.currentModelInfo.toolTemplate
Connections {
target: MySettings
function onToolTemplateChanged() {
toolTemplateTextArea.text = root.currentModelInfo.toolTemplate;
}
}
Connections {
target: root
function onCurrentModelInfoChanged() {
toolTemplateTextArea.text = root.currentModelInfo.toolTemplate;
}
}
onTextChanged: {
if (toolTemplateTextArea.text.indexOf("%1") !== -1) {
MySettings.setModelToolTemplate(root.currentModelInfo, text)
}
}
Accessible.role: Accessible.EditableText
Accessible.name: toolTemplateLabel.text
Accessible.description: toolTemplateLabel.text
}
}
MySettingsLabel {
id: chatNamePromptLabel
text: qsTr("Chat Name Prompt")
helpText: qsTr("Prompt used to automatically generate chat names.")
Layout.row: 11
Layout.row: 13
Layout.column: 0
Layout.topMargin: 15
}
Rectangle {
id: chatNamePrompt
Layout.row: 12
Layout.row: 14
Layout.column: 0
Layout.columnSpan: 2
Layout.fillWidth: true
@ -297,14 +343,14 @@ MySettingsTab {
id: suggestedFollowUpPromptLabel
text: qsTr("Suggested FollowUp Prompt")
helpText: qsTr("Prompt used to generate suggested follow-up questions.")
Layout.row: 13
Layout.row: 15
Layout.column: 0
Layout.topMargin: 15
}
Rectangle {
id: suggestedFollowUpPrompt
Layout.row: 14
Layout.row: 16
Layout.column: 0
Layout.columnSpan: 2
Layout.fillWidth: true
@ -337,7 +383,7 @@ MySettingsTab {
}
GridLayout {
Layout.row: 15
Layout.row: 17
Layout.column: 0
Layout.columnSpan: 2
Layout.topMargin: 15
@ -833,7 +879,7 @@ MySettingsTab {
}
Rectangle {
Layout.row: 16
Layout.row: 18
Layout.column: 0
Layout.columnSpan: 2
Layout.topMargin: 15