Handle the forced usage of tool calls outside of the recursive prompt method.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-08-14 09:38:53 -04:00
parent f118720717
commit 75dbf9de7d
2 changed files with 52 additions and 33 deletions

View File

@ -760,15 +760,12 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate, bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, bool isToolCallResponse) int32_t repeat_penalty_tokens)
{ {
if (!isModelLoaded()) // FIXME: The only localdocs specific thing here should be the injection of the parameters
return false; // FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here.
// FIXME: This should be made agnostic to localdocs and rely upon the force usage usage mode
// and also we have to honor the ask before running mode.
QList<SourceExcerpt> localDocsExcerpts; QList<SourceExcerpt> localDocsExcerpts;
if (!collectionList.isEmpty() && !isToolCallResponse) { if (!collectionList.isEmpty()) {
LocalDocsSearch localdocs; LocalDocsSearch localdocs;
QJsonObject parameters; QJsonObject parameters;
parameters.insert("text", prompt); parameters.insert("text", prompt);
@ -795,6 +792,27 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
docsContext = u"### Context:\n%1\n\n"_s.arg(json); docsContext = u"### Context:\n%1\n\n"_s.arg(json);
} }
qint64 totalTime = 0;
bool producedSourceExcerpts;
bool success = promptRecursive({ docsContext }, prompt, promptTemplate, n_predict, top_k, top_p,
min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts);
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || producedSourceExcerpts)))
generateQuestions(totalTime);
else
emit responseStopped(totalTime);
return success;
}
bool ChatLLM::promptRecursive(const QList<QString> &toolContexts, const QString &prompt,
const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp,
int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall)
{
if (!isModelLoaded())
return false;
int n_threads = MySettings::globalInstance()->threadCount(); int n_threads = MySettings::globalInstance()->threadCount();
m_stopGenerating = false; m_stopGenerating = false;
@ -815,19 +833,22 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
printf("%s", qPrintable(prompt)); printf("%s", qPrintable(prompt));
fflush(stdout); fflush(stdout);
#endif #endif
QElapsedTimer totalTime;
totalTime.start(); QElapsedTimer elapsedTimer;
elapsedTimer.start();
m_timer->start(); m_timer->start();
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response // The list of possible additional contexts that come from previous usage of tool calls
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, for (const QString &context : toolContexts) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode context without a response
m_llModelInfo.model->prompt(context.toStdString(), "%1", promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx); /*allowContextShift*/ true, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response m_ctx.n_predict = old_n_predict; // now we are ready for a response
} }
// We can't handle recursive tool calls right now otherwise we always try to check if we have a // We can't handle recursive tool calls right now due to the possibility of the model causing
// tool call // infinite recursion through repeated tool calls
m_checkToolCall = !isToolCallResponse; m_checkToolCall = !isRecursiveCall;
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx); /*allowContextShift*/ true, m_ctx);
@ -841,7 +862,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
fflush(stdout); fflush(stdout);
#endif #endif
m_timer->stop(); m_timer->stop();
qint64 elapsed = totalTime.elapsed(); totalTime = elapsedTimer.elapsed();
std::string trimmed = trim_whitespace(m_response); std::string trimmed = trim_whitespace(m_response);
// If we found a tool call, then deal with it // If we found a tool call, then deal with it
@ -852,7 +873,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
if (toolTemplate.isEmpty()) { if (toolTemplate.isEmpty()) {
qWarning() << "ERROR: No valid tool template for this model" << toolCall; qWarning() << "ERROR: No valid tool template for this model" << toolCall;
return handleFailedToolCall(trimmed, elapsed); return handleFailedToolCall(trimmed, totalTime);
} }
QJsonParseError err; QJsonParseError err;
@ -860,13 +881,13 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
qWarning() << "ERROR: The tool call had null or invalid json " << toolCall; qWarning() << "ERROR: The tool call had null or invalid json " << toolCall;
return handleFailedToolCall(trimmed, elapsed); return handleFailedToolCall(trimmed, totalTime);
} }
QJsonObject rootObject = toolCallDoc.object(); QJsonObject rootObject = toolCallDoc.object();
if (!rootObject.contains("name") || !rootObject.contains("parameters")) { if (!rootObject.contains("name") || !rootObject.contains("parameters")) {
qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall; qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall;
return handleFailedToolCall(trimmed, elapsed); return handleFailedToolCall(trimmed, totalTime);
} }
const QString tool = toolCallDoc["name"].toString(); const QString tool = toolCallDoc["name"].toString();
@ -877,7 +898,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (tool != "web_search" || !args.contains("query")) { if (tool != "web_search" || !args.contains("query")) {
// FIXME: Need to surface errors to the UI // FIXME: Need to surface errors to the UI
qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall; qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall;
return handleFailedToolCall(trimmed, elapsed); return handleFailedToolCall(trimmed, totalTime);
} }
const QString query = args["query"].toString(); const QString query = args["query"].toString();
@ -900,6 +921,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (!parseError.isEmpty()) { if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError; qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError;
} else if (!sourceExcerpts.isEmpty()) { } else if (!sourceExcerpts.isEmpty()) {
producedSourceExcerpts = true;
emit sourceExcerptsChanged(sourceExcerpts); emit sourceExcerptsChanged(sourceExcerpts);
} }
@ -907,23 +929,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_promptTokens = 0; m_promptTokens = 0;
m_response = std::string(); m_response = std::string();
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive // This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive
// tool calls // tool calls
return promptInternal(QList<QString>()/*collectionList*/, braveResponse, toolTemplate, return promptRecursive(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime,
true /*isToolCallResponse*/); producedSourceExcerpts, true /*isRecursiveCall*/);
} else { } else {
if (trimmed != m_response) { if (trimmed != m_response) {
m_response = trimmed; m_response = trimmed;
emit responseChanged(QString::fromStdString(m_response)); emit responseChanged(QString::fromStdString(m_response));
} }
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse)))
generateQuestions(elapsed);
else
emit responseStopped(elapsed);
m_pristineLoadedState = false; m_pristineLoadedState = false;
return true; return true;
} }

View File

@ -196,9 +196,10 @@ Q_SIGNALS:
void modelInfoChanged(const ModelInfo &modelInfo); void modelInfoChanged(const ModelInfo &modelInfo);
protected: protected:
// FIXME: This is only available because of server which sucks
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate, bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, bool isToolCallResponse = false); int32_t repeat_penalty_tokens);
bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed); bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed);
bool handlePrompt(int32_t token); bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response); bool handleResponse(int32_t token, const std::string &response);
@ -219,6 +220,9 @@ protected:
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
private: private:
bool promptRecursive(const QList<QString> &toolContexts, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall = false);
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
std::string m_response; std::string m_response;