Ignore DeepSeek-R1 "think" content in name/follow-up responses (#3458)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2025-02-04 12:08:17 -05:00 committed by GitHub
parent d4e6a6e485
commit 8c9f26e249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 295 additions and 137 deletions

View File

@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Fixed
- Fix "index N is not a prompt" when using LocalDocs with reasoning ([#3451](https://github.com/nomic-ai/gpt4all/pull/3451)
- Work around rendering artifacts on Snapdragon SoCs with Windows ([#3450](https://github.com/nomic-ai/gpt4all/pull/3450))
- Prevent DeepSeek-R1 reasoning from appearing in chat names and follow-up questions ([#3458](https://github.com/nomic-ai/gpt4all/pull/3458))
## [3.8.0] - 2025-01-30

View File

@ -255,7 +255,7 @@ void Chat::responseStopped(qint64 promptResponseMs)
ToolCallParser parser;
parser.update(possibleToolcall.toUtf8());
if (parser.state() == ToolEnums::ParseState::Complete && parser.startTag() != ToolCallConstants::ThinkTag)
if (parser.state() == ToolEnums::ParseState::Complete && parser.startTag() != ToolCallConstants::ThinkStartTag)
processToolCall(parser.toolCall());
else
responseComplete();
@ -381,11 +381,8 @@ void Chat::trySwitchContextOfLoadedModel()
void Chat::generatedNameChanged(const QString &name)
{
// Only use the first three words maximum and remove newlines and extra spaces
m_generatedName = name.simplified();
QStringList words = m_generatedName.split(' ', Qt::SkipEmptyParts);
int wordCount = qMin(7, words.size());
m_name = words.mid(0, wordCount).join(' ');
m_generatedName = name;
m_name = name;
emit nameChanged();
m_needsSave = true;
}

View File

@ -29,6 +29,7 @@
#include <QRegularExpression>
#include <QRegularExpressionMatch>
#include <QSet>
#include <QTextStream>
#include <QUrl>
#include <QWaitCondition>
#include <Qt>
@ -92,6 +93,70 @@ static const std::shared_ptr<minja::Context> &jinjaEnv()
return environment;
}
class BaseResponseHandler {
public:
virtual void onSplitIntoTwo (const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) = 0;
virtual void onSplitIntoThree (const QString &secondBuffer, const QString &thirdBuffer) = 0;
// "old-style" responses, with all of the implementation details left in
virtual void onOldResponseChunk(const QByteArray &chunk) = 0;
// notify of a "new-style" response that has been cleaned of tool calling
virtual bool onBufferResponse (const QString &response, int bufferIdx) = 0;
// notify of a "new-style" response, no tool calling applicable
virtual bool onRegularResponse () = 0;
virtual bool getStopGenerating () const = 0;
};
static auto promptModelWithTools(
LLModel *model, const LLModel::PromptCallback &promptCallback, BaseResponseHandler &respHandler,
const LLModel::PromptContext &ctx, const QByteArray &prompt, const QStringList &toolNames
) -> std::pair<QStringList, bool>
{
ToolCallParser toolCallParser(toolNames);
auto handleResponse = [&toolCallParser, &respHandler](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
toolCallParser.update(piece.data());
// Split the response into two if needed
if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 2);
respHandler.onSplitIntoTwo(toolCallParser.startTag(), parseBuffers.at(0), parseBuffers.at(1));
}
// Split the response into three if needed
if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkStartTag
&& toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 3);
respHandler.onSplitIntoThree(parseBuffers.at(1), parseBuffers.at(2));
}
respHandler.onOldResponseChunk(QByteArray::fromRawData(piece.data(), piece.size()));
bool ok;
const auto parseBuffers = toolCallParser.buffers();
if (parseBuffers.size() > 1) {
ok = respHandler.onBufferResponse(parseBuffers.last(), parseBuffers.size() - 1);
} else {
ok = respHandler.onRegularResponse();
}
if (!ok)
return false;
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
return !shouldExecuteToolCall && !respHandler.getStopGenerating();
};
model->prompt(std::string_view(prompt), promptCallback, handleResponse, ctx);
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
return { toolCallParser.buffers(), shouldExecuteToolCall };
}
class LLModelStore {
public:
static LLModelStore *globalInstance();
@ -882,6 +947,62 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL
};
}
class ChatViewResponseHandler : public BaseResponseHandler {
public:
ChatViewResponseHandler(ChatLLM *cllm, QElapsedTimer *totalTime, ChatLLM::PromptResult *result)
: m_cllm(cllm), m_totalTime(totalTime), m_result(result) {}
void onSplitIntoTwo(const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
{
if (startTag == ToolCallConstants::ThinkStartTag)
m_cllm->m_chatModel->splitThinking({ firstBuffer, secondBuffer });
else
m_cllm->m_chatModel->splitToolCall({ firstBuffer, secondBuffer });
}
void onSplitIntoThree(const QString &secondBuffer, const QString &thirdBuffer) override
{
m_cllm->m_chatModel->endThinking({ secondBuffer, thirdBuffer }, m_totalTime->elapsed());
}
void onOldResponseChunk(const QByteArray &chunk) override
{
m_result->responseTokens++;
m_cllm->m_timer->inc();
m_result->response.append(chunk);
}
bool onBufferResponse(const QString &response, int bufferIdx) override
{
Q_UNUSED(bufferIdx)
try {
m_cllm->m_chatModel->setResponseValue(response);
} catch (const std::exception &e) {
// We have a try/catch here because the main thread might have removed the response from
// the chatmodel by erasing the conversation during the response... the main thread sets
// m_stopGenerating before doing so, but it doesn't wait after that to reset the chatmodel
Q_ASSERT(m_cllm->m_stopGenerating);
return false;
}
emit m_cllm->responseChanged();
return true;
}
bool onRegularResponse() override
{
auto respStr = QString::fromUtf8(m_result->response);
return onBufferResponse(removeLeadingWhitespace(respStr), 0);
}
bool getStopGenerating() const override
{ return m_cllm->m_stopGenerating; }
private:
ChatLLM *m_cllm;
QElapsedTimer *m_totalTime;
ChatLLM::PromptResult *m_result;
};
auto ChatLLM::promptInternal(
const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
const LLModel::PromptContext &ctx,
@ -931,64 +1052,20 @@ auto ChatLLM::promptInternal(
QElapsedTimer totalTime;
totalTime.start();
ChatViewResponseHandler respHandler(this, &totalTime, &result);
m_timer->start();
ToolCallParser toolCallParser;
auto handleResponse = [this, &result, &toolCallParser, &totalTime](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
result.responseTokens++;
m_timer->inc();
toolCallParser.update(piece.data());
// Split the response into two if needed and create chat items
if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 2);
if (toolCallParser.startTag() == ToolCallConstants::ThinkTag)
m_chatModel->splitThinking({parseBuffers.at(0), parseBuffers.at(1)});
else
m_chatModel->splitToolCall({parseBuffers.at(0), parseBuffers.at(1)});
}
// Split the response into three if needed and create chat items
if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkTag
&& toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 3);
m_chatModel->endThinking({parseBuffers.at(1), parseBuffers.at(2)}, totalTime.elapsed());
}
result.response.append(piece.data(), piece.size());
auto respStr = QString::fromUtf8(result.response);
try {
const auto parseBuffers = toolCallParser.buffers();
if (parseBuffers.size() > 1)
m_chatModel->setResponseValue(parseBuffers.last());
else
m_chatModel->setResponseValue(removeLeadingWhitespace(respStr));
} catch (const std::exception &e) {
// We have a try/catch here because the main thread might have removed the response from
// the chatmodel by erasing the conversation during the response... the main thread sets
// m_stopGenerating before doing so, but it doesn't wait after that to reset the chatmodel
Q_ASSERT(m_stopGenerating);
return false;
}
emit responseChanged();
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
return !shouldExecuteToolCall && !m_stopGenerating;
};
QStringList finalBuffers;
bool shouldExecuteTool;
try {
emit promptProcessing();
m_llModelInfo.model->setThreadCount(mySettings->threadCount());
m_stopGenerating = false;
m_llModelInfo.model->prompt(conversation, handlePrompt, handleResponse, ctx);
std::tie(finalBuffers, shouldExecuteTool) = promptModelWithTools(
m_llModelInfo.model.get(), handlePrompt, respHandler, ctx,
QByteArray::fromRawData(conversation.data(), conversation.size()),
ToolCallConstants::AllTagNames
);
} catch (...) {
m_timer->stop();
throw;
@ -997,22 +1074,18 @@ auto ChatLLM::promptInternal(
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
const auto parseBuffers = toolCallParser.buffers();
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
// trim trailing whitespace
auto respStr = QString::fromUtf8(result.response);
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || parseBuffers.size() > 1)) {
if (parseBuffers.size() > 1)
m_chatModel->setResponseValue(parseBuffers.last());
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || finalBuffers.size() > 1)) {
if (finalBuffers.size() > 1)
m_chatModel->setResponseValue(finalBuffers.last());
else
m_chatModel->setResponseValue(respStr.trimmed());
emit responseChanged();
}
bool doQuestions = false;
if (!m_isServer && messageItems && !shouldExecuteToolCall) {
if (!m_isServer && messageItems && !shouldExecuteTool) {
switch (mySettings->suggestionMode()) {
case SuggestionMode::On: doQuestions = true; break;
case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break;
@ -1090,6 +1163,66 @@ void ChatLLM::reloadModel()
loadModel(m);
}
// This class throws discards the text within thinking tags, for use with chat names and follow-up questions.
class SimpleResponseHandler : public BaseResponseHandler {
public:
SimpleResponseHandler(ChatLLM *cllm)
: m_cllm(cllm) {}
void onSplitIntoTwo(const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
{ /* no-op */ }
void onSplitIntoThree(const QString &secondBuffer, const QString &thirdBuffer) override
{ /* no-op */ }
void onOldResponseChunk(const QByteArray &chunk) override
{ m_response.append(chunk); }
bool onBufferResponse(const QString &response, int bufferIdx) override
{
if (bufferIdx == 1)
return true; // ignore "think" content
return onSimpleResponse(response);
}
bool onRegularResponse() override
{ return onBufferResponse(QString::fromUtf8(m_response), 0); }
bool getStopGenerating() const override
{ return m_cllm->m_stopGenerating; }
protected:
virtual bool onSimpleResponse(const QString &response) = 0;
protected:
ChatLLM *m_cllm;
QByteArray m_response;
};
class NameResponseHandler : public SimpleResponseHandler {
private:
// max length of chat names, in words
static constexpr qsizetype MAX_WORDS = 3;
public:
using SimpleResponseHandler::SimpleResponseHandler;
protected:
bool onSimpleResponse(const QString &response) override
{
QTextStream stream(const_cast<QString *>(&response), QIODeviceBase::ReadOnly);
QStringList words;
while (!stream.atEnd() && words.size() < MAX_WORDS) {
QString word;
stream >> word;
words << word;
}
emit m_cllm->generatedNameChanged(words.join(u' '));
return words.size() < MAX_WORDS || stream.atEnd();
}
};
void ChatLLM::generateName()
{
Q_ASSERT(isModelLoaded());
@ -1106,23 +1239,15 @@ void ChatLLM::generateName()
return;
}
QByteArray response; // raw UTF-8
auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
response.append(piece.data(), piece.size());
QStringList words = QString::fromUtf8(response).simplified().split(u' ', Qt::SkipEmptyParts);
emit generatedNameChanged(words.join(u' '));
return words.size() <= 3;
};
NameResponseHandler respHandler(this);
try {
m_llModelInfo.model->prompt(
applyJinjaTemplate(forkConversation(chatNamePrompt)),
[this](auto &&...) { return !m_stopGenerating; },
handleResponse,
promptContextFromSettings(m_modelInfo)
promptModelWithTools(
m_llModelInfo.model.get(),
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
respHandler, promptContextFromSettings(m_modelInfo),
applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(),
{ ToolCallConstants::ThinkTagName }
);
} catch (const std::exception &e) {
qWarning() << "ChatLLM failed to generate name:" << e.what();
@ -1134,13 +1259,43 @@ void ChatLLM::handleChatIdChanged(const QString &id)
m_llmThread.setObjectName(id);
}
void ChatLLM::generateQuestions(qint64 elapsed)
{
class QuestionResponseHandler : public SimpleResponseHandler {
public:
using SimpleResponseHandler::SimpleResponseHandler;
protected:
bool onSimpleResponse(const QString &response) override
{
auto responseUtf8Bytes = response.toUtf8().slice(m_offset);
auto responseUtf8 = std::string(responseUtf8Bytes.begin(), responseUtf8Bytes.end());
// extract all questions from response
ptrdiff_t lastMatchEnd = -1;
auto it = std::sregex_iterator(responseUtf8.begin(), responseUtf8.end(), s_reQuestion);
auto end = std::sregex_iterator();
for (; it != end; ++it) {
auto pos = it->position();
auto len = it->length();
lastMatchEnd = pos + len;
emit m_cllm->generatedQuestionFinished(QString::fromUtf8(&responseUtf8[pos], len));
}
// remove processed input from buffer
if (lastMatchEnd != -1)
m_offset += lastMatchEnd;
return true;
}
private:
// FIXME: This only works with response by the model in english which is not ideal for a multi-language
// model.
// match whole question sentences
static const std::regex reQuestion(R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)");
static inline const std::regex s_reQuestion { R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)" };
qsizetype m_offset = 0;
};
void ChatLLM::generateQuestions(qint64 elapsed)
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded()) {
emit responseStopped(elapsed);
@ -1158,39 +1313,17 @@ void ChatLLM::generateQuestions(qint64 elapsed)
emit generatingQuestions();
std::string response; // raw UTF-8
auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
// add token to buffer
response.append(piece);
// extract all questions from response
ptrdiff_t lastMatchEnd = -1;
auto it = std::sregex_iterator(response.begin(), response.end(), reQuestion);
auto end = std::sregex_iterator();
for (; it != end; ++it) {
auto pos = it->position();
auto len = it->length();
lastMatchEnd = pos + len;
emit generatedQuestionFinished(QString::fromUtf8(&response[pos], len));
}
// remove processed input from buffer
if (lastMatchEnd != -1)
response.erase(0, lastMatchEnd);
return true;
};
QuestionResponseHandler respHandler(this);
QElapsedTimer totalTime;
totalTime.start();
try {
m_llModelInfo.model->prompt(
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)),
[this](auto &&...) { return !m_stopGenerating; },
handleResponse,
promptContextFromSettings(m_modelInfo)
promptModelWithTools(
m_llModelInfo.model.get(),
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
respHandler, promptContextFromSettings(m_modelInfo),
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
{ ToolCallConstants::ThinkTagName }
);
} catch (const std::exception &e) {
qWarning() << "ChatLLM failed to generate follow-up questions:" << e.what();

View File

@ -30,6 +30,7 @@
using namespace Qt::Literals::StringLiterals;
class ChatViewResponseHandler;
class QDataStream;
// NOTE: values serialized to disk, do not change or reuse
@ -285,6 +286,8 @@ private:
bool m_isServer;
bool m_forceMetal;
bool m_reloadingToChangeVariant;
friend class ChatViewResponseHandler;
friend class SimpleResponseHandler;
};
#endif // CHATLLM_H

View File

@ -269,7 +269,7 @@ private:
std::optional<QString> m_chatTemplate;
mutable std::optional<QString> m_modelChatTemplate;
QString m_systemMessage;
QString m_chatNamePrompt = "Describe the above conversation in seven words or less.";
QString m_chatNamePrompt = "Describe the above conversation. Your entire response must be three words or less.";
QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts.";
friend class MySettings;
friend class ModelList;

View File

@ -1,17 +1,29 @@
#include "toolcallparser.h"
#include <QDebug>
#include <QtGlobal>
#include <QtLogging>
#include "tool.h"
#include <QSet>
#include <QtGlobal>
#include <stdexcept>
#include <cstddef>
ToolCallParser::ToolCallParser()
: ToolCallParser(ToolCallConstants::AllTagNames)
{}
ToolCallParser::ToolCallParser(const QStringList &tagNames)
{
m_possibleStartTags << ToolCallConstants::CodeInterpreterTag.toUtf8()
<< ToolCallConstants::ThinkTag.toUtf8();
m_possibleEndTags << ToolCallConstants::CodeInterpreterEndTag.toUtf8()
<< ToolCallConstants::ThinkEndTag.toUtf8();
QSet<QChar> firstChars;
for (auto &name : tagNames) {
if (name.isEmpty())
throw std::invalid_argument("ToolCallParser(): tag names must not be empty");
if (firstChars.contains(name.at(0)))
throw std::invalid_argument("ToolCallParser(): tag names must not share any prefix");
firstChars << name.at(0);
m_possibleStartTags << makeStartTag(name).toUtf8();
m_possibleEndTags << makeEndTag (name).toUtf8();
}
reset();
}
@ -80,7 +92,7 @@ void ToolCallParser::update(const QByteArray &update)
{
currentBuffer().append(update);
for (size_t i = currentBuffer().size() - update.size(); i < currentBuffer().size(); ++i) {
for (qsizetype i = currentBuffer().size() - update.size(); i < currentBuffer().size(); ++i) {
const char c = currentBuffer()[i];
const bool foundMatch = isExpected(c);
if (!foundMatch) {

View File

@ -1,30 +1,22 @@
#ifndef TOOLCALLPARSER_H
#define TOOLCALLPARSER_H
#include "tool.h"
#include <QByteArray>
#include <QList>
#include <QString>
#include <QStringList>
namespace ToolCallConstants
{
const QString CodeInterpreterFunction = R"(javascript_interpret)";
const QString CodeInterpreterTag = R"(<)" + CodeInterpreterFunction + R"(>)";
const QString CodeInterpreterEndTag = R"(</)" + CodeInterpreterFunction + R"(>)";
const QString CodeInterpreterPrefix = CodeInterpreterTag + "\n```javascript\n";
const QString CodeInterpreterSuffix = "```\n" + CodeInterpreterEndTag;
namespace ToolEnums { enum class ParseState; }
using namespace Qt::Literals::StringLiterals;
// NB: the parsing code assumes the first char of the various tags differ
const QString ThinkTag = QStringLiteral("<think>");
const QString ThinkEndTag = QStringLiteral("</think>");
}
class ToolCallParser
{
public:
ToolCallParser();
ToolCallParser(const QStringList &tagNames);
void reset();
void update(const QByteArray &update);
QString toolCall() const { return QString::fromUtf8(m_toolCall); }
@ -37,6 +29,9 @@ public:
QStringList buffers() const;
int numberOfBuffers() const { return m_buffers.size(); }
static QString makeStartTag(const QString &name) { return u"<%1>"_s .arg(name); }
static QString makeEndTag (const QString &name) { return u"</%1>"_s.arg(name); }
private:
QByteArray &currentBuffer();
void resetSearchState();
@ -58,4 +53,21 @@ private:
int m_endIndex;
};
namespace ToolCallConstants
{
// NB: the parsing code assumes the first char of the various tags differ
inline const QString CodeInterpreterFunction = u"javascript_interpret"_s;
inline const QString CodeInterpreterStartTag = ToolCallParser::makeStartTag(CodeInterpreterFunction);
inline const QString CodeInterpreterEndTag = ToolCallParser::makeEndTag (CodeInterpreterFunction);
inline const QString CodeInterpreterPrefix = u"%1\n```javascript\n"_s.arg(CodeInterpreterStartTag);
inline const QString CodeInterpreterSuffix = u"```\n%1"_s .arg(CodeInterpreterEndTag );
inline const QString ThinkTagName = u"think"_s;
inline const QString ThinkStartTag = ToolCallParser::makeStartTag(ThinkTagName);
inline const QString ThinkEndTag = ToolCallParser::makeEndTag (ThinkTagName);
inline const QStringList AllTagNames { CodeInterpreterFunction, ThinkTagName };
}
#endif // TOOLCALLPARSER_H