Fix index used by LocalDocs when tool calling/thinking is active (#3451)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2025-02-03 11:22:46 -05:00 committed by GitHub
parent 6bfa014594
commit 9131f4c432
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 18 deletions

View File

@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
## [Unreleased]
### Fixed
- Fix "index N is not a prompt" when using LocalDocs with reasoning ([#3451](https://github.com/nomic-ai/gpt4all/pull/3451)
## [3.8.0] - 2025-01-30 ## [3.8.0] - 2025-01-30
### Added ### Added
@ -283,6 +288,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694)) - Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694))
- Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701)) - Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701))
[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.8.0...HEAD
[3.8.0]: https://github.com/nomic-ai/gpt4all/compare/v3.7.0...v3.8.0 [3.8.0]: https://github.com/nomic-ai/gpt4all/compare/v3.7.0...v3.8.0
[3.7.0]: https://github.com/nomic-ai/gpt4all/compare/v3.6.1...v3.7.0 [3.7.0]: https://github.com/nomic-ai/gpt4all/compare/v3.6.1...v3.7.0
[3.6.1]: https://github.com/nomic-ai/gpt4all/compare/v3.6.0...v3.6.1 [3.6.1]: https://github.com/nomic-ai/gpt4all/compare/v3.6.0...v3.6.1

View File

@ -730,7 +730,8 @@ std::vector<MessageItem> ChatLLM::forkConversation(const QString &prompt) const
conversation.reserve(items.size() + 1); conversation.reserve(items.size() + 1);
conversation.assign(items.begin(), items.end()); conversation.assign(items.begin(), items.end());
} }
conversation.emplace_back(MessageItem::Type::Prompt, prompt.toUtf8()); qsizetype nextIndex = conversation.empty() ? 0 : conversation.back().index().value() + 1;
conversation.emplace_back(nextIndex, MessageItem::Type::Prompt, prompt.toUtf8());
return conversation; return conversation;
} }
@ -801,7 +802,7 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const MessageItem> items) cons
json::array_t messages; json::array_t messages;
messages.reserve(useSystem + items.size()); messages.reserve(useSystem + items.size());
if (useSystem) { if (useSystem) {
systemItem = std::make_unique<MessageItem>(MessageItem::Type::System, systemMessage.toUtf8()); systemItem = std::make_unique<MessageItem>(MessageItem::system_tag, systemMessage.toUtf8());
messages.emplace_back(makeMap(*systemItem)); messages.emplace_back(makeMap(*systemItem));
} }
for (auto &item : items) for (auto &item : items)
@ -855,14 +856,14 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL
// Find the prompt that represents the query. Server chats are flexible and may not have one. // Find the prompt that represents the query. Server chats are flexible and may not have one.
auto items = getChat(); auto items = getChat();
if (auto peer = m_chatModel->getPeer(items, items.end() - 1)) // peer of response if (auto peer = m_chatModel->getPeer(items, items.end() - 1)) // peer of response
query = { *peer - items.begin(), (*peer)->content() }; query = { (*peer)->index().value(), (*peer)->content() };
} }
if (query) { if (query) {
auto &[promptIndex, queryStr] = *query; auto &[promptIndex, queryStr] = *query;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks
m_chatModel->updateSources(promptIndex + startOffset, databaseResults); m_chatModel->updateSources(promptIndex, databaseResults);
emit databaseResultsChanged(databaseResults); emit databaseResultsChanged(databaseResults);
} }
} }

View File

@ -94,11 +94,28 @@ class MessageItem
public: public:
enum class Type { System, Prompt, Response, ToolResponse }; enum class Type { System, Prompt, Response, ToolResponse };
MessageItem(Type type, QString content) struct system_tag_t { explicit system_tag_t() = default; };
: m_type(type), m_content(std::move(content)) {} static inline constexpr system_tag_t system_tag = system_tag_t{};
MessageItem(Type type, QString content, const QList<ResultInfo> &sources, const QList<PromptAttachment> &promptAttachments) MessageItem(qsizetype index, Type type, QString content)
: m_type(type), m_content(std::move(content)), m_sources(sources), m_promptAttachments(promptAttachments) {} : m_index(index), m_type(type), m_content(std::move(content))
{
Q_ASSERT(type != Type::System); // use system_tag constructor
}
// Construct a system message with no index, since they are never stored in the chat
MessageItem(system_tag_t, QString content)
: m_type(Type::System), m_content(std::move(content)) {}
MessageItem(qsizetype index, Type type, QString content, const QList<ResultInfo> &sources, const QList<PromptAttachment> &promptAttachments)
: m_index(index)
, m_type(type)
, m_content(std::move(content))
, m_sources(sources)
, m_promptAttachments(promptAttachments) {}
// index of the parent ChatItem (system, prompt, response) in its container
std::optional<qsizetype> index() const { return m_index; }
Type type() const { return m_type; } Type type() const { return m_type; }
const QString &content() const { return m_content; } const QString &content() const { return m_content; }
@ -126,6 +143,7 @@ public:
} }
private: private:
std::optional<qsizetype> m_index;
Type m_type; Type m_type;
QString m_content; QString m_content;
QList<ResultInfo> m_sources; QList<ResultInfo> m_sources;
@ -399,7 +417,7 @@ public:
Q_UNREACHABLE(); Q_UNREACHABLE();
} }
MessageItem asMessageItem() const MessageItem asMessageItem(qsizetype index) const
{ {
MessageItem::Type msgType; MessageItem::Type msgType;
switch (auto typ = type()) { switch (auto typ = type()) {
@ -413,7 +431,7 @@ public:
case Think: case Think:
throw std::invalid_argument(fmt::format("cannot convert ChatItem type {} to message item", int(typ))); throw std::invalid_argument(fmt::format("cannot convert ChatItem type {} to message item", int(typ)));
} }
return { msgType, flattenedContent(), sources, promptAttachments }; return { index, msgType, flattenedContent(), sources, promptAttachments };
} }
static QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources); static QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources);
@ -537,6 +555,7 @@ private:
return std::nullopt; return std::nullopt;
} }
// FIXME(jared): this should really be done at the parent level, not the sub-item level
static std::optional<qsizetype> getPeerInternal(const MessageItem *arr, qsizetype size, qsizetype index) static std::optional<qsizetype> getPeerInternal(const MessageItem *arr, qsizetype size, qsizetype index)
{ {
qsizetype peer; qsizetype peer;
@ -1114,10 +1133,12 @@ public:
// A flattened version of the chat item tree used by the backend and jinja // A flattened version of the chat item tree used by the backend and jinja
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
std::vector<MessageItem> chatItems; std::vector<MessageItem> chatItems;
for (const ChatItem *item : m_chatItems) { for (qsizetype i : views::iota(0, m_chatItems.size())) {
chatItems.reserve(chatItems.size() + item->subItems.size() + 1); auto *parent = m_chatItems.at(i);
ranges::copy(item->subItems | views::transform(&ChatItem::asMessageItem), std::back_inserter(chatItems)); chatItems.reserve(chatItems.size() + parent->subItems.size() + 1);
chatItems.push_back(item->asMessageItem()); ranges::copy(parent->subItems | views::transform([&](auto *s) { return s->asMessageItem(i); }),
std::back_inserter(chatItems));
chatItems.push_back(parent->asMessageItem(i));
} }
return chatItems; return chatItems;
} }