Ignore DeepSeek-R1 "think" content in name/follow-up responses (#3458)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2025-02-04 12:08:17 -05:00
committed by GitHub
parent d4e6a6e485
commit 8c9f26e249
7 changed files with 295 additions and 137 deletions

View File

@@ -29,6 +29,7 @@
#include <QRegularExpression>
#include <QRegularExpressionMatch>
#include <QSet>
#include <QTextStream>
#include <QUrl>
#include <QWaitCondition>
#include <Qt>
@@ -92,6 +93,70 @@ static const std::shared_ptr<minja::Context> &jinjaEnv()
return environment;
}
class BaseResponseHandler {
public:
virtual void onSplitIntoTwo (const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) = 0;
virtual void onSplitIntoThree (const QString &secondBuffer, const QString &thirdBuffer) = 0;
// "old-style" responses, with all of the implementation details left in
virtual void onOldResponseChunk(const QByteArray &chunk) = 0;
// notify of a "new-style" response that has been cleaned of tool calling
virtual bool onBufferResponse (const QString &response, int bufferIdx) = 0;
// notify of a "new-style" response, no tool calling applicable
virtual bool onRegularResponse () = 0;
virtual bool getStopGenerating () const = 0;
};
static auto promptModelWithTools(
LLModel *model, const LLModel::PromptCallback &promptCallback, BaseResponseHandler &respHandler,
const LLModel::PromptContext &ctx, const QByteArray &prompt, const QStringList &toolNames
) -> std::pair<QStringList, bool>
{
ToolCallParser toolCallParser(toolNames);
auto handleResponse = [&toolCallParser, &respHandler](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
toolCallParser.update(piece.data());
// Split the response into two if needed
if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 2);
respHandler.onSplitIntoTwo(toolCallParser.startTag(), parseBuffers.at(0), parseBuffers.at(1));
}
// Split the response into three if needed
if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkStartTag
&& toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 3);
respHandler.onSplitIntoThree(parseBuffers.at(1), parseBuffers.at(2));
}
respHandler.onOldResponseChunk(QByteArray::fromRawData(piece.data(), piece.size()));
bool ok;
const auto parseBuffers = toolCallParser.buffers();
if (parseBuffers.size() > 1) {
ok = respHandler.onBufferResponse(parseBuffers.last(), parseBuffers.size() - 1);
} else {
ok = respHandler.onRegularResponse();
}
if (!ok)
return false;
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
return !shouldExecuteToolCall && !respHandler.getStopGenerating();
};
model->prompt(std::string_view(prompt), promptCallback, handleResponse, ctx);
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
return { toolCallParser.buffers(), shouldExecuteToolCall };
}
class LLModelStore {
public:
static LLModelStore *globalInstance();
@@ -882,6 +947,62 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL
};
}
class ChatViewResponseHandler : public BaseResponseHandler {
public:
ChatViewResponseHandler(ChatLLM *cllm, QElapsedTimer *totalTime, ChatLLM::PromptResult *result)
: m_cllm(cllm), m_totalTime(totalTime), m_result(result) {}
void onSplitIntoTwo(const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
{
if (startTag == ToolCallConstants::ThinkStartTag)
m_cllm->m_chatModel->splitThinking({ firstBuffer, secondBuffer });
else
m_cllm->m_chatModel->splitToolCall({ firstBuffer, secondBuffer });
}
void onSplitIntoThree(const QString &secondBuffer, const QString &thirdBuffer) override
{
m_cllm->m_chatModel->endThinking({ secondBuffer, thirdBuffer }, m_totalTime->elapsed());
}
void onOldResponseChunk(const QByteArray &chunk) override
{
m_result->responseTokens++;
m_cllm->m_timer->inc();
m_result->response.append(chunk);
}
bool onBufferResponse(const QString &response, int bufferIdx) override
{
Q_UNUSED(bufferIdx)
try {
m_cllm->m_chatModel->setResponseValue(response);
} catch (const std::exception &e) {
// We have a try/catch here because the main thread might have removed the response from
// the chatmodel by erasing the conversation during the response... the main thread sets
// m_stopGenerating before doing so, but it doesn't wait after that to reset the chatmodel
Q_ASSERT(m_cllm->m_stopGenerating);
return false;
}
emit m_cllm->responseChanged();
return true;
}
bool onRegularResponse() override
{
auto respStr = QString::fromUtf8(m_result->response);
return onBufferResponse(removeLeadingWhitespace(respStr), 0);
}
bool getStopGenerating() const override
{ return m_cllm->m_stopGenerating; }
private:
ChatLLM *m_cllm;
QElapsedTimer *m_totalTime;
ChatLLM::PromptResult *m_result;
};
auto ChatLLM::promptInternal(
const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
const LLModel::PromptContext &ctx,
@@ -931,64 +1052,20 @@ auto ChatLLM::promptInternal(
QElapsedTimer totalTime;
totalTime.start();
ChatViewResponseHandler respHandler(this, &totalTime, &result);
m_timer->start();
ToolCallParser toolCallParser;
auto handleResponse = [this, &result, &toolCallParser, &totalTime](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
result.responseTokens++;
m_timer->inc();
toolCallParser.update(piece.data());
// Split the response into two if needed and create chat items
if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 2);
if (toolCallParser.startTag() == ToolCallConstants::ThinkTag)
m_chatModel->splitThinking({parseBuffers.at(0), parseBuffers.at(1)});
else
m_chatModel->splitToolCall({parseBuffers.at(0), parseBuffers.at(1)});
}
// Split the response into three if needed and create chat items
if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkTag
&& toolCallParser.splitIfPossible()) {
const auto parseBuffers = toolCallParser.buffers();
Q_ASSERT(parseBuffers.size() == 3);
m_chatModel->endThinking({parseBuffers.at(1), parseBuffers.at(2)}, totalTime.elapsed());
}
result.response.append(piece.data(), piece.size());
auto respStr = QString::fromUtf8(result.response);
try {
const auto parseBuffers = toolCallParser.buffers();
if (parseBuffers.size() > 1)
m_chatModel->setResponseValue(parseBuffers.last());
else
m_chatModel->setResponseValue(removeLeadingWhitespace(respStr));
} catch (const std::exception &e) {
// We have a try/catch here because the main thread might have removed the response from
// the chatmodel by erasing the conversation during the response... the main thread sets
// m_stopGenerating before doing so, but it doesn't wait after that to reset the chatmodel
Q_ASSERT(m_stopGenerating);
return false;
}
emit responseChanged();
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
return !shouldExecuteToolCall && !m_stopGenerating;
};
QStringList finalBuffers;
bool shouldExecuteTool;
try {
emit promptProcessing();
m_llModelInfo.model->setThreadCount(mySettings->threadCount());
m_stopGenerating = false;
m_llModelInfo.model->prompt(conversation, handlePrompt, handleResponse, ctx);
std::tie(finalBuffers, shouldExecuteTool) = promptModelWithTools(
m_llModelInfo.model.get(), handlePrompt, respHandler, ctx,
QByteArray::fromRawData(conversation.data(), conversation.size()),
ToolCallConstants::AllTagNames
);
} catch (...) {
m_timer->stop();
throw;
@@ -997,22 +1074,18 @@ auto ChatLLM::promptInternal(
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
const auto parseBuffers = toolCallParser.buffers();
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
// trim trailing whitespace
auto respStr = QString::fromUtf8(result.response);
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || parseBuffers.size() > 1)) {
if (parseBuffers.size() > 1)
m_chatModel->setResponseValue(parseBuffers.last());
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || finalBuffers.size() > 1)) {
if (finalBuffers.size() > 1)
m_chatModel->setResponseValue(finalBuffers.last());
else
m_chatModel->setResponseValue(respStr.trimmed());
emit responseChanged();
}
bool doQuestions = false;
if (!m_isServer && messageItems && !shouldExecuteToolCall) {
if (!m_isServer && messageItems && !shouldExecuteTool) {
switch (mySettings->suggestionMode()) {
case SuggestionMode::On: doQuestions = true; break;
case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break;
@@ -1090,6 +1163,66 @@ void ChatLLM::reloadModel()
loadModel(m);
}
// This class throws discards the text within thinking tags, for use with chat names and follow-up questions.
class SimpleResponseHandler : public BaseResponseHandler {
public:
SimpleResponseHandler(ChatLLM *cllm)
: m_cllm(cllm) {}
void onSplitIntoTwo(const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
{ /* no-op */ }
void onSplitIntoThree(const QString &secondBuffer, const QString &thirdBuffer) override
{ /* no-op */ }
void onOldResponseChunk(const QByteArray &chunk) override
{ m_response.append(chunk); }
bool onBufferResponse(const QString &response, int bufferIdx) override
{
if (bufferIdx == 1)
return true; // ignore "think" content
return onSimpleResponse(response);
}
bool onRegularResponse() override
{ return onBufferResponse(QString::fromUtf8(m_response), 0); }
bool getStopGenerating() const override
{ return m_cllm->m_stopGenerating; }
protected:
virtual bool onSimpleResponse(const QString &response) = 0;
protected:
ChatLLM *m_cllm;
QByteArray m_response;
};
class NameResponseHandler : public SimpleResponseHandler {
private:
// max length of chat names, in words
static constexpr qsizetype MAX_WORDS = 3;
public:
using SimpleResponseHandler::SimpleResponseHandler;
protected:
bool onSimpleResponse(const QString &response) override
{
QTextStream stream(const_cast<QString *>(&response), QIODeviceBase::ReadOnly);
QStringList words;
while (!stream.atEnd() && words.size() < MAX_WORDS) {
QString word;
stream >> word;
words << word;
}
emit m_cllm->generatedNameChanged(words.join(u' '));
return words.size() < MAX_WORDS || stream.atEnd();
}
};
void ChatLLM::generateName()
{
Q_ASSERT(isModelLoaded());
@@ -1106,23 +1239,15 @@ void ChatLLM::generateName()
return;
}
QByteArray response; // raw UTF-8
auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
response.append(piece.data(), piece.size());
QStringList words = QString::fromUtf8(response).simplified().split(u' ', Qt::SkipEmptyParts);
emit generatedNameChanged(words.join(u' '));
return words.size() <= 3;
};
NameResponseHandler respHandler(this);
try {
m_llModelInfo.model->prompt(
applyJinjaTemplate(forkConversation(chatNamePrompt)),
[this](auto &&...) { return !m_stopGenerating; },
handleResponse,
promptContextFromSettings(m_modelInfo)
promptModelWithTools(
m_llModelInfo.model.get(),
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
respHandler, promptContextFromSettings(m_modelInfo),
applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(),
{ ToolCallConstants::ThinkTagName }
);
} catch (const std::exception &e) {
qWarning() << "ChatLLM failed to generate name:" << e.what();
@@ -1134,13 +1259,43 @@ void ChatLLM::handleChatIdChanged(const QString &id)
m_llmThread.setObjectName(id);
}
void ChatLLM::generateQuestions(qint64 elapsed)
{
class QuestionResponseHandler : public SimpleResponseHandler {
public:
using SimpleResponseHandler::SimpleResponseHandler;
protected:
bool onSimpleResponse(const QString &response) override
{
auto responseUtf8Bytes = response.toUtf8().slice(m_offset);
auto responseUtf8 = std::string(responseUtf8Bytes.begin(), responseUtf8Bytes.end());
// extract all questions from response
ptrdiff_t lastMatchEnd = -1;
auto it = std::sregex_iterator(responseUtf8.begin(), responseUtf8.end(), s_reQuestion);
auto end = std::sregex_iterator();
for (; it != end; ++it) {
auto pos = it->position();
auto len = it->length();
lastMatchEnd = pos + len;
emit m_cllm->generatedQuestionFinished(QString::fromUtf8(&responseUtf8[pos], len));
}
// remove processed input from buffer
if (lastMatchEnd != -1)
m_offset += lastMatchEnd;
return true;
}
private:
// FIXME: This only works with response by the model in english which is not ideal for a multi-language
// model.
// match whole question sentences
static const std::regex reQuestion(R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)");
static inline const std::regex s_reQuestion { R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)" };
qsizetype m_offset = 0;
};
void ChatLLM::generateQuestions(qint64 elapsed)
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded()) {
emit responseStopped(elapsed);
@@ -1158,39 +1313,17 @@ void ChatLLM::generateQuestions(qint64 elapsed)
emit generatingQuestions();
std::string response; // raw UTF-8
auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
// add token to buffer
response.append(piece);
// extract all questions from response
ptrdiff_t lastMatchEnd = -1;
auto it = std::sregex_iterator(response.begin(), response.end(), reQuestion);
auto end = std::sregex_iterator();
for (; it != end; ++it) {
auto pos = it->position();
auto len = it->length();
lastMatchEnd = pos + len;
emit generatedQuestionFinished(QString::fromUtf8(&response[pos], len));
}
// remove processed input from buffer
if (lastMatchEnd != -1)
response.erase(0, lastMatchEnd);
return true;
};
QuestionResponseHandler respHandler(this);
QElapsedTimer totalTime;
totalTime.start();
try {
m_llModelInfo.model->prompt(
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)),
[this](auto &&...) { return !m_stopGenerating; },
handleResponse,
promptContextFromSettings(m_modelInfo)
promptModelWithTools(
m_llModelInfo.model.get(),
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
respHandler, promptContextFromSettings(m_modelInfo),
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
{ ToolCallConstants::ThinkTagName }
);
} catch (const std::exception &e) {
qWarning() << "ChatLLM failed to generate follow-up questions:" << e.what();