Display DeepSeek-R1 thinking like Reasoner (#3440)

Signed-off-by: Adam Treat <treat.adam@gmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT 2025-01-30 16:11:05 -05:00 committed by GitHub
parent f914ee56c9
commit 007a7af1c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 306 additions and 68 deletions

View File

@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Added
- Support DeepSeek-R1 Qwen models ([#3431](https://github.com/nomic-ai/gpt4all/pull/3431))
- Support for think tags in the GUI ([#3440](https://github.com/nomic-ai/gpt4all/pull/3440))
### Changed
- Use minja instead of Jinja2Cpp for significantly improved template compatibility ([#3433](https://github.com/nomic-ai/gpt4all/pull/3433))

View File

@ -13,6 +13,8 @@ ColumnLayout {
property alias textContent: innerTextItem.textContent
property bool isCurrent: false
property bool isError: false
property bool isThinking: false
property int thinkingTime: 0
Layout.topMargin: 10
Layout.bottomMargin: 10
@ -26,16 +28,20 @@ ColumnLayout {
anchors.bottom: parent.bottom
Item {
width: myTextArea.width
height: myTextArea.height
Layout.preferredWidth: myTextArea.implicitWidth
Layout.preferredHeight: myTextArea.implicitHeight
TextArea {
id: myTextArea
text: {
if (isError)
return qsTr("Analysis encountered error");
if (isCurrent)
return qsTr("Analyzing");
return qsTr("Analyzed");
return isThinking ? qsTr("Thinking") : qsTr("Analyzing");
return isThinking
? qsTr("Thought for %1 %2")
.arg(Math.ceil(thinkingTime / 1000.0))
.arg(Math.ceil(thinkingTime / 1000.0) === 1 ? qsTr("second") : qsTr("seconds"))
: qsTr("Analyzed");
}
padding: 0
font.pixelSize: theme.fontSizeLarger

View File

@ -189,6 +189,17 @@ GridLayout {
isError: modelData.isToolCallError
}
}
DelegateChoice {
roleValue: "Think: ";
ChatCollapsibleItem {
Layout.fillWidth: true
textContent: modelData.content
isCurrent: modelData.isCurrentResponse
isError: false
isThinking: true
thinkingTime: modelData.thinkingTime
}
}
}
delegate: chooser

View File

@ -255,7 +255,7 @@ void Chat::responseStopped(qint64 promptResponseMs)
ToolCallParser parser;
parser.update(possibleToolcall);
if (parser.state() == ToolEnums::ParseState::Complete)
if (parser.state() == ToolEnums::ParseState::Complete && parser.startTag() != ToolCallConstants::ThinkTag)
processToolCall(parser.toolCall());
else
responseComplete();

View File

@ -928,8 +928,12 @@ auto ChatLLM::promptInternal(
return !m_stopGenerating;
};
QElapsedTimer totalTime;
totalTime.start();
m_timer->start();
ToolCallParser toolCallParser;
auto handleResponse = [this, &result, &toolCallParser](LLModel::Token token, std::string_view piece) -> bool {
auto handleResponse = [this, &result, &toolCallParser, &totalTime](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
result.responseTokens++;
m_timer->inc();
@ -938,18 +942,31 @@ auto ChatLLM::promptInternal(
// handle this like below where we have a QByteArray
toolCallParser.update(QString::fromStdString(piece.data()));
// Create a toolcall and split the response if needed
if (!toolCallParser.hasSplit() && toolCallParser.state() == ToolEnums::ParseState::Partial) {
const QPair<QString, QString> pair = toolCallParser.split();
m_chatModel->splitToolCall(pair);
// Split the response into two if needed and create chat items
if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) {
const QVector<QString> &parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 2);
if (toolCallParser.startTag() == ToolCallConstants::ThinkTag)
m_chatModel->splitThinking({parseBuffers.at(0), parseBuffers.at(1)});
else
m_chatModel->splitToolCall({parseBuffers.at(0), parseBuffers.at(1)});
}
// Split the response into three if needed and create chat items
if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkTag
&& toolCallParser.splitIfPossible()) {
const QVector<QString> &parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 3);
m_chatModel->endThinking({parseBuffers.at(1), parseBuffers.at(2)}, totalTime.elapsed());
}
result.response.append(piece.data(), piece.size());
auto respStr = QString::fromUtf8(result.response);
try {
if (toolCallParser.hasSplit())
m_chatModel->setResponseValue(toolCallParser.buffer());
const QVector<QString> &parseBuffers = toolCallParser.buffers();
if (parseBuffers.size() > 1)
m_chatModel->setResponseValue(parseBuffers.last());
else
m_chatModel->setResponseValue(removeLeadingWhitespace(respStr));
} catch (const std::exception &e) {
@ -962,13 +979,11 @@ auto ChatLLM::promptInternal(
emit responseChanged();
const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete;
return !foundToolCall && !m_stopGenerating;
};
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
QElapsedTimer totalTime;
totalTime.start();
m_timer->start();
return !shouldExecuteToolCall && !m_stopGenerating;
};
try {
emit promptProcessing();
@ -983,20 +998,22 @@ auto ChatLLM::promptInternal(
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete;
const QVector<QString> &parseBuffers = toolCallParser.buffers();
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
// trim trailing whitespace
auto respStr = QString::fromUtf8(result.response);
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || foundToolCall)) {
if (toolCallParser.hasSplit())
m_chatModel->setResponseValue(toolCallParser.buffer());
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || parseBuffers.size() > 1)) {
if (parseBuffers.size() > 1)
m_chatModel->setResponseValue(parseBuffers.last());
else
m_chatModel->setResponseValue(respStr.trimmed());
emit responseChanged();
}
bool doQuestions = false;
if (!m_isServer && messageItems && !foundToolCall) {
if (!m_isServer && messageItems && !shouldExecuteToolCall) {
switch (mySettings->suggestionMode()) {
case SuggestionMode::On: doQuestions = true; break;
case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break;

View File

@ -41,6 +41,12 @@ void ChatItem::serializeText(QDataStream &stream, int version)
stream << value;
}
void ChatItem::serializeThink(QDataStream &stream, int version)
{
stream << value;
stream << thinkingTime;
}
void ChatItem::serializeSubItems(QDataStream &stream, int version)
{
stream << name;
@ -50,6 +56,7 @@ void ChatItem::serializeSubItems(QDataStream &stream, int version)
case ToolCall: { serializeToolCall(stream, version); break; }
case ToolResponse: { serializeToolResponse(stream, version); break; }
case Text: { serializeText(stream, version); break; }
case Think: { serializeThink(stream, version); break; }
case System:
case Prompt:
throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ)));
@ -162,6 +169,13 @@ bool ChatItem::deserializeResponse(QDataStream &stream, int version)
return true;
}
bool ChatItem::deserializeThink(QDataStream &stream, int version)
{
stream >> value;
stream >> thinkingTime;
return true;
}
bool ChatItem::deserializeSubItems(QDataStream &stream, int version)
{
stream >> name;
@ -177,6 +191,7 @@ bool ChatItem::deserializeSubItems(QDataStream &stream, int version)
case ToolCall: { deserializeToolCall(stream, version); break; }
case ToolResponse: { deserializeToolResponse(stream, version); break; }
case Text: { deserializeText(stream, version); break; }
case Think: { deserializeThink(stream, version); break; }
case System:
case Prompt:
throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ)));

View File

@ -159,8 +159,11 @@ class ChatItem : public QObject
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState )
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
// thinking
Q_PROPERTY(int thinkingTime MEMBER thinkingTime NOTIFY thinkingTimeChanged)
public:
enum class Type { System, Prompt, Response, Text, ToolCall, ToolResponse };
enum class Type { System, Prompt, Response, Text, ToolCall, ToolResponse, Think };
// tags for constructing ChatItems
struct prompt_tag_t { explicit prompt_tag_t () = default; };
@ -169,12 +172,14 @@ public:
struct text_tag_t { explicit text_tag_t () = default; };
struct tool_call_tag_t { explicit tool_call_tag_t () = default; };
struct tool_response_tag_t { explicit tool_response_tag_t() = default; };
struct think_tag_t { explicit think_tag_t () = default; };
static inline constexpr prompt_tag_t prompt_tag = prompt_tag_t {};
static inline constexpr response_tag_t response_tag = response_tag_t {};
static inline constexpr system_tag_t system_tag = system_tag_t {};
static inline constexpr text_tag_t text_tag = text_tag_t {};
static inline constexpr tool_call_tag_t tool_call_tag = tool_call_tag_t {};
static inline constexpr tool_response_tag_t tool_response_tag = tool_response_tag_t {};
static inline constexpr think_tag_t think_tag = think_tag_t {};
public:
ChatItem(QObject *parent)
@ -220,6 +225,10 @@ public:
: ChatItem(parent)
{ this->name = u"ToolResponse: "_s; this->value = value; }
ChatItem(QObject *parent, think_tag_t, const QString &value)
: ChatItem(parent)
{ this->name = u"Think: "_s; this->value = value; }
Type type() const
{
if (name == u"System: "_s)
@ -234,6 +243,8 @@ public:
return Type::ToolCall;
if (name == u"ToolResponse: "_s)
return Type::ToolResponse;
if (name == u"Think: "_s)
return Type::Think;
throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name));
}
@ -265,9 +276,11 @@ public:
return beforeToolCall;
}
// For tool calls we only return content if it is the code interpreter
if (type() == Type::Think)
return thinkContent(value);
if (type() == Type::ToolCall)
return codeInterpreterContent(value);
return toolCallContent(value);
// We don't show any of content from the tool response in the GUI
if (type() == Type::ToolResponse)
@ -276,7 +289,18 @@ public:
return value;
}
QString codeInterpreterContent(const QString &value) const
QString thinkContent(const QString &value) const
{
ToolCallParser parser;
parser.update(value);
// Extract the content
QString content = parser.toolCall();
content = content.trimmed();
return content;
}
QString toolCallContent(const QString &value) const
{
ToolCallParser parser;
parser.update(value);
@ -357,6 +381,12 @@ public:
return toolCallInfo.error != ToolEnums::Error::NoError;
}
void setThinkingTime(int t)
{
thinkingTime = t;
emit thinkingTimeChanged();
}
// NB: Assumes response is not current.
static ChatItem *fromMessageInput(QObject *parent, const MessageInput &message)
{
@ -380,6 +410,7 @@ public:
case ToolResponse: msgType = MessageItem::Type::ToolResponse; break;
case Text:
case ToolCall:
case Think:
throw std::invalid_argument(fmt::format("cannot convert ChatItem type {} to message item", int(typ)));
}
return { msgType, flattenedContent(), sources, promptAttachments };
@ -391,6 +422,7 @@ public:
void serializeToolCall(QDataStream &stream, int version);
void serializeToolResponse(QDataStream &stream, int version);
void serializeText(QDataStream &stream, int version);
void serializeThink(QDataStream &stream, int version);
void serializeSubItems(QDataStream &stream, int version); // recursive
void serialize(QDataStream &stream, int version);
@ -399,6 +431,7 @@ public:
bool deserializeToolCall(QDataStream &stream, int version);
bool deserializeToolResponse(QDataStream &stream, int version);
bool deserializeText(QDataStream &stream, int version);
bool deserializeThink(QDataStream &stream, int version);
bool deserializeSubItems(QDataStream &stream, int version); // recursive
bool deserialize(QDataStream &stream, int version);
@ -406,6 +439,7 @@ Q_SIGNALS:
void contentChanged();
void isTooCallErrorChanged();
void isCurrentResponseChanged();
void thinkingTimeChanged();
public:
@ -429,6 +463,9 @@ public:
bool stopped = false;
bool thumbsUpState = false;
bool thumbsDownState = false;
// thinking time in ms
int thinkingTime = 0;
};
class ChatModel : public QAbstractListModel
@ -879,6 +916,70 @@ public:
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
}
Q_INVOKABLE void splitThinking(const QPair<QString, QString> &split)
{
qsizetype index;
{
QMutexLocker locker(&m_mutex);
if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response)
throw std::logic_error("can only set thinking on a chat that ends with a response");
index = m_chatItems.count() - 1;
ChatItem *currentResponse = m_chatItems.back();
Q_ASSERT(currentResponse->isCurrentResponse);
// Create a new response container for any text and the thinking
ChatItem *newResponse = new ChatItem(this, ChatItem::response_tag);
// Add preceding text if any
if (!split.first.isEmpty()) {
ChatItem *textItem = new ChatItem(this, ChatItem::text_tag, split.first);
newResponse->subItems.push_back(textItem);
}
// Add the thinking item
Q_ASSERT(!split.second.isEmpty());
ChatItem *thinkingItem = new ChatItem(this, ChatItem::think_tag, split.second);
thinkingItem->isCurrentResponse = true;
newResponse->subItems.push_back(thinkingItem);
// Add new response and reset our value
currentResponse->subItems.push_back(newResponse);
currentResponse->value = QString();
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole});
}
Q_INVOKABLE void endThinking(const QPair<QString, QString> &split, int thinkingTime)
{
qsizetype index;
{
QMutexLocker locker(&m_mutex);
if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response)
throw std::logic_error("can only end thinking on a chat that ends with a response");
index = m_chatItems.count() - 1;
ChatItem *currentResponse = m_chatItems.back();
Q_ASSERT(currentResponse->isCurrentResponse);
ChatItem *subResponse = currentResponse->subItems.back();
Q_ASSERT(subResponse->type() == ChatItem::Type::Response);
Q_ASSERT(subResponse->isCurrentResponse);
subResponse->setCurrentResponse(false);
ChatItem *thinkingItem = subResponse->subItems.back();
Q_ASSERT(thinkingItem->type() == ChatItem::Type::Think);
thinkingItem->setCurrentResponse(false);
thinkingItem->setValue(split.first);
thinkingItem->setThinkingTime(thinkingTime);
currentResponse->setValue(split.second);
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole});
}
Q_INVOKABLE void splitToolCall(const QPair<QString, QString> &split)
{
qsizetype index;

View File

@ -28,6 +28,7 @@ namespace ToolEnums
enum class ParseState {
None,
InTagChoice,
InStart,
Partial,
Complete,

View File

@ -6,11 +6,12 @@
#include <cstddef>
static const QString ToolCallStart = ToolCallConstants::CodeInterpreterTag;
static const QString ToolCallEnd = ToolCallConstants::CodeInterpreterEndTag;
ToolCallParser::ToolCallParser()
{
m_possibleStartTags << ToolCallConstants::CodeInterpreterTag
<< ToolCallConstants::ThinkTag;
m_possibleEndTags << ToolCallConstants::CodeInterpreterEndTag
<< ToolCallConstants::ThinkEndTag;
reset();
}
@ -20,18 +21,56 @@ void ToolCallParser::reset()
resetSearchState();
// These are global states maintained between update calls
m_buffer.clear();
m_hasSplit = false;
m_buffers.clear();
m_buffers.append(QString());
}
void ToolCallParser::resetSearchState()
{
m_expected = ToolCallStart.at(0);
m_expected = {'<'};
m_expectedIndex = 0;
m_state = ToolEnums::ParseState::None;
m_toolCall.clear();
m_startTagBuffer.clear();
m_endTagBuffer.clear();
m_currentTagIndex = -1;
m_startIndex = -1;
m_endIndex = -1;
}
bool ToolCallParser::isExpected(QChar c) const
{
return m_expected.isEmpty() || m_expected.contains(c);
}
void ToolCallParser::setExpected(const QStringList &tags)
{
m_expected.clear();
for (const QString &tag : tags) {
Q_ASSERT(tag.size() > m_expectedIndex);
m_expected << tag.at(m_expectedIndex);
}
}
QString ToolCallParser::startTag() const
{
if (m_currentTagIndex < 0)
return QString();
return m_possibleStartTags.at(m_currentTagIndex);
}
QString ToolCallParser::endTag() const
{
if (m_currentTagIndex < 0)
return QString();
return m_possibleEndTags.at(m_currentTagIndex);
}
QString &ToolCallParser::currentBuffer()
{
return m_buffers.last();
}
// This method is called with an arbitrary string and a current state. This method should take the
@ -39,17 +78,11 @@ void ToolCallParser::resetSearchState()
// the new state.
void ToolCallParser::update(const QString &update)
{
Q_ASSERT(m_state != ToolEnums::ParseState::Complete);
if (m_state == ToolEnums::ParseState::Complete) {
qWarning() << "ERROR: ToolCallParser::update already found a complete toolcall!";
return;
}
currentBuffer().append(update);
m_buffer.append(update);
for (size_t i = m_buffer.size() - update.size(); i < m_buffer.size(); ++i) {
const QChar c = m_buffer[i];
const bool foundMatch = m_expected.isNull() || c == m_expected;
for (size_t i = currentBuffer().size() - update.size(); i < currentBuffer().size(); ++i) {
const QChar c = currentBuffer()[i];
const bool foundMatch = isExpected(c);
if (!foundMatch) {
resetSearchState();
continue;
@ -59,34 +92,58 @@ void ToolCallParser::update(const QString &update)
case ToolEnums::ParseState::None:
{
m_expectedIndex = 1;
m_expected = ToolCallStart.at(1);
m_state = ToolEnums::ParseState::InStart;
setExpected(m_possibleStartTags);
m_state = ToolEnums::ParseState::InTagChoice;
m_startIndex = i;
break;
}
case ToolEnums::ParseState::InTagChoice:
{
for (int i = 0; i < m_possibleStartTags.size(); ++i) {
const QString tag = m_possibleStartTags.at(i);
if (c == tag.at(1)) m_currentTagIndex = i;
}
if (m_currentTagIndex >= 0) {
m_expectedIndex = 2;
setExpected({m_possibleStartTags.at(m_currentTagIndex)});
m_state = ToolEnums::ParseState::InStart;
} else
resetSearchState();
break;
}
case ToolEnums::ParseState::InStart:
{
if (m_expectedIndex == ToolCallStart.size() - 1) {
m_startTagBuffer.append(c);
const QString startTag = this->startTag();
Q_ASSERT(!startTag.isEmpty());
if (m_expectedIndex == startTag.size() - 1) {
m_expectedIndex = 0;
m_expected = QChar();
setExpected({});
m_state = ToolEnums::ParseState::Partial;
} else {
++m_expectedIndex;
m_expected = ToolCallStart.at(m_expectedIndex);
Q_ASSERT(m_currentTagIndex >= 0);
setExpected({startTag});
}
break;
}
case ToolEnums::ParseState::Partial:
{
Q_ASSERT(m_currentTagIndex >= 0);
const QString endTag = this->endTag();
Q_ASSERT(!endTag.isEmpty());
m_toolCall.append(c);
m_endTagBuffer.append(c);
if (m_endTagBuffer.size() > ToolCallEnd.size())
if (m_endTagBuffer.size() > endTag.size())
m_endTagBuffer.remove(0, 1);
if (m_endTagBuffer == ToolCallEnd) {
m_toolCall.chop(ToolCallEnd.size());
if (m_endTagBuffer == endTag) {
m_endIndex = i + 1;
m_toolCall.chop(endTag.size());
m_state = ToolEnums::ParseState::Complete;
m_endTagBuffer.clear();
}
break;
}
case ToolEnums::ParseState::Complete:
{
@ -97,15 +154,31 @@ void ToolCallParser::update(const QString &update)
}
}
QPair<QString, QString> ToolCallParser::split()
bool ToolCallParser::splitIfPossible()
{
Q_ASSERT(m_state == ToolEnums::ParseState::Partial
|| m_state == ToolEnums::ParseState::Complete);
// The first split happens when we're in a partial state
if (m_buffers.size() < 2 && m_state == ToolEnums::ParseState::Partial) {
Q_ASSERT(m_startIndex >= 0);
const QString beforeToolCall = currentBuffer().left(m_startIndex);
const QString toolCall = currentBuffer().mid(m_startIndex);
m_buffers = { beforeToolCall, toolCall };
return true;
}
Q_ASSERT(m_startIndex >= 0);
m_hasSplit = true;
const QString beforeToolCall = m_buffer.left(m_startIndex);
m_buffer = m_buffer.mid(m_startIndex);
m_startIndex = 0;
return { beforeToolCall, m_buffer };
// The second split happens when we're in the complete state
if (m_buffers.size() < 3 && m_state == ToolEnums::ParseState::Complete) {
Q_ASSERT(m_endIndex >= 0);
const QString beforeToolCall = m_buffers.first();
const QString toolCall = currentBuffer().left(m_endIndex);
const QString afterToolCall = currentBuffer().mid(m_endIndex);
m_buffers = { beforeToolCall, toolCall, afterToolCall };
return true;
}
return false;
}
const QVector<QString> &ToolCallParser::buffers() const
{
return m_buffers;
}

View File

@ -14,6 +14,10 @@ namespace ToolCallConstants
const QString CodeInterpreterEndTag = R"(</)" + CodeInterpreterFunction + R"(>)";
const QString CodeInterpreterPrefix = CodeInterpreterTag + "\n```javascript\n";
const QString CodeInterpreterSuffix = "```\n" + CodeInterpreterEndTag;
// NB: the parsing code assumes the first char of the various tags differ
const QString ThinkTag = QStringLiteral("<think>");
const QString ThinkEndTag = QStringLiteral("</think>");
}
class ToolCallParser
@ -22,26 +26,35 @@ public:
ToolCallParser();
void reset();
void update(const QString &update);
QString buffer() const { return m_buffer; }
QString toolCall() const { return m_toolCall; }
int startIndex() const { return m_startIndex; }
ToolEnums::ParseState state() const { return m_state; }
QString startTag() const;
QString endTag() const;
// Splits
QPair<QString, QString> split();
bool hasSplit() const { return m_hasSplit; }
bool splitIfPossible();
const QVector<QString> &buffers() const;
int numberOfBuffers() const { return m_buffers.size(); }
private:
QString &currentBuffer();
void resetSearchState();
bool isExpected(QChar c) const;
void setExpected(const QStringList &tags);
QChar m_expected;
QStringList m_possibleStartTags;
QStringList m_possibleEndTags;
QString m_startTagBuffer;
QString m_endTagBuffer;
int m_currentTagIndex;
QVector<QChar> m_expected;
int m_expectedIndex;
ToolEnums::ParseState m_state;
QString m_buffer;
QVector<QString> m_buffers;
QString m_toolCall;
QString m_endTagBuffer;
int m_startIndex;
bool m_hasSplit;
int m_endIndex;
};
#endif // TOOLCALLPARSER_H