Handle the forced usage of tool calls outside of the recursive prompt method.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-08-14 09:38:53 -04:00
parent f118720717
commit 75dbf9de7d
2 changed files with 52 additions and 33 deletions

View File

@ -759,16 +759,13 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
}
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, bool isToolCallResponse)
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
if (!isModelLoaded())
return false;
// FIXME: This should be made agnostic to localdocs and rely upon the force usage usage mode
// and also we have to honor the ask before running mode.
// FIXME: The only localdocs specific thing here should be the injection of the parameters
// FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here.
QList<SourceExcerpt> localDocsExcerpts;
if (!collectionList.isEmpty() && !isToolCallResponse) {
if (!collectionList.isEmpty()) {
LocalDocsSearch localdocs;
QJsonObject parameters;
parameters.insert("text", prompt);
@ -795,6 +792,27 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
docsContext = u"### Context:\n%1\n\n"_s.arg(json);
}
qint64 totalTime = 0;
bool producedSourceExcerpts;
bool success = promptRecursive({ docsContext }, prompt, promptTemplate, n_predict, top_k, top_p,
min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts);
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || producedSourceExcerpts)))
generateQuestions(totalTime);
else
emit responseStopped(totalTime);
return success;
}
bool ChatLLM::promptRecursive(const QList<QString> &toolContexts, const QString &prompt,
const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp,
int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall)
{
if (!isModelLoaded())
return false;
int n_threads = MySettings::globalInstance()->threadCount();
m_stopGenerating = false;
@ -815,19 +833,22 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
printf("%s", qPrintable(prompt));
fflush(stdout);
#endif
QElapsedTimer totalTime;
totalTime.start();
QElapsedTimer elapsedTimer;
elapsedTimer.start();
m_timer->start();
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
// The list of possible additional contexts that come from previous usage of tool calls
for (const QString &context : toolContexts) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode context without a response
m_llModelInfo.model->prompt(context.toStdString(), "%1", promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
// We can't handle recursive tool calls right now otherwise we always try to check if we have a
// tool call
m_checkToolCall = !isToolCallResponse;
// We can't handle recursive tool calls right now due to the possibility of the model causing
// infinite recursion through repeated tool calls
m_checkToolCall = !isRecursiveCall;
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
@ -841,7 +862,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
fflush(stdout);
#endif
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
totalTime = elapsedTimer.elapsed();
std::string trimmed = trim_whitespace(m_response);
// If we found a tool call, then deal with it
@ -852,7 +873,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
if (toolTemplate.isEmpty()) {
qWarning() << "ERROR: No valid tool template for this model" << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}
QJsonParseError err;
@ -860,13 +881,13 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
qWarning() << "ERROR: The tool call had null or invalid json " << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}
QJsonObject rootObject = toolCallDoc.object();
if (!rootObject.contains("name") || !rootObject.contains("parameters")) {
qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}
const QString tool = toolCallDoc["name"].toString();
@ -877,7 +898,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (tool != "web_search" || !args.contains("query")) {
// FIXME: Need to surface errors to the UI
qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall;
return handleFailedToolCall(trimmed, elapsed);
return handleFailedToolCall(trimmed, totalTime);
}
const QString query = args["query"].toString();
@ -900,6 +921,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError;
} else if (!sourceExcerpts.isEmpty()) {
producedSourceExcerpts = true;
emit sourceExcerptsChanged(sourceExcerpts);
}
@ -907,23 +929,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_promptTokens = 0;
m_response = std::string();
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
// This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive
// tool calls
return promptInternal(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
true /*isToolCallResponse*/);
return promptRecursive(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime,
producedSourceExcerpts, true /*isRecursiveCall*/);
} else {
if (trimmed != m_response) {
m_response = trimmed;
emit responseChanged(QString::fromStdString(m_response));
}
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse)))
generateQuestions(elapsed);
else
emit responseStopped(elapsed);
m_pristineLoadedState = false;
return true;
}

View File

@ -196,9 +196,10 @@ Q_SIGNALS:
void modelInfoChanged(const ModelInfo &modelInfo);
protected:
// FIXME: This is only available because of server which sucks
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, bool isToolCallResponse = false);
int32_t repeat_penalty_tokens);
bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
@ -219,6 +220,9 @@ protected:
quint32 m_promptResponseTokens;
private:
bool promptRecursive(const QList<QString> &toolContexts, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall = false);
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
std::string m_response;