mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-05-06 07:27:15 +00:00
Don't block the gui thread for tool calls (#3435)
Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
adafa17c37
commit
22b8278ef1
@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
- Fix `codesign --verify` failure on macOS ([#3413](https://github.com/nomic-ai/gpt4all/pull/3413))
|
||||
- Code Interpreter: Fix console.log not accepting a single string after v3.7.0 ([#3426](https://github.com/nomic-ai/gpt4all/pull/3426))
|
||||
- Fix Phi 3.1 Mini 128K Instruct template (by [@ThiloteE](https://github.com/ThiloteE) in [#3412](https://github.com/nomic-ai/gpt4all/pull/3412))
|
||||
- Don't block the gui thread for reasoning ([#3435](https://github.com/nomic-ai/gpt4all/pull/3435))
|
||||
|
||||
## [3.7.0] - 2025-01-21
|
||||
|
||||
|
@ -110,6 +110,7 @@ GridLayout {
|
||||
case Chat.PromptProcessing: return qsTr("processing ...")
|
||||
case Chat.ResponseGeneration: return qsTr("generating response ...");
|
||||
case Chat.GeneratingQuestions: return qsTr("generating questions ...");
|
||||
case Chat.ToolCallGeneration: return qsTr("generating toolcall ...");
|
||||
default: return ""; // handle unexpected values
|
||||
}
|
||||
}
|
||||
|
@ -181,6 +181,11 @@ QVariant Chat::popPrompt(int index)
|
||||
|
||||
void Chat::stopGenerating()
|
||||
{
|
||||
// In future if we have more than one tool we'll have to keep track of which tools are possibly
|
||||
// running, but for now we only have one
|
||||
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
|
||||
Q_ASSERT(toolInstance);
|
||||
toolInstance->interrupt();
|
||||
m_llmodel->stopGenerating();
|
||||
}
|
||||
|
||||
@ -242,56 +247,71 @@ void Chat::responseStopped(qint64 promptResponseMs)
|
||||
|
||||
const QString possibleToolcall = m_chatModel->possibleToolcall();
|
||||
|
||||
ToolCallParser parser;
|
||||
parser.update(possibleToolcall);
|
||||
|
||||
if (parser.state() == ToolEnums::ParseState::Complete) {
|
||||
const QString toolCall = parser.toolCall();
|
||||
|
||||
// Regex to remove the formatting around the code
|
||||
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
|
||||
QString code = toolCall;
|
||||
code.remove(regex);
|
||||
code = code.trimmed();
|
||||
|
||||
// Right now the code interpreter is the only available tool
|
||||
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
|
||||
Q_ASSERT(toolInstance);
|
||||
|
||||
// The param is the code
|
||||
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
|
||||
const QString result = toolInstance->run({param}, 10000 /*msecs to timeout*/);
|
||||
const ToolEnums::Error error = toolInstance->error();
|
||||
const QString errorString = toolInstance->errorString();
|
||||
|
||||
// Update the current response with meta information about toolcall and re-parent
|
||||
m_chatModel->updateToolCall({
|
||||
ToolCallConstants::CodeInterpreterFunction,
|
||||
{ param },
|
||||
result,
|
||||
error,
|
||||
errorString
|
||||
});
|
||||
|
||||
++m_consecutiveToolCalls;
|
||||
|
||||
// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
|
||||
if (m_consecutiveToolCalls < 3 || error == ToolEnums::Error::NoError) {
|
||||
resetResponseState();
|
||||
emit promptRequested(m_collections); // triggers a new response
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_generatedName.isEmpty())
|
||||
emit generateNameRequested();
|
||||
|
||||
m_consecutiveToolCalls = 0;
|
||||
Network::globalInstance()->trackChatEvent("response_complete", {
|
||||
Network::globalInstance()->trackChatEvent("response_stopped", {
|
||||
{"first", m_firstResponse},
|
||||
{"message_count", chatModel()->count()},
|
||||
{"$duration", promptResponseMs / 1000.},
|
||||
});
|
||||
|
||||
ToolCallParser parser;
|
||||
parser.update(possibleToolcall);
|
||||
if (parser.state() == ToolEnums::ParseState::Complete)
|
||||
processToolCall(parser.toolCall());
|
||||
else
|
||||
responseComplete();
|
||||
}
|
||||
|
||||
void Chat::processToolCall(const QString &toolCall)
|
||||
{
|
||||
m_responseState = Chat::ToolCallGeneration;
|
||||
emit responseStateChanged();
|
||||
// Regex to remove the formatting around the code
|
||||
static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$");
|
||||
QString code = toolCall;
|
||||
code.remove(regex);
|
||||
code = code.trimmed();
|
||||
|
||||
// Right now the code interpreter is the only available tool
|
||||
Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction);
|
||||
Q_ASSERT(toolInstance);
|
||||
connect(toolInstance, &Tool::runComplete, this, &Chat::toolCallComplete, Qt::SingleShotConnection);
|
||||
|
||||
// The param is the code
|
||||
const ToolParam param = { "code", ToolEnums::ParamType::String, code };
|
||||
m_responseInProgress = true;
|
||||
emit responseInProgressChanged();
|
||||
toolInstance->run({param});
|
||||
}
|
||||
|
||||
void Chat::toolCallComplete(const ToolCallInfo &info)
|
||||
{
|
||||
// Update the current response with meta information about toolcall and re-parent
|
||||
m_chatModel->updateToolCall(info);
|
||||
|
||||
++m_consecutiveToolCalls;
|
||||
|
||||
m_responseInProgress = false;
|
||||
emit responseInProgressChanged();
|
||||
|
||||
// We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop
|
||||
if (m_consecutiveToolCalls < 3 || info.error == ToolEnums::Error::NoError) {
|
||||
resetResponseState();
|
||||
emit promptRequested(m_collections); // triggers a new response
|
||||
return;
|
||||
}
|
||||
|
||||
responseComplete();
|
||||
}
|
||||
|
||||
void Chat::responseComplete()
|
||||
{
|
||||
if (m_generatedName.isEmpty())
|
||||
emit generateNameRequested();
|
||||
|
||||
m_responseState = Chat::ResponseStopped;
|
||||
emit responseStateChanged();
|
||||
|
||||
m_consecutiveToolCalls = 0;
|
||||
m_firstResponse = false;
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,8 @@ public:
|
||||
LocalDocsProcessing,
|
||||
PromptProcessing,
|
||||
GeneratingQuestions,
|
||||
ResponseGeneration
|
||||
ResponseGeneration,
|
||||
ToolCallGeneration
|
||||
};
|
||||
Q_ENUM(ResponseState)
|
||||
|
||||
@ -166,6 +167,9 @@ private Q_SLOTS:
|
||||
void promptProcessing();
|
||||
void generatingQuestions();
|
||||
void responseStopped(qint64 promptResponseMs);
|
||||
void processToolCall(const QString &toolCall);
|
||||
void toolCallComplete(const ToolCallInfo &info);
|
||||
void responseComplete();
|
||||
void generatedNameChanged(const QString &name);
|
||||
void generatedQuestionFinished(const QString &question);
|
||||
void handleModelLoadingError(const QString &error);
|
||||
|
@ -7,8 +7,15 @@
|
||||
|
||||
using namespace Qt::Literals::StringLiterals;
|
||||
|
||||
CodeInterpreter::CodeInterpreter()
|
||||
: Tool()
|
||||
, m_error(ToolEnums::Error::NoError)
|
||||
{
|
||||
m_worker = new CodeInterpreterWorker;
|
||||
connect(this, &CodeInterpreter::request, m_worker, &CodeInterpreterWorker::request, Qt::QueuedConnection);
|
||||
}
|
||||
|
||||
QString CodeInterpreter::run(const QList<ToolParam> ¶ms, qint64 timeout)
|
||||
void CodeInterpreter::run(const QList<ToolParam> ¶ms)
|
||||
{
|
||||
m_error = ToolEnums::Error::NoError;
|
||||
m_errorString = QString();
|
||||
@ -18,27 +25,24 @@ QString CodeInterpreter::run(const QList<ToolParam> ¶ms, qint64 timeout)
|
||||
&& params.first().type == ToolEnums::ParamType::String);
|
||||
|
||||
const QString code = params.first().value.toString();
|
||||
|
||||
QThread workerThread;
|
||||
CodeInterpreterWorker worker;
|
||||
worker.moveToThread(&workerThread);
|
||||
connect(&worker, &CodeInterpreterWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
|
||||
connect(&workerThread, &QThread::started, [&worker, code]() {
|
||||
worker.request(code);
|
||||
connect(m_worker, &CodeInterpreterWorker::finished, [this, params] {
|
||||
m_error = m_worker->error();
|
||||
m_errorString = m_worker->errorString();
|
||||
emit runComplete({
|
||||
ToolCallConstants::CodeInterpreterFunction,
|
||||
params,
|
||||
m_worker->response(),
|
||||
m_error,
|
||||
m_errorString
|
||||
});
|
||||
});
|
||||
workerThread.start();
|
||||
bool timedOut = !workerThread.wait(timeout);
|
||||
if (timedOut) {
|
||||
worker.interrupt(timeout); // thread safe
|
||||
m_error = ToolEnums::Error::TimeoutError;
|
||||
}
|
||||
workerThread.quit();
|
||||
workerThread.wait();
|
||||
if (!timedOut) {
|
||||
m_error = worker.error();
|
||||
m_errorString = worker.errorString();
|
||||
}
|
||||
return worker.response();
|
||||
|
||||
emit request(code);
|
||||
}
|
||||
|
||||
bool CodeInterpreter::interrupt()
|
||||
{
|
||||
return m_worker->interrupt();
|
||||
}
|
||||
|
||||
QList<ToolParamInfo> CodeInterpreter::parameters() const
|
||||
@ -89,17 +93,15 @@ QString CodeInterpreter::exampleReply() const
|
||||
|
||||
CodeInterpreterWorker::CodeInterpreterWorker()
|
||||
: QObject(nullptr)
|
||||
, m_engine(new QJSEngine(this))
|
||||
{
|
||||
}
|
||||
moveToThread(&m_thread);
|
||||
|
||||
void CodeInterpreterWorker::request(const QString &code)
|
||||
{
|
||||
JavaScriptConsoleCapture consoleCapture;
|
||||
QJSValue consoleInternalObject = m_engine.newQObject(&consoleCapture);
|
||||
m_engine.globalObject().setProperty("console_internal", consoleInternalObject);
|
||||
QJSValue consoleInternalObject = m_engine->newQObject(&m_consoleCapture);
|
||||
m_engine->globalObject().setProperty("console_internal", consoleInternalObject);
|
||||
|
||||
// preprocess console.log args in JS since Q_INVOKE doesn't support varargs
|
||||
auto consoleObject = m_engine.evaluate(uR"(
|
||||
auto consoleObject = m_engine->evaluate(uR"(
|
||||
class Console {
|
||||
log(...args) {
|
||||
if (args.length == 0)
|
||||
@ -116,15 +118,28 @@ void CodeInterpreterWorker::request(const QString &code)
|
||||
|
||||
new Console();
|
||||
)"_s);
|
||||
m_engine.globalObject().setProperty("console", consoleObject);
|
||||
m_engine->globalObject().setProperty("console", consoleObject);
|
||||
m_thread.start();
|
||||
}
|
||||
|
||||
const QJSValue result = m_engine.evaluate(code);
|
||||
void CodeInterpreterWorker::reset()
|
||||
{
|
||||
m_response.clear();
|
||||
m_error = ToolEnums::Error::NoError;
|
||||
m_errorString.clear();
|
||||
m_consoleCapture.output.clear();
|
||||
m_engine->setInterrupted(false);
|
||||
}
|
||||
|
||||
void CodeInterpreterWorker::request(const QString &code)
|
||||
{
|
||||
reset();
|
||||
const QJSValue result = m_engine->evaluate(code);
|
||||
QString resultString;
|
||||
|
||||
if (m_engine.isInterrupted()) {
|
||||
resultString = QString("Error: code execution was timed out as it exceeded %1 ms. Code must be written to ensure execution does not timeout.").arg(m_timeout);
|
||||
} else if (result.isError()) {
|
||||
if (m_engine->isInterrupted()) {
|
||||
resultString = QString("Error: code execution was interrupted or timed out.");
|
||||
} else if (result.isError()) {
|
||||
// NOTE: We purposely do not set the m_error or m_errorString for the code interpreter since
|
||||
// we *want* the model to see the response has an error so it can hopefully correct itself. The
|
||||
// error member variables are intended for tools that have error conditions that cannot be corrected.
|
||||
@ -145,9 +160,16 @@ void CodeInterpreterWorker::request(const QString &code)
|
||||
}
|
||||
|
||||
if (resultString.isEmpty())
|
||||
resultString = consoleCapture.output;
|
||||
else if (!consoleCapture.output.isEmpty())
|
||||
resultString += "\n" + consoleCapture.output;
|
||||
resultString = m_consoleCapture.output;
|
||||
else if (!m_consoleCapture.output.isEmpty())
|
||||
resultString += "\n" + m_consoleCapture.output;
|
||||
m_response = resultString;
|
||||
emit finished();
|
||||
}
|
||||
|
||||
bool CodeInterpreterWorker::interrupt()
|
||||
{
|
||||
m_error = ToolEnums::Error::TimeoutError;
|
||||
m_engine->setInterrupted(true);
|
||||
return true;
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <QObject>
|
||||
#include <QString>
|
||||
#include <QtGlobal>
|
||||
#include <QThread>
|
||||
|
||||
class JavaScriptConsoleCapture : public QObject
|
||||
{
|
||||
@ -39,32 +40,37 @@ public:
|
||||
CodeInterpreterWorker();
|
||||
virtual ~CodeInterpreterWorker() {}
|
||||
|
||||
void reset();
|
||||
QString response() const { return m_response; }
|
||||
|
||||
void request(const QString &code);
|
||||
void interrupt(qint64 timeout) { m_timeout = timeout; m_engine.setInterrupted(true); }
|
||||
ToolEnums::Error error() const { return m_error; }
|
||||
QString errorString() const { return m_errorString; }
|
||||
bool interrupt();
|
||||
|
||||
public Q_SLOTS:
|
||||
void request(const QString &code);
|
||||
|
||||
Q_SIGNALS:
|
||||
void finished();
|
||||
|
||||
private:
|
||||
qint64 m_timeout = 0;
|
||||
QJSEngine m_engine;
|
||||
QString m_response;
|
||||
ToolEnums::Error m_error = ToolEnums::Error::NoError;
|
||||
QString m_errorString;
|
||||
QThread m_thread;
|
||||
JavaScriptConsoleCapture m_consoleCapture;
|
||||
QJSEngine *m_engine = nullptr;
|
||||
};
|
||||
|
||||
class CodeInterpreter : public Tool
|
||||
{
|
||||
Q_OBJECT
|
||||
public:
|
||||
explicit CodeInterpreter() : Tool(), m_error(ToolEnums::Error::NoError) {}
|
||||
explicit CodeInterpreter();
|
||||
virtual ~CodeInterpreter() {}
|
||||
|
||||
QString run(const QList<ToolParam> ¶ms, qint64 timeout = 2000) override;
|
||||
void run(const QList<ToolParam> ¶ms) override;
|
||||
bool interrupt() override;
|
||||
|
||||
ToolEnums::Error error() const override { return m_error; }
|
||||
QString errorString() const override { return m_errorString; }
|
||||
|
||||
@ -77,9 +83,13 @@ public:
|
||||
QString exampleCall() const override;
|
||||
QString exampleReply() const override;
|
||||
|
||||
Q_SIGNALS:
|
||||
void request(const QString &code);
|
||||
|
||||
private:
|
||||
ToolEnums::Error m_error = ToolEnums::Error::NoError;
|
||||
QString m_errorString;
|
||||
CodeInterpreterWorker *m_worker;
|
||||
};
|
||||
|
||||
#endif // CODEINTERPRETER_H
|
||||
|
@ -87,7 +87,8 @@ public:
|
||||
Tool() : QObject(nullptr) {}
|
||||
virtual ~Tool() {}
|
||||
|
||||
virtual QString run(const QList<ToolParam> ¶ms, qint64 timeout = 2000) = 0;
|
||||
virtual void run(const QList<ToolParam> ¶ms) = 0;
|
||||
virtual bool interrupt() = 0;
|
||||
|
||||
// Tools should set these if they encounter errors. For instance, a tool depending upon the network
|
||||
// might set these error variables if the network is not available.
|
||||
@ -122,6 +123,9 @@ public:
|
||||
bool operator==(const Tool &other) const { return function() == other.function(); }
|
||||
|
||||
jinja2::Value jinjaValue() const;
|
||||
|
||||
Q_SIGNALS:
|
||||
void runComplete(const ToolCallInfo &info);
|
||||
};
|
||||
|
||||
#endif // TOOL_H
|
||||
|
Loading…
Reference in New Issue
Block a user