mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-11-08 10:34:39 +00:00
Ignore DeepSeek-R1 "think" content in name/follow-up responses (#3458)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -29,6 +29,7 @@
|
||||
#include <QRegularExpression>
|
||||
#include <QRegularExpressionMatch>
|
||||
#include <QSet>
|
||||
#include <QTextStream>
|
||||
#include <QUrl>
|
||||
#include <QWaitCondition>
|
||||
#include <Qt>
|
||||
@@ -92,6 +93,70 @@ static const std::shared_ptr<minja::Context> &jinjaEnv()
|
||||
return environment;
|
||||
}
|
||||
|
||||
class BaseResponseHandler {
|
||||
public:
|
||||
virtual void onSplitIntoTwo (const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) = 0;
|
||||
virtual void onSplitIntoThree (const QString &secondBuffer, const QString &thirdBuffer) = 0;
|
||||
// "old-style" responses, with all of the implementation details left in
|
||||
virtual void onOldResponseChunk(const QByteArray &chunk) = 0;
|
||||
// notify of a "new-style" response that has been cleaned of tool calling
|
||||
virtual bool onBufferResponse (const QString &response, int bufferIdx) = 0;
|
||||
// notify of a "new-style" response, no tool calling applicable
|
||||
virtual bool onRegularResponse () = 0;
|
||||
virtual bool getStopGenerating () const = 0;
|
||||
};
|
||||
|
||||
static auto promptModelWithTools(
|
||||
LLModel *model, const LLModel::PromptCallback &promptCallback, BaseResponseHandler &respHandler,
|
||||
const LLModel::PromptContext &ctx, const QByteArray &prompt, const QStringList &toolNames
|
||||
) -> std::pair<QStringList, bool>
|
||||
{
|
||||
ToolCallParser toolCallParser(toolNames);
|
||||
auto handleResponse = [&toolCallParser, &respHandler](LLModel::Token token, std::string_view piece) -> bool {
|
||||
Q_UNUSED(token)
|
||||
|
||||
toolCallParser.update(piece.data());
|
||||
|
||||
// Split the response into two if needed
|
||||
if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) {
|
||||
const auto parseBuffers = toolCallParser.buffers();
|
||||
Q_ASSERT(parseBuffers.size() == 2);
|
||||
respHandler.onSplitIntoTwo(toolCallParser.startTag(), parseBuffers.at(0), parseBuffers.at(1));
|
||||
}
|
||||
|
||||
// Split the response into three if needed
|
||||
if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkStartTag
|
||||
&& toolCallParser.splitIfPossible()) {
|
||||
const auto parseBuffers = toolCallParser.buffers();
|
||||
Q_ASSERT(parseBuffers.size() == 3);
|
||||
respHandler.onSplitIntoThree(parseBuffers.at(1), parseBuffers.at(2));
|
||||
}
|
||||
|
||||
respHandler.onOldResponseChunk(QByteArray::fromRawData(piece.data(), piece.size()));
|
||||
|
||||
bool ok;
|
||||
const auto parseBuffers = toolCallParser.buffers();
|
||||
if (parseBuffers.size() > 1) {
|
||||
ok = respHandler.onBufferResponse(parseBuffers.last(), parseBuffers.size() - 1);
|
||||
} else {
|
||||
ok = respHandler.onRegularResponse();
|
||||
}
|
||||
if (!ok)
|
||||
return false;
|
||||
|
||||
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
|
||||
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
|
||||
|
||||
return !shouldExecuteToolCall && !respHandler.getStopGenerating();
|
||||
};
|
||||
model->prompt(std::string_view(prompt), promptCallback, handleResponse, ctx);
|
||||
|
||||
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
|
||||
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
|
||||
|
||||
return { toolCallParser.buffers(), shouldExecuteToolCall };
|
||||
}
|
||||
|
||||
class LLModelStore {
|
||||
public:
|
||||
static LLModelStore *globalInstance();
|
||||
@@ -882,6 +947,62 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL
|
||||
};
|
||||
}
|
||||
|
||||
class ChatViewResponseHandler : public BaseResponseHandler {
|
||||
public:
|
||||
ChatViewResponseHandler(ChatLLM *cllm, QElapsedTimer *totalTime, ChatLLM::PromptResult *result)
|
||||
: m_cllm(cllm), m_totalTime(totalTime), m_result(result) {}
|
||||
|
||||
void onSplitIntoTwo(const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
|
||||
{
|
||||
if (startTag == ToolCallConstants::ThinkStartTag)
|
||||
m_cllm->m_chatModel->splitThinking({ firstBuffer, secondBuffer });
|
||||
else
|
||||
m_cllm->m_chatModel->splitToolCall({ firstBuffer, secondBuffer });
|
||||
}
|
||||
|
||||
void onSplitIntoThree(const QString &secondBuffer, const QString &thirdBuffer) override
|
||||
{
|
||||
m_cllm->m_chatModel->endThinking({ secondBuffer, thirdBuffer }, m_totalTime->elapsed());
|
||||
}
|
||||
|
||||
void onOldResponseChunk(const QByteArray &chunk) override
|
||||
{
|
||||
m_result->responseTokens++;
|
||||
m_cllm->m_timer->inc();
|
||||
m_result->response.append(chunk);
|
||||
}
|
||||
|
||||
bool onBufferResponse(const QString &response, int bufferIdx) override
|
||||
{
|
||||
Q_UNUSED(bufferIdx)
|
||||
try {
|
||||
m_cllm->m_chatModel->setResponseValue(response);
|
||||
} catch (const std::exception &e) {
|
||||
// We have a try/catch here because the main thread might have removed the response from
|
||||
// the chatmodel by erasing the conversation during the response... the main thread sets
|
||||
// m_stopGenerating before doing so, but it doesn't wait after that to reset the chatmodel
|
||||
Q_ASSERT(m_cllm->m_stopGenerating);
|
||||
return false;
|
||||
}
|
||||
emit m_cllm->responseChanged();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool onRegularResponse() override
|
||||
{
|
||||
auto respStr = QString::fromUtf8(m_result->response);
|
||||
return onBufferResponse(removeLeadingWhitespace(respStr), 0);
|
||||
}
|
||||
|
||||
bool getStopGenerating() const override
|
||||
{ return m_cllm->m_stopGenerating; }
|
||||
|
||||
private:
|
||||
ChatLLM *m_cllm;
|
||||
QElapsedTimer *m_totalTime;
|
||||
ChatLLM::PromptResult *m_result;
|
||||
};
|
||||
|
||||
auto ChatLLM::promptInternal(
|
||||
const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
|
||||
const LLModel::PromptContext &ctx,
|
||||
@@ -931,64 +1052,20 @@ auto ChatLLM::promptInternal(
|
||||
|
||||
QElapsedTimer totalTime;
|
||||
totalTime.start();
|
||||
ChatViewResponseHandler respHandler(this, &totalTime, &result);
|
||||
|
||||
m_timer->start();
|
||||
|
||||
ToolCallParser toolCallParser;
|
||||
auto handleResponse = [this, &result, &toolCallParser, &totalTime](LLModel::Token token, std::string_view piece) -> bool {
|
||||
Q_UNUSED(token)
|
||||
result.responseTokens++;
|
||||
m_timer->inc();
|
||||
|
||||
toolCallParser.update(piece.data());
|
||||
|
||||
// Split the response into two if needed and create chat items
|
||||
if (toolCallParser.numberOfBuffers() < 2 && toolCallParser.splitIfPossible()) {
|
||||
const auto parseBuffers = toolCallParser.buffers();
|
||||
Q_ASSERT(parseBuffers.size() == 2);
|
||||
if (toolCallParser.startTag() == ToolCallConstants::ThinkTag)
|
||||
m_chatModel->splitThinking({parseBuffers.at(0), parseBuffers.at(1)});
|
||||
else
|
||||
m_chatModel->splitToolCall({parseBuffers.at(0), parseBuffers.at(1)});
|
||||
}
|
||||
|
||||
// Split the response into three if needed and create chat items
|
||||
if (toolCallParser.numberOfBuffers() < 3 && toolCallParser.startTag() == ToolCallConstants::ThinkTag
|
||||
&& toolCallParser.splitIfPossible()) {
|
||||
const auto parseBuffers = toolCallParser.buffers();
|
||||
Q_ASSERT(parseBuffers.size() == 3);
|
||||
m_chatModel->endThinking({parseBuffers.at(1), parseBuffers.at(2)}, totalTime.elapsed());
|
||||
}
|
||||
|
||||
result.response.append(piece.data(), piece.size());
|
||||
auto respStr = QString::fromUtf8(result.response);
|
||||
|
||||
try {
|
||||
const auto parseBuffers = toolCallParser.buffers();
|
||||
if (parseBuffers.size() > 1)
|
||||
m_chatModel->setResponseValue(parseBuffers.last());
|
||||
else
|
||||
m_chatModel->setResponseValue(removeLeadingWhitespace(respStr));
|
||||
} catch (const std::exception &e) {
|
||||
// We have a try/catch here because the main thread might have removed the response from
|
||||
// the chatmodel by erasing the conversation during the response... the main thread sets
|
||||
// m_stopGenerating before doing so, but it doesn't wait after that to reset the chatmodel
|
||||
Q_ASSERT(m_stopGenerating);
|
||||
return false;
|
||||
}
|
||||
|
||||
emit responseChanged();
|
||||
|
||||
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
|
||||
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
|
||||
|
||||
return !shouldExecuteToolCall && !m_stopGenerating;
|
||||
};
|
||||
|
||||
QStringList finalBuffers;
|
||||
bool shouldExecuteTool;
|
||||
try {
|
||||
emit promptProcessing();
|
||||
m_llModelInfo.model->setThreadCount(mySettings->threadCount());
|
||||
m_stopGenerating = false;
|
||||
m_llModelInfo.model->prompt(conversation, handlePrompt, handleResponse, ctx);
|
||||
std::tie(finalBuffers, shouldExecuteTool) = promptModelWithTools(
|
||||
m_llModelInfo.model.get(), handlePrompt, respHandler, ctx,
|
||||
QByteArray::fromRawData(conversation.data(), conversation.size()),
|
||||
ToolCallConstants::AllTagNames
|
||||
);
|
||||
} catch (...) {
|
||||
m_timer->stop();
|
||||
throw;
|
||||
@@ -997,22 +1074,18 @@ auto ChatLLM::promptInternal(
|
||||
m_timer->stop();
|
||||
qint64 elapsed = totalTime.elapsed();
|
||||
|
||||
const auto parseBuffers = toolCallParser.buffers();
|
||||
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
|
||||
&& toolCallParser.startTag() != ToolCallConstants::ThinkTag;
|
||||
|
||||
// trim trailing whitespace
|
||||
auto respStr = QString::fromUtf8(result.response);
|
||||
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || parseBuffers.size() > 1)) {
|
||||
if (parseBuffers.size() > 1)
|
||||
m_chatModel->setResponseValue(parseBuffers.last());
|
||||
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || finalBuffers.size() > 1)) {
|
||||
if (finalBuffers.size() > 1)
|
||||
m_chatModel->setResponseValue(finalBuffers.last());
|
||||
else
|
||||
m_chatModel->setResponseValue(respStr.trimmed());
|
||||
emit responseChanged();
|
||||
}
|
||||
|
||||
bool doQuestions = false;
|
||||
if (!m_isServer && messageItems && !shouldExecuteToolCall) {
|
||||
if (!m_isServer && messageItems && !shouldExecuteTool) {
|
||||
switch (mySettings->suggestionMode()) {
|
||||
case SuggestionMode::On: doQuestions = true; break;
|
||||
case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break;
|
||||
@@ -1090,6 +1163,66 @@ void ChatLLM::reloadModel()
|
||||
loadModel(m);
|
||||
}
|
||||
|
||||
// This class throws discards the text within thinking tags, for use with chat names and follow-up questions.
|
||||
class SimpleResponseHandler : public BaseResponseHandler {
|
||||
public:
|
||||
SimpleResponseHandler(ChatLLM *cllm)
|
||||
: m_cllm(cllm) {}
|
||||
|
||||
void onSplitIntoTwo(const QString &startTag, const QString &firstBuffer, const QString &secondBuffer) override
|
||||
{ /* no-op */ }
|
||||
|
||||
void onSplitIntoThree(const QString &secondBuffer, const QString &thirdBuffer) override
|
||||
{ /* no-op */ }
|
||||
|
||||
void onOldResponseChunk(const QByteArray &chunk) override
|
||||
{ m_response.append(chunk); }
|
||||
|
||||
bool onBufferResponse(const QString &response, int bufferIdx) override
|
||||
{
|
||||
if (bufferIdx == 1)
|
||||
return true; // ignore "think" content
|
||||
return onSimpleResponse(response);
|
||||
}
|
||||
|
||||
bool onRegularResponse() override
|
||||
{ return onBufferResponse(QString::fromUtf8(m_response), 0); }
|
||||
|
||||
bool getStopGenerating() const override
|
||||
{ return m_cllm->m_stopGenerating; }
|
||||
|
||||
protected:
|
||||
virtual bool onSimpleResponse(const QString &response) = 0;
|
||||
|
||||
protected:
|
||||
ChatLLM *m_cllm;
|
||||
QByteArray m_response;
|
||||
};
|
||||
|
||||
class NameResponseHandler : public SimpleResponseHandler {
|
||||
private:
|
||||
// max length of chat names, in words
|
||||
static constexpr qsizetype MAX_WORDS = 3;
|
||||
|
||||
public:
|
||||
using SimpleResponseHandler::SimpleResponseHandler;
|
||||
|
||||
protected:
|
||||
bool onSimpleResponse(const QString &response) override
|
||||
{
|
||||
QTextStream stream(const_cast<QString *>(&response), QIODeviceBase::ReadOnly);
|
||||
QStringList words;
|
||||
while (!stream.atEnd() && words.size() < MAX_WORDS) {
|
||||
QString word;
|
||||
stream >> word;
|
||||
words << word;
|
||||
}
|
||||
|
||||
emit m_cllm->generatedNameChanged(words.join(u' '));
|
||||
return words.size() < MAX_WORDS || stream.atEnd();
|
||||
}
|
||||
};
|
||||
|
||||
void ChatLLM::generateName()
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
@@ -1106,23 +1239,15 @@ void ChatLLM::generateName()
|
||||
return;
|
||||
}
|
||||
|
||||
QByteArray response; // raw UTF-8
|
||||
|
||||
auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool {
|
||||
Q_UNUSED(token)
|
||||
|
||||
response.append(piece.data(), piece.size());
|
||||
QStringList words = QString::fromUtf8(response).simplified().split(u' ', Qt::SkipEmptyParts);
|
||||
emit generatedNameChanged(words.join(u' '));
|
||||
return words.size() <= 3;
|
||||
};
|
||||
NameResponseHandler respHandler(this);
|
||||
|
||||
try {
|
||||
m_llModelInfo.model->prompt(
|
||||
applyJinjaTemplate(forkConversation(chatNamePrompt)),
|
||||
[this](auto &&...) { return !m_stopGenerating; },
|
||||
handleResponse,
|
||||
promptContextFromSettings(m_modelInfo)
|
||||
promptModelWithTools(
|
||||
m_llModelInfo.model.get(),
|
||||
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
|
||||
respHandler, promptContextFromSettings(m_modelInfo),
|
||||
applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(),
|
||||
{ ToolCallConstants::ThinkTagName }
|
||||
);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ChatLLM failed to generate name:" << e.what();
|
||||
@@ -1134,13 +1259,43 @@ void ChatLLM::handleChatIdChanged(const QString &id)
|
||||
m_llmThread.setObjectName(id);
|
||||
}
|
||||
|
||||
void ChatLLM::generateQuestions(qint64 elapsed)
|
||||
{
|
||||
class QuestionResponseHandler : public SimpleResponseHandler {
|
||||
public:
|
||||
using SimpleResponseHandler::SimpleResponseHandler;
|
||||
|
||||
protected:
|
||||
bool onSimpleResponse(const QString &response) override
|
||||
{
|
||||
auto responseUtf8Bytes = response.toUtf8().slice(m_offset);
|
||||
auto responseUtf8 = std::string(responseUtf8Bytes.begin(), responseUtf8Bytes.end());
|
||||
// extract all questions from response
|
||||
ptrdiff_t lastMatchEnd = -1;
|
||||
auto it = std::sregex_iterator(responseUtf8.begin(), responseUtf8.end(), s_reQuestion);
|
||||
auto end = std::sregex_iterator();
|
||||
for (; it != end; ++it) {
|
||||
auto pos = it->position();
|
||||
auto len = it->length();
|
||||
lastMatchEnd = pos + len;
|
||||
emit m_cllm->generatedQuestionFinished(QString::fromUtf8(&responseUtf8[pos], len));
|
||||
}
|
||||
|
||||
// remove processed input from buffer
|
||||
if (lastMatchEnd != -1)
|
||||
m_offset += lastMatchEnd;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
// FIXME: This only works with response by the model in english which is not ideal for a multi-language
|
||||
// model.
|
||||
// match whole question sentences
|
||||
static const std::regex reQuestion(R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)");
|
||||
static inline const std::regex s_reQuestion { R"(\b(?:What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)" };
|
||||
|
||||
qsizetype m_offset = 0;
|
||||
};
|
||||
|
||||
void ChatLLM::generateQuestions(qint64 elapsed)
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
if (!isModelLoaded()) {
|
||||
emit responseStopped(elapsed);
|
||||
@@ -1158,39 +1313,17 @@ void ChatLLM::generateQuestions(qint64 elapsed)
|
||||
|
||||
emit generatingQuestions();
|
||||
|
||||
std::string response; // raw UTF-8
|
||||
|
||||
auto handleResponse = [this, &response](LLModel::Token token, std::string_view piece) -> bool {
|
||||
Q_UNUSED(token)
|
||||
|
||||
// add token to buffer
|
||||
response.append(piece);
|
||||
|
||||
// extract all questions from response
|
||||
ptrdiff_t lastMatchEnd = -1;
|
||||
auto it = std::sregex_iterator(response.begin(), response.end(), reQuestion);
|
||||
auto end = std::sregex_iterator();
|
||||
for (; it != end; ++it) {
|
||||
auto pos = it->position();
|
||||
auto len = it->length();
|
||||
lastMatchEnd = pos + len;
|
||||
emit generatedQuestionFinished(QString::fromUtf8(&response[pos], len));
|
||||
}
|
||||
|
||||
// remove processed input from buffer
|
||||
if (lastMatchEnd != -1)
|
||||
response.erase(0, lastMatchEnd);
|
||||
return true;
|
||||
};
|
||||
QuestionResponseHandler respHandler(this);
|
||||
|
||||
QElapsedTimer totalTime;
|
||||
totalTime.start();
|
||||
try {
|
||||
m_llModelInfo.model->prompt(
|
||||
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)),
|
||||
[this](auto &&...) { return !m_stopGenerating; },
|
||||
handleResponse,
|
||||
promptContextFromSettings(m_modelInfo)
|
||||
promptModelWithTools(
|
||||
m_llModelInfo.model.get(),
|
||||
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
|
||||
respHandler, promptContextFromSettings(m_modelInfo),
|
||||
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
|
||||
{ ToolCallConstants::ThinkTagName }
|
||||
);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ChatLLM failed to generate follow-up questions:" << e.what();
|
||||
|
||||
Reference in New Issue
Block a user