From 007a7af1c82f17f0c457a435fdfac40dfbb2391f Mon Sep 17 00:00:00 2001 From: AT Date: Thu, 30 Jan 2025 16:11:05 -0500 Subject: [PATCH] Display DeepSeek-R1 thinking like Reasoner (#3440) Signed-off-by: Adam Treat Signed-off-by: Jared Van Bortel Co-authored-by: Jared Van Bortel --- gpt4all-chat/CHANGELOG.md | 1 + gpt4all-chat/qml/ChatCollapsibleItem.qml | 14 ++- gpt4all-chat/qml/ChatItemView.qml | 11 ++ gpt4all-chat/src/chat.cpp | 2 +- gpt4all-chat/src/chatllm.cpp | 53 ++++++--- gpt4all-chat/src/chatmodel.cpp | 15 +++ gpt4all-chat/src/chatmodel.h | 109 +++++++++++++++++- gpt4all-chat/src/tool.h | 1 + gpt4all-chat/src/toolcallparser.cpp | 139 +++++++++++++++++------ gpt4all-chat/src/toolcallparser.h | 29 +++-- 10 files changed, 306 insertions(+), 68 deletions(-) diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index 31fc996a..e888e2fa 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added - Support DeepSeek-R1 Qwen models ([#3431](https://github.com/nomic-ai/gpt4all/pull/3431)) +- Support for think tags in the GUI ([#3440](https://github.com/nomic-ai/gpt4all/pull/3440)) ### Changed - Use minja instead of Jinja2Cpp for significantly improved template compatibility ([#3433](https://github.com/nomic-ai/gpt4all/pull/3433)) diff --git a/gpt4all-chat/qml/ChatCollapsibleItem.qml b/gpt4all-chat/qml/ChatCollapsibleItem.qml index cdfa2c40..459b8da1 100644 --- a/gpt4all-chat/qml/ChatCollapsibleItem.qml +++ b/gpt4all-chat/qml/ChatCollapsibleItem.qml @@ -13,6 +13,8 @@ ColumnLayout { property alias textContent: innerTextItem.textContent property bool isCurrent: false property bool isError: false + property bool isThinking: false + property int thinkingTime: 0 Layout.topMargin: 10 Layout.bottomMargin: 10 @@ -26,16 +28,20 @@ ColumnLayout { anchors.bottom: parent.bottom Item { - width: myTextArea.width - height: myTextArea.height + Layout.preferredWidth: myTextArea.implicitWidth + Layout.preferredHeight: myTextArea.implicitHeight TextArea { id: myTextArea text: { if (isError) return qsTr("Analysis encountered error"); if (isCurrent) - return qsTr("Analyzing"); - return qsTr("Analyzed"); + return isThinking ? qsTr("Thinking") : qsTr("Analyzing"); + return isThinking + ? qsTr("Thought for %1 %2") + .arg(Math.ceil(thinkingTime / 1000.0)) + .arg(Math.ceil(thinkingTime / 1000.0) === 1 ? qsTr("second") : qsTr("seconds")) + : qsTr("Analyzed"); } padding: 0 font.pixelSize: theme.fontSizeLarger diff --git a/gpt4all-chat/qml/ChatItemView.qml b/gpt4all-chat/qml/ChatItemView.qml index 457670f2..90e09bcb 100644 --- a/gpt4all-chat/qml/ChatItemView.qml +++ b/gpt4all-chat/qml/ChatItemView.qml @@ -189,6 +189,17 @@ GridLayout { isError: modelData.isToolCallError } } + DelegateChoice { + roleValue: "Think: "; + ChatCollapsibleItem { + Layout.fillWidth: true + textContent: modelData.content + isCurrent: modelData.isCurrentResponse + isError: false + isThinking: true + thinkingTime: modelData.thinkingTime + } + } } delegate: chooser diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index 54cb037e..fe91a1c1 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -255,7 +255,7 @@ void Chat::responseStopped(qint64 promptResponseMs) ToolCallParser parser; parser.update(possibleToolcall); - if (parser.state() == ToolEnums::ParseState::Complete) + if (parser.state() == ToolEnums::ParseState::Complete && parser.startTag() != ToolCallConstants::ThinkTag) processToolCall(parser.toolCall()); else responseComplete(); diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index faded163..625df71c 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -928,8 +928,12 @@ auto ChatLLM::promptInternal( return !m_stopGenerating; }; + QElapsedTimer totalTime; + totalTime.start(); + m_timer->start(); + ToolCallParser toolCallParser; - auto handleResponse = [this, &result, &toolCallParser](LLModel::Token token, std::string_view piece) -> bool { + auto handleResponse = [this, &result, &toolCallParser, &totalTime](LLModel::Token token, std::string_view piece) -> bool { Q_UNUSED(token) result.responseTokens++; m_timer->inc(); @@ -938,18 +942,31 @@ auto ChatLLM::promptInternal( // handle this like below where we have a QByteArray toolCallParser.update(QString::fromStdString(piece.data())); - // Create a toolcall and split the response if needed - if (!toolCallParser.hasSplit() && toolCallParser.state() == ToolEnums::ParseState::Partial) { - const QPair pair = toolCallParser.split(); - m_chatModel->splitToolCall(pair); + // Split the response into two if needed and create chat items + if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) { + const QVector &parseBuffers = toolCallParser.buffers(); + Q_ASSERT(parseBuffers.size() == 2); + if (toolCallParser.startTag() == ToolCallConstants::ThinkTag) + m_chatModel->splitThinking({parseBuffers.at(0), parseBuffers.at(1)}); + else + m_chatModel->splitToolCall({parseBuffers.at(0), parseBuffers.at(1)}); + } + + // Split the response into three if needed and create chat items + if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkTag + && toolCallParser.splitIfPossible()) { + const QVector &parseBuffers = toolCallParser.buffers(); + Q_ASSERT(parseBuffers.size() == 3); + m_chatModel->endThinking({parseBuffers.at(1), parseBuffers.at(2)}, totalTime.elapsed()); } result.response.append(piece.data(), piece.size()); auto respStr = QString::fromUtf8(result.response); try { - if (toolCallParser.hasSplit()) - m_chatModel->setResponseValue(toolCallParser.buffer()); + const QVector &parseBuffers = toolCallParser.buffers(); + if (parseBuffers.size() > 1) + m_chatModel->setResponseValue(parseBuffers.last()); else m_chatModel->setResponseValue(removeLeadingWhitespace(respStr)); } catch (const std::exception &e) { @@ -962,13 +979,11 @@ auto ChatLLM::promptInternal( emit responseChanged(); - const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete; - return !foundToolCall && !m_stopGenerating; - }; + const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete + && toolCallParser.startTag() != ToolCallConstants::ThinkTag; - QElapsedTimer totalTime; - totalTime.start(); - m_timer->start(); + return !shouldExecuteToolCall && !m_stopGenerating; + }; try { emit promptProcessing(); @@ -983,20 +998,22 @@ auto ChatLLM::promptInternal( m_timer->stop(); qint64 elapsed = totalTime.elapsed(); - const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete; + const QVector &parseBuffers = toolCallParser.buffers(); + const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete + && toolCallParser.startTag() != ToolCallConstants::ThinkTag; // trim trailing whitespace auto respStr = QString::fromUtf8(result.response); - if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || foundToolCall)) { - if (toolCallParser.hasSplit()) - m_chatModel->setResponseValue(toolCallParser.buffer()); + if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || parseBuffers.size() > 1)) { + if (parseBuffers.size() > 1) + m_chatModel->setResponseValue(parseBuffers.last()); else m_chatModel->setResponseValue(respStr.trimmed()); emit responseChanged(); } bool doQuestions = false; - if (!m_isServer && messageItems && !foundToolCall) { + if (!m_isServer && messageItems && !shouldExecuteToolCall) { switch (mySettings->suggestionMode()) { case SuggestionMode::On: doQuestions = true; break; case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break; diff --git a/gpt4all-chat/src/chatmodel.cpp b/gpt4all-chat/src/chatmodel.cpp index 15862740..f18bd1e1 100644 --- a/gpt4all-chat/src/chatmodel.cpp +++ b/gpt4all-chat/src/chatmodel.cpp @@ -41,6 +41,12 @@ void ChatItem::serializeText(QDataStream &stream, int version) stream << value; } +void ChatItem::serializeThink(QDataStream &stream, int version) +{ + stream << value; + stream << thinkingTime; +} + void ChatItem::serializeSubItems(QDataStream &stream, int version) { stream << name; @@ -50,6 +56,7 @@ void ChatItem::serializeSubItems(QDataStream &stream, int version) case ToolCall: { serializeToolCall(stream, version); break; } case ToolResponse: { serializeToolResponse(stream, version); break; } case Text: { serializeText(stream, version); break; } + case Think: { serializeThink(stream, version); break; } case System: case Prompt: throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ))); @@ -162,6 +169,13 @@ bool ChatItem::deserializeResponse(QDataStream &stream, int version) return true; } +bool ChatItem::deserializeThink(QDataStream &stream, int version) +{ + stream >> value; + stream >> thinkingTime; + return true; +} + bool ChatItem::deserializeSubItems(QDataStream &stream, int version) { stream >> name; @@ -177,6 +191,7 @@ bool ChatItem::deserializeSubItems(QDataStream &stream, int version) case ToolCall: { deserializeToolCall(stream, version); break; } case ToolResponse: { deserializeToolResponse(stream, version); break; } case Text: { deserializeText(stream, version); break; } + case Think: { deserializeThink(stream, version); break; } case System: case Prompt: throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ))); diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index da898e96..fadbbee3 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -159,8 +159,11 @@ class ChatItem : public QObject Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState ) Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) + // thinking + Q_PROPERTY(int thinkingTime MEMBER thinkingTime NOTIFY thinkingTimeChanged) + public: - enum class Type { System, Prompt, Response, Text, ToolCall, ToolResponse }; + enum class Type { System, Prompt, Response, Text, ToolCall, ToolResponse, Think }; // tags for constructing ChatItems struct prompt_tag_t { explicit prompt_tag_t () = default; }; @@ -169,12 +172,14 @@ public: struct text_tag_t { explicit text_tag_t () = default; }; struct tool_call_tag_t { explicit tool_call_tag_t () = default; }; struct tool_response_tag_t { explicit tool_response_tag_t() = default; }; + struct think_tag_t { explicit think_tag_t () = default; }; static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t {}; static inline constexpr response_tag_t response_tag = response_tag_t {}; static inline constexpr system_tag_t system_tag = system_tag_t {}; static inline constexpr text_tag_t text_tag = text_tag_t {}; static inline constexpr tool_call_tag_t tool_call_tag = tool_call_tag_t {}; static inline constexpr tool_response_tag_t tool_response_tag = tool_response_tag_t {}; + static inline constexpr think_tag_t think_tag = think_tag_t {}; public: ChatItem(QObject *parent) @@ -220,6 +225,10 @@ public: : ChatItem(parent) { this->name = u"ToolResponse: "_s; this->value = value; } + ChatItem(QObject *parent, think_tag_t, const QString &value) + : ChatItem(parent) + { this->name = u"Think: "_s; this->value = value; } + Type type() const { if (name == u"System: "_s) @@ -234,6 +243,8 @@ public: return Type::ToolCall; if (name == u"ToolResponse: "_s) return Type::ToolResponse; + if (name == u"Think: "_s) + return Type::Think; throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name)); } @@ -265,9 +276,11 @@ public: return beforeToolCall; } - // For tool calls we only return content if it is the code interpreter + if (type() == Type::Think) + return thinkContent(value); + if (type() == Type::ToolCall) - return codeInterpreterContent(value); + return toolCallContent(value); // We don't show any of content from the tool response in the GUI if (type() == Type::ToolResponse) @@ -276,7 +289,18 @@ public: return value; } - QString codeInterpreterContent(const QString &value) const + QString thinkContent(const QString &value) const + { + ToolCallParser parser; + parser.update(value); + + // Extract the content + QString content = parser.toolCall(); + content = content.trimmed(); + return content; + } + + QString toolCallContent(const QString &value) const { ToolCallParser parser; parser.update(value); @@ -357,6 +381,12 @@ public: return toolCallInfo.error != ToolEnums::Error::NoError; } + void setThinkingTime(int t) + { + thinkingTime = t; + emit thinkingTimeChanged(); + } + // NB: Assumes response is not current. static ChatItem *fromMessageInput(QObject *parent, const MessageInput &message) { @@ -380,6 +410,7 @@ public: case ToolResponse: msgType = MessageItem::Type::ToolResponse; break; case Text: case ToolCall: + case Think: throw std::invalid_argument(fmt::format("cannot convert ChatItem type {} to message item", int(typ))); } return { msgType, flattenedContent(), sources, promptAttachments }; @@ -391,6 +422,7 @@ public: void serializeToolCall(QDataStream &stream, int version); void serializeToolResponse(QDataStream &stream, int version); void serializeText(QDataStream &stream, int version); + void serializeThink(QDataStream &stream, int version); void serializeSubItems(QDataStream &stream, int version); // recursive void serialize(QDataStream &stream, int version); @@ -399,6 +431,7 @@ public: bool deserializeToolCall(QDataStream &stream, int version); bool deserializeToolResponse(QDataStream &stream, int version); bool deserializeText(QDataStream &stream, int version); + bool deserializeThink(QDataStream &stream, int version); bool deserializeSubItems(QDataStream &stream, int version); // recursive bool deserialize(QDataStream &stream, int version); @@ -406,6 +439,7 @@ Q_SIGNALS: void contentChanged(); void isTooCallErrorChanged(); void isCurrentResponseChanged(); + void thinkingTimeChanged(); public: @@ -429,6 +463,9 @@ public: bool stopped = false; bool thumbsUpState = false; bool thumbsDownState = false; + + // thinking time in ms + int thinkingTime = 0; }; class ChatModel : public QAbstractListModel @@ -879,6 +916,70 @@ public: if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); } + Q_INVOKABLE void splitThinking(const QPair &split) + { + qsizetype index; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("can only set thinking on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + ChatItem *currentResponse = m_chatItems.back(); + Q_ASSERT(currentResponse->isCurrentResponse); + + // Create a new response container for any text and the thinking + ChatItem *newResponse = new ChatItem(this, ChatItem::response_tag); + + // Add preceding text if any + if (!split.first.isEmpty()) { + ChatItem *textItem = new ChatItem(this, ChatItem::text_tag, split.first); + newResponse->subItems.push_back(textItem); + } + + // Add the thinking item + Q_ASSERT(!split.second.isEmpty()); + ChatItem *thinkingItem = new ChatItem(this, ChatItem::think_tag, split.second); + thinkingItem->isCurrentResponse = true; + newResponse->subItems.push_back(thinkingItem); + + // Add new response and reset our value + currentResponse->subItems.push_back(newResponse); + currentResponse->value = QString(); + } + + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + + Q_INVOKABLE void endThinking(const QPair &split, int thinkingTime) + { + qsizetype index; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("can only end thinking on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + ChatItem *currentResponse = m_chatItems.back(); + Q_ASSERT(currentResponse->isCurrentResponse); + + ChatItem *subResponse = currentResponse->subItems.back(); + Q_ASSERT(subResponse->type() == ChatItem::Type::Response); + Q_ASSERT(subResponse->isCurrentResponse); + subResponse->setCurrentResponse(false); + + ChatItem *thinkingItem = subResponse->subItems.back(); + Q_ASSERT(thinkingItem->type() == ChatItem::Type::Think); + thinkingItem->setCurrentResponse(false); + thinkingItem->setValue(split.first); + thinkingItem->setThinkingTime(thinkingTime); + + currentResponse->setValue(split.second); + } + + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + Q_INVOKABLE void splitToolCall(const QPair &split) { qsizetype index; diff --git a/gpt4all-chat/src/tool.h b/gpt4all-chat/src/tool.h index 209cf470..0af645f5 100644 --- a/gpt4all-chat/src/tool.h +++ b/gpt4all-chat/src/tool.h @@ -28,6 +28,7 @@ namespace ToolEnums enum class ParseState { None, + InTagChoice, InStart, Partial, Complete, diff --git a/gpt4all-chat/src/toolcallparser.cpp b/gpt4all-chat/src/toolcallparser.cpp index 7649c21d..de2ac0e5 100644 --- a/gpt4all-chat/src/toolcallparser.cpp +++ b/gpt4all-chat/src/toolcallparser.cpp @@ -6,11 +6,12 @@ #include -static const QString ToolCallStart = ToolCallConstants::CodeInterpreterTag; -static const QString ToolCallEnd = ToolCallConstants::CodeInterpreterEndTag; - ToolCallParser::ToolCallParser() { + m_possibleStartTags << ToolCallConstants::CodeInterpreterTag + << ToolCallConstants::ThinkTag; + m_possibleEndTags << ToolCallConstants::CodeInterpreterEndTag + << ToolCallConstants::ThinkEndTag; reset(); } @@ -20,18 +21,56 @@ void ToolCallParser::reset() resetSearchState(); // These are global states maintained between update calls - m_buffer.clear(); - m_hasSplit = false; + m_buffers.clear(); + m_buffers.append(QString()); } void ToolCallParser::resetSearchState() { - m_expected = ToolCallStart.at(0); + m_expected = {'<'}; m_expectedIndex = 0; m_state = ToolEnums::ParseState::None; + m_toolCall.clear(); + m_startTagBuffer.clear(); m_endTagBuffer.clear(); + + m_currentTagIndex = -1; m_startIndex = -1; + m_endIndex = -1; +} + +bool ToolCallParser::isExpected(QChar c) const +{ + return m_expected.isEmpty() || m_expected.contains(c); +} + +void ToolCallParser::setExpected(const QStringList &tags) +{ + m_expected.clear(); + for (const QString &tag : tags) { + Q_ASSERT(tag.size() > m_expectedIndex); + m_expected << tag.at(m_expectedIndex); + } +} + +QString ToolCallParser::startTag() const +{ + if (m_currentTagIndex < 0) + return QString(); + return m_possibleStartTags.at(m_currentTagIndex); +} + +QString ToolCallParser::endTag() const +{ + if (m_currentTagIndex < 0) + return QString(); + return m_possibleEndTags.at(m_currentTagIndex); +} + +QString &ToolCallParser::currentBuffer() +{ + return m_buffers.last(); } // This method is called with an arbitrary string and a current state. This method should take the @@ -39,17 +78,11 @@ void ToolCallParser::resetSearchState() // the new state. void ToolCallParser::update(const QString &update) { - Q_ASSERT(m_state != ToolEnums::ParseState::Complete); - if (m_state == ToolEnums::ParseState::Complete) { - qWarning() << "ERROR: ToolCallParser::update already found a complete toolcall!"; - return; - } + currentBuffer().append(update); - m_buffer.append(update); - - for (size_t i = m_buffer.size() - update.size(); i < m_buffer.size(); ++i) { - const QChar c = m_buffer[i]; - const bool foundMatch = m_expected.isNull() || c == m_expected; + for (size_t i = currentBuffer().size() - update.size(); i < currentBuffer().size(); ++i) { + const QChar c = currentBuffer()[i]; + const bool foundMatch = isExpected(c); if (!foundMatch) { resetSearchState(); continue; @@ -59,34 +92,58 @@ void ToolCallParser::update(const QString &update) case ToolEnums::ParseState::None: { m_expectedIndex = 1; - m_expected = ToolCallStart.at(1); - m_state = ToolEnums::ParseState::InStart; + setExpected(m_possibleStartTags); + m_state = ToolEnums::ParseState::InTagChoice; m_startIndex = i; break; } + case ToolEnums::ParseState::InTagChoice: + { + for (int i = 0; i < m_possibleStartTags.size(); ++i) { + const QString tag = m_possibleStartTags.at(i); + if (c == tag.at(1)) m_currentTagIndex = i; + } + if (m_currentTagIndex >= 0) { + m_expectedIndex = 2; + setExpected({m_possibleStartTags.at(m_currentTagIndex)}); + m_state = ToolEnums::ParseState::InStart; + } else + resetSearchState(); + break; + } case ToolEnums::ParseState::InStart: { - if (m_expectedIndex == ToolCallStart.size() - 1) { + m_startTagBuffer.append(c); + + const QString startTag = this->startTag(); + Q_ASSERT(!startTag.isEmpty()); + if (m_expectedIndex == startTag.size() - 1) { m_expectedIndex = 0; - m_expected = QChar(); + setExpected({}); m_state = ToolEnums::ParseState::Partial; } else { ++m_expectedIndex; - m_expected = ToolCallStart.at(m_expectedIndex); + Q_ASSERT(m_currentTagIndex >= 0); + setExpected({startTag}); } break; } case ToolEnums::ParseState::Partial: { + Q_ASSERT(m_currentTagIndex >= 0); + const QString endTag = this->endTag(); + Q_ASSERT(!endTag.isEmpty()); m_toolCall.append(c); m_endTagBuffer.append(c); - if (m_endTagBuffer.size() > ToolCallEnd.size()) + if (m_endTagBuffer.size() > endTag.size()) m_endTagBuffer.remove(0, 1); - if (m_endTagBuffer == ToolCallEnd) { - m_toolCall.chop(ToolCallEnd.size()); + if (m_endTagBuffer == endTag) { + m_endIndex = i + 1; + m_toolCall.chop(endTag.size()); m_state = ToolEnums::ParseState::Complete; m_endTagBuffer.clear(); } + break; } case ToolEnums::ParseState::Complete: { @@ -97,15 +154,31 @@ void ToolCallParser::update(const QString &update) } } -QPair ToolCallParser::split() +bool ToolCallParser::splitIfPossible() { - Q_ASSERT(m_state == ToolEnums::ParseState::Partial - || m_state == ToolEnums::ParseState::Complete); + // The first split happens when we're in a partial state + if (m_buffers.size() < 2 && m_state == ToolEnums::ParseState::Partial) { + Q_ASSERT(m_startIndex >= 0); + const QString beforeToolCall = currentBuffer().left(m_startIndex); + const QString toolCall = currentBuffer().mid(m_startIndex); + m_buffers = { beforeToolCall, toolCall }; + return true; + } - Q_ASSERT(m_startIndex >= 0); - m_hasSplit = true; - const QString beforeToolCall = m_buffer.left(m_startIndex); - m_buffer = m_buffer.mid(m_startIndex); - m_startIndex = 0; - return { beforeToolCall, m_buffer }; + // The second split happens when we're in the complete state + if (m_buffers.size() < 3 && m_state == ToolEnums::ParseState::Complete) { + Q_ASSERT(m_endIndex >= 0); + const QString beforeToolCall = m_buffers.first(); + const QString toolCall = currentBuffer().left(m_endIndex); + const QString afterToolCall = currentBuffer().mid(m_endIndex); + m_buffers = { beforeToolCall, toolCall, afterToolCall }; + return true; + } + + return false; +} + +const QVector &ToolCallParser::buffers() const +{ + return m_buffers; } diff --git a/gpt4all-chat/src/toolcallparser.h b/gpt4all-chat/src/toolcallparser.h index 855cb6b7..16ebdfe3 100644 --- a/gpt4all-chat/src/toolcallparser.h +++ b/gpt4all-chat/src/toolcallparser.h @@ -14,6 +14,10 @@ namespace ToolCallConstants const QString CodeInterpreterEndTag = R"()"; const QString CodeInterpreterPrefix = CodeInterpreterTag + "\n```javascript\n"; const QString CodeInterpreterSuffix = "```\n" + CodeInterpreterEndTag; + + // NB: the parsing code assumes the first char of the various tags differ + const QString ThinkTag = QStringLiteral(""); + const QString ThinkEndTag = QStringLiteral(""); } class ToolCallParser @@ -22,26 +26,35 @@ public: ToolCallParser(); void reset(); void update(const QString &update); - QString buffer() const { return m_buffer; } QString toolCall() const { return m_toolCall; } int startIndex() const { return m_startIndex; } ToolEnums::ParseState state() const { return m_state; } + QString startTag() const; + QString endTag() const; - // Splits - QPair split(); - bool hasSplit() const { return m_hasSplit; } + bool splitIfPossible(); + const QVector &buffers() const; + int numberOfBuffers() const { return m_buffers.size(); } private: + QString ¤tBuffer(); void resetSearchState(); + bool isExpected(QChar c) const; + void setExpected(const QStringList &tags); - QChar m_expected; + QStringList m_possibleStartTags; + QStringList m_possibleEndTags; + QString m_startTagBuffer; + QString m_endTagBuffer; + int m_currentTagIndex; + + QVector m_expected; int m_expectedIndex; ToolEnums::ParseState m_state; - QString m_buffer; + QVector m_buffers; QString m_toolCall; - QString m_endTagBuffer; int m_startIndex; - bool m_hasSplit; + int m_endIndex; }; #endif // TOOLCALLPARSER_H