mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-15 06:33:31 +00:00
Handle the forced usage of tool calls outside of the recursive prompt method.
Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
f118720717
commit
75dbf9de7d
@ -759,16 +759,13 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
|
||||
}
|
||||
|
||||
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens, bool isToolCallResponse)
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens)
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return false;
|
||||
|
||||
// FIXME: This should be made agnostic to localdocs and rely upon the force usage usage mode
|
||||
// and also we have to honor the ask before running mode.
|
||||
// FIXME: The only localdocs specific thing here should be the injection of the parameters
|
||||
// FIXME: Get the list of tools ... if force usage is set, then we *try* and force usage here.
|
||||
QList<SourceExcerpt> localDocsExcerpts;
|
||||
if (!collectionList.isEmpty() && !isToolCallResponse) {
|
||||
if (!collectionList.isEmpty()) {
|
||||
LocalDocsSearch localdocs;
|
||||
QJsonObject parameters;
|
||||
parameters.insert("text", prompt);
|
||||
@ -795,6 +792,27 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
docsContext = u"### Context:\n%1\n\n"_s.arg(json);
|
||||
}
|
||||
|
||||
qint64 totalTime = 0;
|
||||
bool producedSourceExcerpts;
|
||||
bool success = promptRecursive({ docsContext }, prompt, promptTemplate, n_predict, top_k, top_p,
|
||||
min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime, producedSourceExcerpts);
|
||||
|
||||
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
|
||||
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || producedSourceExcerpts)))
|
||||
generateQuestions(totalTime);
|
||||
else
|
||||
emit responseStopped(totalTime);
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
bool ChatLLM::promptRecursive(const QList<QString> &toolContexts, const QString &prompt,
|
||||
const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp,
|
||||
int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall)
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return false;
|
||||
|
||||
int n_threads = MySettings::globalInstance()->threadCount();
|
||||
|
||||
m_stopGenerating = false;
|
||||
@ -815,19 +833,22 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
printf("%s", qPrintable(prompt));
|
||||
fflush(stdout);
|
||||
#endif
|
||||
QElapsedTimer totalTime;
|
||||
totalTime.start();
|
||||
|
||||
QElapsedTimer elapsedTimer;
|
||||
elapsedTimer.start();
|
||||
m_timer->start();
|
||||
if (!docsContext.isEmpty()) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
|
||||
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
|
||||
|
||||
// The list of possible additional contexts that come from previous usage of tool calls
|
||||
for (const QString &context : toolContexts) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode context without a response
|
||||
m_llModelInfo.model->prompt(context.toStdString(), "%1", promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
|
||||
// We can't handle recursive tool calls right now otherwise we always try to check if we have a
|
||||
// tool call
|
||||
m_checkToolCall = !isToolCallResponse;
|
||||
// We can't handle recursive tool calls right now due to the possibility of the model causing
|
||||
// infinite recursion through repeated tool calls
|
||||
m_checkToolCall = !isRecursiveCall;
|
||||
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
@ -841,7 +862,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
fflush(stdout);
|
||||
#endif
|
||||
m_timer->stop();
|
||||
qint64 elapsed = totalTime.elapsed();
|
||||
totalTime = elapsedTimer.elapsed();
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
|
||||
// If we found a tool call, then deal with it
|
||||
@ -852,7 +873,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
|
||||
if (toolTemplate.isEmpty()) {
|
||||
qWarning() << "ERROR: No valid tool template for this model" << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
return handleFailedToolCall(trimmed, totalTime);
|
||||
}
|
||||
|
||||
QJsonParseError err;
|
||||
@ -860,13 +881,13 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
|
||||
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
|
||||
qWarning() << "ERROR: The tool call had null or invalid json " << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
return handleFailedToolCall(trimmed, totalTime);
|
||||
}
|
||||
|
||||
QJsonObject rootObject = toolCallDoc.object();
|
||||
if (!rootObject.contains("name") || !rootObject.contains("parameters")) {
|
||||
qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
return handleFailedToolCall(trimmed, totalTime);
|
||||
}
|
||||
|
||||
const QString tool = toolCallDoc["name"].toString();
|
||||
@ -877,7 +898,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
if (tool != "web_search" || !args.contains("query")) {
|
||||
// FIXME: Need to surface errors to the UI
|
||||
qWarning() << "ERROR: Could not find the tool and correct parameters for " << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
return handleFailedToolCall(trimmed, totalTime);
|
||||
}
|
||||
|
||||
const QString query = args["query"].toString();
|
||||
@ -900,6 +921,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
if (!parseError.isEmpty()) {
|
||||
qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError;
|
||||
} else if (!sourceExcerpts.isEmpty()) {
|
||||
producedSourceExcerpts = true;
|
||||
emit sourceExcerptsChanged(sourceExcerpts);
|
||||
}
|
||||
|
||||
@ -907,23 +929,16 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_promptTokens = 0;
|
||||
m_response = std::string();
|
||||
|
||||
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
|
||||
// This is a recursive call but isRecursiveCall is checked above to arrest infinite recursive
|
||||
// tool calls
|
||||
return promptInternal(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
|
||||
true /*isToolCallResponse*/);
|
||||
|
||||
return promptRecursive(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, totalTime,
|
||||
producedSourceExcerpts, true /*isRecursiveCall*/);
|
||||
} else {
|
||||
if (trimmed != m_response) {
|
||||
m_response = trimmed;
|
||||
emit responseChanged(QString::fromStdString(m_response));
|
||||
}
|
||||
|
||||
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
|
||||
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse)))
|
||||
generateQuestions(elapsed);
|
||||
else
|
||||
emit responseStopped(elapsed);
|
||||
m_pristineLoadedState = false;
|
||||
return true;
|
||||
}
|
||||
|
@ -196,9 +196,10 @@ Q_SIGNALS:
|
||||
void modelInfoChanged(const ModelInfo &modelInfo);
|
||||
|
||||
protected:
|
||||
// FIXME: This is only available because of server which sucks
|
||||
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens, bool isToolCallResponse = false);
|
||||
int32_t repeat_penalty_tokens);
|
||||
bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed);
|
||||
bool handlePrompt(int32_t token);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
@ -219,6 +220,9 @@ protected:
|
||||
quint32 m_promptResponseTokens;
|
||||
|
||||
private:
|
||||
bool promptRecursive(const QList<QString> &toolContexts, const QString &prompt, const QString &promptTemplate,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens, qint64 &totalTime, bool &producedSourceExcerpts, bool isRecursiveCall = false);
|
||||
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
||||
|
||||
std::string m_response;
|
||||
|
Loading…
Reference in New Issue
Block a user