1
0
mirror of https://github.com/nomic-ai/gpt4all.git synced 2025-05-06 07:27:15 +00:00

Don't block the gui thread for tool calls ()

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
AT 2025-01-29 18:33:08 -05:00 committed by GitHub
parent adafa17c37
commit 22b8278ef1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 153 additions and 91 deletions

View File

@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix `codesign --verify` failure on macOS ([#3413](https://github.com/nomic-ai/gpt4all/pull/3413))
- Code Interpreter: Fix console.log not accepting a single string after v3.7.0 ([#3426](https://github.com/nomic-ai/gpt4all/pull/3426))
- Fix Phi 3.1 Mini 128K Instruct template (by [@ThiloteE](https://github.com/ThiloteE) in [#3412](https://github.com/nomic-ai/gpt4all/pull/3412))
- Don't block the gui thread for reasoning ([#3435](https://github.com/nomic-ai/gpt4all/pull/3435))
## [3.7.0] - 2025-01-21

View File

@ -110,6 +110,7 @@ GridLayout {
case Chat.PromptProcessing: return qsTr("processing ...")
case Chat.ResponseGeneration: return qsTr("generating response ...");
case Chat.GeneratingQuestions: return qsTr("generating questions ...");
case Chat.ToolCallGeneration: return qsTr("generating toolcall ...");
default: return ""; // handle unexpected values
}
}

View File

@ -181,6 +181,11 @@ QVariant Chat::popPrompt(int index)
void Chat::stopGenerating()
{
// In future if we have more than one tool we'll have to keep track of which tools are possibly
// running, but for now we only have one
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);
toolInstance->interrupt();
m_llmodel->stopGenerating();
}
@ -242,56 +247,71 @@ void Chat::responseStopped(qint64 promptResponseMs)
const QString possibleToolcall = m_chatModel->possibleToolcall();
ToolCallParser parser;
parser.update(possibleToolcall);
if (parser.state() == ToolEnums::ParseState::Complete) {
const QString toolCall = parser.toolCall();
// Regex to remove the formatting around the code
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
QString code = toolCall;
code.remove(regex);
code = code.trimmed();
// Right now the code interpreter is the only available tool
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);
// The param is the code
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
const QString result = toolInstance->run({param}, 10000 /*msecs to timeout*/);
const ToolEnums::Error error = toolInstance->error();
const QString errorString = toolInstance->errorString();
// Update the current response with meta information about toolcall and re-parent
m_chatModel->updateToolCall({
ToolCallConstants::CodeInterpreterFunction,
{ param },
result,
error,
errorString
});
++m_consecutiveToolCalls;
// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
if (m_consecutiveToolCalls < 3 || error == ToolEnums::Error::NoError) {
resetResponseState();
emit promptRequested(m_collections); // triggers a new response
return;
}
}
if (m_generatedName.isEmpty())
emit generateNameRequested();
m_consecutiveToolCalls = 0;
Network::globalInstance()->trackChatEvent("response_complete", {
Network::globalInstance()->trackChatEvent("response_stopped", {
{"first", m_firstResponse},
{"message_count", chatModel()->count()},
{"$duration", promptResponseMs / 1000.},
});
ToolCallParser parser;
parser.update(possibleToolcall);
if (parser.state() == ToolEnums::ParseState::Complete)
processToolCall(parser.toolCall());
else
responseComplete();
}
void Chat::processToolCall(const QString &toolCall)
{
m_responseState = Chat::ToolCallGeneration;
emit responseStateChanged();
// Regex to remove the formatting around the code
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
QString code = toolCall;
code.remove(regex);
code = code.trimmed();
// Right now the code interpreter is the only available tool
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
Q_ASSERT(toolInstance);
connect(toolInstance, &Tool::runComplete, this, &Chat::toolCallComplete, Qt::SingleShotConnection);
// The param is the code
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
m_responseInProgress = true;
emit responseInProgressChanged();
toolInstance->run({param});
}
void Chat::toolCallComplete(const ToolCallInfo &info)
{
// Update the current response with meta information about toolcall and re-parent
m_chatModel->updateToolCall(info);
++m_consecutiveToolCalls;
m_responseInProgress = false;
emit responseInProgressChanged();
// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
if (m_consecutiveToolCalls < 3 || info.error == ToolEnums::Error::NoError) {
resetResponseState();
emit promptRequested(m_collections); // triggers a new response
return;
}
responseComplete();
}
void Chat::responseComplete()
{
if (m_generatedName.isEmpty())
emit generateNameRequested();
m_responseState = Chat::ResponseStopped;
emit responseStateChanged();
m_consecutiveToolCalls = 0;
m_firstResponse = false;
}

View File

@ -55,7 +55,8 @@ public:
LocalDocsProcessing,
PromptProcessing,
GeneratingQuestions,
ResponseGeneration
ResponseGeneration,
ToolCallGeneration
};
Q_ENUM(ResponseState)
@ -166,6 +167,9 @@ private Q_SLOTS:
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
void processToolCall(const QString &toolCall);
void toolCallComplete(const ToolCallInfo &info);
void responseComplete();
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &question);
void handleModelLoadingError(const QString &error);

View File

@ -7,8 +7,15 @@
using namespace Qt::Literals::StringLiterals;
CodeInterpreter::CodeInterpreter()
: Tool()
, m_error(ToolEnums::Error::NoError)
{
m_worker = new CodeInterpreterWorker;
connect(this, &CodeInterpreter::request, m_worker, &CodeInterpreterWorker::request, Qt::QueuedConnection);
}
QString CodeInterpreter::run(const QList<ToolParam> &params, qint64 timeout)
void CodeInterpreter::run(const QList<ToolParam> &params)
{
m_error = ToolEnums::Error::NoError;
m_errorString = QString();
@ -18,27 +25,24 @@ QString CodeInterpreter::run(const QList<ToolParam> &params, qint64 timeout)
&& params.first().type == ToolEnums::ParamType::String);
const QString code = params.first().value.toString();
QThread workerThread;
CodeInterpreterWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &CodeInterpreterWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(&workerThread, &QThread::started, [&worker, code]() {
worker.request(code);
connect(m_worker, &CodeInterpreterWorker::finished, [this, params] {
m_error = m_worker->error();
m_errorString = m_worker->errorString();
emit runComplete({
ToolCallConstants::CodeInterpreterFunction,
params,
m_worker->response(),
m_error,
m_errorString
});
});
workerThread.start();
bool timedOut = !workerThread.wait(timeout);
if (timedOut) {
worker.interrupt(timeout); // thread safe
m_error = ToolEnums::Error::TimeoutError;
}
workerThread.quit();
workerThread.wait();
if (!timedOut) {
m_error = worker.error();
m_errorString = worker.errorString();
}
return worker.response();
emit request(code);
}
bool CodeInterpreter::interrupt()
{
return m_worker->interrupt();
}
QList<ToolParamInfo> CodeInterpreter::parameters() const
@ -89,17 +93,15 @@ QString CodeInterpreter::exampleReply() const
CodeInterpreterWorker::CodeInterpreterWorker()
: QObject(nullptr)
, m_engine(new QJSEngine(this))
{
}
moveToThread(&m_thread);
void CodeInterpreterWorker::request(const QString &code)
{
JavaScriptConsoleCapture consoleCapture;
QJSValue consoleInternalObject = m_engine.newQObject(&consoleCapture);
m_engine.globalObject().setProperty("console_internal", consoleInternalObject);
QJSValue consoleInternalObject = m_engine->newQObject(&m_consoleCapture);
m_engine->globalObject().setProperty("console_internal", consoleInternalObject);
// preprocess console.log args in JS since Q_INVOKE doesn't support varargs
auto consoleObject = m_engine.evaluate(uR"(
auto consoleObject = m_engine->evaluate(uR"(
class Console {
log(...args) {
if (args.length == 0)
@ -116,15 +118,28 @@ void CodeInterpreterWorker::request(const QString &code)
new Console();
)"_s);
m_engine.globalObject().setProperty("console", consoleObject);
m_engine->globalObject().setProperty("console", consoleObject);
m_thread.start();
}
const QJSValue result = m_engine.evaluate(code);
void CodeInterpreterWorker::reset()
{
m_response.clear();
m_error = ToolEnums::Error::NoError;
m_errorString.clear();
m_consoleCapture.output.clear();
m_engine->setInterrupted(false);
}
void CodeInterpreterWorker::request(const QString &code)
{
reset();
const QJSValue result = m_engine->evaluate(code);
QString resultString;
if (m_engine.isInterrupted()) {
resultString = QString("Error: code execution was timed out as it exceeded %1 ms. Code must be written to ensure execution does not timeout.").arg(m_timeout);
} else if (result.isError()) {
if (m_engine->isInterrupted()) {
resultString = QString("Error: code execution was interrupted or timed out.");
} else if (result.isError()) {
// NOTE: We purposely do not set the m_error or m_errorString for the code interpreter since
// we *want* the model to see the response has an error so it can hopefully correct itself. The
// error member variables are intended for tools that have error conditions that cannot be corrected.
@ -145,9 +160,16 @@ void CodeInterpreterWorker::request(const QString &code)
}
if (resultString.isEmpty())
resultString = consoleCapture.output;
else if (!consoleCapture.output.isEmpty())
resultString += "\n" + consoleCapture.output;
resultString = m_consoleCapture.output;
else if (!m_consoleCapture.output.isEmpty())
resultString += "\n" + m_consoleCapture.output;
m_response = resultString;
emit finished();
}
bool CodeInterpreterWorker::interrupt()
{
m_error = ToolEnums::Error::TimeoutError;
m_engine->setInterrupted(true);
return true;
}

View File

@ -8,6 +8,7 @@
#include <QObject>
#include <QString>
#include <QtGlobal>
#include <QThread>
class JavaScriptConsoleCapture : public QObject
{
@ -39,32 +40,37 @@ public:
CodeInterpreterWorker();
virtual ~CodeInterpreterWorker() {}
void reset();
QString response() const { return m_response; }
void request(const QString &code);
void interrupt(qint64 timeout) { m_timeout = timeout; m_engine.setInterrupted(true); }
ToolEnums::Error error() const { return m_error; }
QString errorString() const { return m_errorString; }
bool interrupt();
public Q_SLOTS:
void request(const QString &code);
Q_SIGNALS:
void finished();
private:
qint64 m_timeout = 0;
QJSEngine m_engine;
QString m_response;
ToolEnums::Error m_error = ToolEnums::Error::NoError;
QString m_errorString;
QThread m_thread;
JavaScriptConsoleCapture m_consoleCapture;
QJSEngine *m_engine = nullptr;
};
class CodeInterpreter : public Tool
{
Q_OBJECT
public:
explicit CodeInterpreter() : Tool(), m_error(ToolEnums::Error::NoError) {}
explicit CodeInterpreter();
virtual ~CodeInterpreter() {}
QString run(const QList<ToolParam> &params, qint64 timeout = 2000) override;
void run(const QList<ToolParam> &params) override;
bool interrupt() override;
ToolEnums::Error error() const override { return m_error; }
QString errorString() const override { return m_errorString; }
@ -77,9 +83,13 @@ public:
QString exampleCall() const override;
QString exampleReply() const override;
Q_SIGNALS:
void request(const QString &code);
private:
ToolEnums::Error m_error = ToolEnums::Error::NoError;
QString m_errorString;
CodeInterpreterWorker *m_worker;
};
#endif // CODEINTERPRETER_H

View File

@ -87,7 +87,8 @@ public:
Tool() : QObject(nullptr) {}
virtual ~Tool() {}
virtual QString run(const QList<ToolParam> &params, qint64 timeout = 2000) = 0;
virtual void run(const QList<ToolParam> &params) = 0;
virtual bool interrupt() = 0;
// Tools should set these if they encounter errors. For instance, a tool depending upon the network
// might set these error variables if the network is not available.
@ -122,6 +123,9 @@ public:
bool operator==(const Tool &other) const { return function() == other.function(); }
jinja2::Value jinjaValue() const;
Q_SIGNALS:
void runComplete(const ToolCallInfo &info);
};
#endif // TOOL_H