mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-04 02:58:04 +00:00
Refactor to handle errors in tool calling better and add source comments.
Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
7c7558eed3
commit
fffd9f341a
@ -815,11 +815,17 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
|||||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||||
}
|
}
|
||||||
|
|
||||||
m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now
|
// We can't handle recursive tool calls right now otherwise we always try to check if we have a
|
||||||
|
// tool call
|
||||||
|
m_checkToolCall = !isToolCallResponse;
|
||||||
|
|
||||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||||
/*allowContextShift*/ true, m_ctx);
|
/*allowContextShift*/ true, m_ctx);
|
||||||
|
|
||||||
|
// After the response has been handled reset this state
|
||||||
m_checkToolCall = false;
|
m_checkToolCall = false;
|
||||||
m_maybeToolCall = false;
|
m_maybeToolCall = false;
|
||||||
|
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
printf("\n");
|
printf("\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
@ -827,15 +833,66 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
|||||||
m_timer->stop();
|
m_timer->stop();
|
||||||
qint64 elapsed = totalTime.elapsed();
|
qint64 elapsed = totalTime.elapsed();
|
||||||
std::string trimmed = trim_whitespace(m_response);
|
std::string trimmed = trim_whitespace(m_response);
|
||||||
|
|
||||||
|
// If we found a tool call, then deal with it
|
||||||
if (m_foundToolCall) {
|
if (m_foundToolCall) {
|
||||||
m_foundToolCall = false;
|
m_foundToolCall = false;
|
||||||
|
|
||||||
|
const QString toolCall = QString::fromStdString(trimmed);
|
||||||
|
const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
|
||||||
|
if (toolTemplate.isEmpty()) {
|
||||||
|
qWarning() << "ERROR: No valid tool template for this model" << toolCall;
|
||||||
|
return handleFailedToolCall(trimmed, elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
QJsonParseError err;
|
||||||
|
const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);
|
||||||
|
|
||||||
|
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
|
||||||
|
qWarning() << "ERROR: The tool call had null or invalid json " << toolCall;
|
||||||
|
return handleFailedToolCall(trimmed, elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
QJsonObject rootObject = toolCallDoc.object();
|
||||||
|
if (!rootObject.contains("name") || !rootObject.contains("arguments")) {
|
||||||
|
qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall;
|
||||||
|
return handleFailedToolCall(trimmed, elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
const QString tool = toolCallDoc["name"].toString();
|
||||||
|
const QJsonObject args = toolCallDoc["arguments"].toObject();
|
||||||
|
|
||||||
|
// FIXME: In the future this will try to match the tool call to a list of tools that are supported
|
||||||
|
// according to MySettings, but for now only brave search is supported
|
||||||
|
if (tool != "brave_search" || !args.contains("query")) {
|
||||||
|
qWarning() << "ERROR: Could not find the tool and correct arguments for " << toolCall;
|
||||||
|
return handleFailedToolCall(trimmed, elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
const QString query = args["query"].toString();
|
||||||
|
|
||||||
|
// FIXME: This has to handle errors of the tool call
|
||||||
|
emit toolCalled(tr("searching web..."));
|
||||||
|
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
|
||||||
|
Q_ASSERT(apiKey != "");
|
||||||
|
BraveSearch brave;
|
||||||
|
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/,
|
||||||
|
2000 /*msecs to timeout*/);
|
||||||
|
emit sourceExcerptsChanged(braveResponse.second);
|
||||||
|
|
||||||
|
// Erase the context of the tool call
|
||||||
m_ctx.n_past = std::max(0, m_ctx.n_past);
|
m_ctx.n_past = std::max(0, m_ctx.n_past);
|
||||||
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
||||||
m_promptResponseTokens = 0;
|
m_promptResponseTokens = 0;
|
||||||
m_promptTokens = 0;
|
m_promptTokens = 0;
|
||||||
m_response = std::string();
|
m_response = std::string();
|
||||||
return toolCallInternal(QString::fromStdString(trimmed), n_predict, top_k, top_p, min_p, temp,
|
|
||||||
n_batch, repeat_penalty, repeat_penalty_tokens);
|
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
|
||||||
|
// tool calls
|
||||||
|
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
|
||||||
|
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
|
||||||
|
true /*isToolCallResponse*/);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if (trimmed != m_response) {
|
if (trimmed != m_response) {
|
||||||
m_response = trimmed;
|
m_response = trimmed;
|
||||||
@ -847,65 +904,19 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
|||||||
generateQuestions(elapsed);
|
generateQuestions(elapsed);
|
||||||
else
|
else
|
||||||
emit responseStopped(elapsed);
|
emit responseStopped(elapsed);
|
||||||
|
m_pristineLoadedState = false;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_pristineLoadedState = false;
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p,
|
bool ChatLLM::handleFailedToolCall(const std::string &response, qint64 elapsed)
|
||||||
float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
|
|
||||||
{
|
{
|
||||||
QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
|
// Restore the strings that we excluded previously when detecting the tool call
|
||||||
if (toolTemplate.isEmpty()) {
|
m_response = "<tool_call>" + response + "</tool_call>";
|
||||||
// FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for
|
emit responseChanged(QString::fromStdString(m_response));
|
||||||
// us to process it. We should probably not even attempt further generation and just show an
|
emit responseStopped(elapsed);
|
||||||
// error in the chat somehow?
|
m_pristineLoadedState = false;
|
||||||
qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall;
|
return true;
|
||||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/,
|
|
||||||
MySettings::globalInstance()->modelPromptTemplate(m_modelInfo),
|
|
||||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
|
||||||
}
|
|
||||||
|
|
||||||
QJsonParseError err;
|
|
||||||
QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);
|
|
||||||
|
|
||||||
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
|
|
||||||
qWarning() << "WARNING: The tool call had null or invalid json " << toolCall;
|
|
||||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
|
|
||||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
|
||||||
}
|
|
||||||
|
|
||||||
QJsonObject rootObject = toolCallDoc.object();
|
|
||||||
if (!rootObject.contains("name") || !rootObject.contains("arguments")) {
|
|
||||||
qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall;
|
|
||||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
|
|
||||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
|
||||||
}
|
|
||||||
|
|
||||||
const QString tool = toolCallDoc["name"].toString();
|
|
||||||
const QJsonObject args = toolCallDoc["arguments"].toObject();
|
|
||||||
|
|
||||||
if (tool != "brave_search" || !args.contains("query")) {
|
|
||||||
qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall;
|
|
||||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
|
|
||||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
|
||||||
}
|
|
||||||
|
|
||||||
const QString query = args["query"].toString();
|
|
||||||
|
|
||||||
emit toolCalled(tr("searching web..."));
|
|
||||||
|
|
||||||
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
|
|
||||||
Q_ASSERT(apiKey != "");
|
|
||||||
|
|
||||||
BraveSearch brave;
|
|
||||||
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/, 2000 /*msecs to timeout*/);
|
|
||||||
|
|
||||||
emit sourceExcerptsChanged(braveResponse.second);
|
|
||||||
|
|
||||||
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
|
|
||||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatLLM::setShouldBeLoaded(bool b)
|
void ChatLLM::setShouldBeLoaded(bool b)
|
||||||
|
@ -200,8 +200,7 @@ protected:
|
|||||||
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||||
int32_t repeat_penalty_tokens, bool isToolCallResponse = false);
|
int32_t repeat_penalty_tokens, bool isToolCallResponse = false);
|
||||||
bool toolCallInternal(const QString &toolcall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed);
|
||||||
int32_t repeat_penalty_tokens);
|
|
||||||
bool handlePrompt(int32_t token);
|
bool handlePrompt(int32_t token);
|
||||||
bool handleResponse(int32_t token, const std::string &response);
|
bool handleResponse(int32_t token, const std::string &response);
|
||||||
bool handleNamePrompt(int32_t token);
|
bool handleNamePrompt(int32_t token);
|
||||||
|
Loading…
Reference in New Issue
Block a user