From ebda9146e7ecf53ddf757926f708bc03d8a1aa6c Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Fri, 11 Oct 2024 16:11:01 -0400 Subject: [PATCH] localdocs: fix regressions caused by docx change (#3079) Signed-off-by: Jared Van Bortel --- gpt4all-chat/CHANGELOG.md | 1 + gpt4all-chat/src/database.cpp | 228 +++++++++++++++++++++------------- gpt4all-chat/src/database.h | 9 ++ 3 files changed, 155 insertions(+), 83 deletions(-) diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index c8dda5f3..61a38c34 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Fixed - Fix models.json cache location ([#3052](https://github.com/nomic-ai/gpt4all/pull/3052)) +- Fix LocalDocs regressions caused by docx change ([#3079](https://github.com/nomic-ai/gpt4all/pull/3079)) ## [3.4.0] - 2024-10-08 diff --git a/gpt4all-chat/src/database.cpp b/gpt4all-chat/src/database.cpp index 02261cb4..0f271410 100644 --- a/gpt4all-chat/src/database.cpp +++ b/gpt4all-chat/src/database.cpp @@ -32,6 +32,7 @@ #include using namespace Qt::Literals::StringLiterals; +namespace ranges = std::ranges; namespace us = unum::usearch; //#define DEBUG @@ -175,6 +176,14 @@ static const QString INSERT_CHUNK_FTS_SQL = uR"( values(?, ?, ?, ?, ?, ?, ?); )"_s; +static const QString SELECT_CHUNKED_DOCUMENTS_SQL[] = { + uR"( + select distinct document_id from chunks; + )"_s, uR"( + select distinct document_id from chunks_fts; + )"_s, +}; + static const QString DELETE_CHUNKS_SQL[] = { uR"( delete from embeddings @@ -230,53 +239,6 @@ static const QString SELECT_CHUNKS_FTS_SQL = uR"( order by score limit %1; )"_s; -static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file, - const QString &title, const QString &author, const QString &subject, const QString &keywords, - int page, int from, int to, int words, int *chunk_id) -{ - if (!q.prepare(INSERT_CHUNK_SQL)) - return false; - q.addBindValue(document_id); - q.addBindValue(chunk_text); - q.addBindValue(file); - q.addBindValue(title); - q.addBindValue(author); - q.addBindValue(subject); - q.addBindValue(keywords); - q.addBindValue(page); - q.addBindValue(from); - q.addBindValue(to); - q.addBindValue(words); - if (!q.exec() || !q.next()) - return false; - *chunk_id = q.value(0).toInt(); - - if (!q.prepare(INSERT_CHUNK_FTS_SQL)) - return false; - q.addBindValue(document_id); - q.addBindValue(chunk_text); - q.addBindValue(file); - q.addBindValue(title); - q.addBindValue(author); - q.addBindValue(subject); - q.addBindValue(keywords); - if (!q.exec()) - return false; - return true; -} - -static bool removeChunksByDocumentId(QSqlQuery &q, int document_id) -{ - for (const auto &cmd: DELETE_CHUNKS_SQL) { - if (!q.prepare(cmd)) - return false; - q.addBindValue(document_id); - if (!q.exec()) - return false; - } - return true; -} - #define NAMED_PAIR(name, typea, a, typeb, b) \ struct name { typea a; typeb b; }; \ static bool operator==(const name &x, const name &y) { return x.a == y.a && x.b == y.b; } \ @@ -634,18 +596,6 @@ static bool selectAllFolderPaths(QSqlQuery &q, QList *folder_paths) return true; } -static bool sqlRemoveDocsByFolderPath(QSqlQuery &q, const QString &path) -{ - for (const auto &cmd: FOLDER_REMOVE_ALL_DOCS_SQL) { - if (!q.prepare(cmd)) - return false; - q.addBindValue(path); - if (!q.exec()) - return false; - } - return true; -} - static const QString INSERT_COLLECTION_ITEM_SQL = uR"( insert into collection_items(collection_id, folder_id) values(?, ?) @@ -889,6 +839,79 @@ void Database::rollback() Q_ASSERT(ok); } +bool Database::refreshDocumentIdCache(QSqlQuery &q) +{ + m_documentIdCache.clear(); + for (const auto &cmd: SELECT_CHUNKED_DOCUMENTS_SQL) { + if (!q.exec(cmd)) + return false; + while (q.next()) + m_documentIdCache << q.value(0).toInt(); + } + return true; +} + +bool Database::addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file, + const QString &title, const QString &author, const QString &subject, const QString &keywords, + int page, int from, int to, int words, int *chunk_id) +{ + if (!q.prepare(INSERT_CHUNK_SQL)) + return false; + q.addBindValue(document_id); + q.addBindValue(chunk_text); + q.addBindValue(file); + q.addBindValue(title); + q.addBindValue(author); + q.addBindValue(subject); + q.addBindValue(keywords); + q.addBindValue(page); + q.addBindValue(from); + q.addBindValue(to); + q.addBindValue(words); + if (!q.exec() || !q.next()) + return false; + *chunk_id = q.value(0).toInt(); + + if (!q.prepare(INSERT_CHUNK_FTS_SQL)) + return false; + q.addBindValue(document_id); + q.addBindValue(chunk_text); + q.addBindValue(file); + q.addBindValue(title); + q.addBindValue(author); + q.addBindValue(subject); + q.addBindValue(keywords); + if (!q.exec()) + return false; + m_documentIdCache << document_id; + return true; +} + +bool Database::removeChunksByDocumentId(QSqlQuery &q, int document_id) +{ + for (const auto &cmd: DELETE_CHUNKS_SQL) { + if (!q.prepare(cmd)) + return false; + q.addBindValue(document_id); + if (!q.exec()) + return false; + } + m_documentIdCache.remove(document_id); + return true; +} + +bool Database::sqlRemoveDocsByFolderPath(QSqlQuery &q, const QString &path) +{ + for (const auto &cmd: FOLDER_REMOVE_ALL_DOCS_SQL) { + if (!q.prepare(cmd)) + return false; + q.addBindValue(path); + if (!q.exec()) + return false; + } + return refreshDocumentIdCache(q); +} + bool Database::hasContent() { return m_db.tables().contains("chunks", Qt::CaseInsensitive); @@ -1246,9 +1269,13 @@ public: protected: std::optional advance() override { + if (getError()) + return std::nullopt; while (!m_stream.atEnd()) { QString word; m_stream >> word; + if (getError()) + return std::nullopt; if (!word.isEmpty()) return word; } @@ -1257,9 +1284,11 @@ protected: std::optional getError() const override { - if (!m_file.error()) - return std::nullopt; - return m_file.binarySeen() ? ChunkStreamer::Status::BINARY_SEEN : ChunkStreamer::Status::ERROR; + if (m_file.binarySeen()) + return ChunkStreamer::Status::BINARY_SEEN; + if (m_file.error()) + return ChunkStreamer::Status::ERROR; + return std::nullopt; } BinaryDetectingFile m_file; @@ -1300,12 +1329,24 @@ void ChunkStreamer::setDocument(const DocumentInfo &doc, int documentId, const Q m_page = 0; // make sure the document doesn't already have any chunks - QSqlQuery q(m_database->m_db); - if (!removeChunksByDocumentId(q, documentId)) - handleDocumentError("ERROR: Cannot remove chunks of document", documentId, doc.file.canonicalPath(), q.lastError()); + if (m_database->m_documentIdCache.contains(documentId)) { + QSqlQuery q(m_database->m_db); + if (!m_database->removeChunksByDocumentId(q, documentId)) + handleDocumentError("ERROR: Cannot remove chunks of document", documentId, doc.file.canonicalPath(), q.lastError()); + } } } +std::optional ChunkStreamer::currentDocKey() const +{ + return m_docKey; +} + +void ChunkStreamer::reset() +{ + m_docKey.reset(); +} + ChunkStreamer::Status ChunkStreamer::step() { // TODO: implement line_from/line_to @@ -1318,8 +1359,10 @@ ChunkStreamer::Status ChunkStreamer::step() Status retval; for (;;) { - if (auto error = m_reader->getError()) + if (auto error = m_reader->getError()) { + m_docKey.reset(); // done processing return *error; + } if (m_database->scanQueueInterrupted()) { retval = Status::INTERRUPTED; break; @@ -1340,7 +1383,7 @@ ChunkStreamer::Status ChunkStreamer::step() } } - if (!word || m_chunk.length() >= maxChunkSize + 1) { // +1 for leading space + if (!word || m_chunk.length() >= maxChunkSize + 1) { // +1 for trailing space if (!m_chunk.isEmpty()) { int nThisChunkWords = 0; auto chunk = m_chunk; // copy @@ -1348,35 +1391,44 @@ ChunkStreamer::Status ChunkStreamer::step() // handle overlength chunks if (m_chunk.length() > maxChunkSize + 1) { // find the final space - qsizetype lastSpace = chunk.lastIndexOf(u' ', -2); + qsizetype chunkEnd = chunk.lastIndexOf(u' ', -2); - if (lastSpace < 0) { + qsizetype spaceSize; + if (chunkEnd >= 0) { // slice off the last word + spaceSize = 1; Q_ASSERT(m_nChunkWords >= 1); - lastSpace = maxChunkSize; + // one word left nThisChunkWords = m_nChunkWords - 1; m_nChunkWords = 1; } else { // slice the overlong word + spaceSize = 0; + chunkEnd = maxChunkSize; + // partial word left, don't count it nThisChunkWords = m_nChunkWords; m_nChunkWords = 0; } - // save the extra part - m_chunk = chunk.sliced(lastSpace + 1); - // slice - chunk.truncate(lastSpace + 1); - Q_ASSERT(chunk.length() <= maxChunkSize + 1); + // save the second part, excluding space if any + m_chunk = chunk.sliced(chunkEnd + spaceSize); + // consume the first part + chunk.truncate(chunkEnd); } else { nThisChunkWords = m_nChunkWords; m_nChunkWords = 0; + // there is no second part + m_chunk.clear(); + // consume the whole chunk, excluding space + chunk.chop(1); } + Q_ASSERT(chunk.length() <= maxChunkSize); QSqlQuery q(m_database->m_db); int chunkId = 0; - if (!addChunk(q, + if (!m_database->addChunk(q, m_documentId, - chunk.chopped(1), // strip trailing space - m_reader->doc().file.canonicalFilePath(), + chunk, + m_reader->doc().file.fileName(), // basename m_title, m_author, m_subject, @@ -1399,12 +1451,11 @@ ChunkStreamer::Status ChunkStreamer::step() toEmbed.chunk = chunk; m_database->appendChunk(toEmbed); ++nChunks; - - m_chunk.clear(); } if (!word) { retval = Status::DOC_COMPLETE; + m_docKey.reset(); // done processing break; } } @@ -1532,8 +1583,14 @@ DocumentInfo Database::dequeueDocument() void Database::removeFolderFromDocumentQueue(int folder_id) { - if (auto it = m_docsToScan.find(folder_id); it != m_docsToScan.end()) - m_docsToScan.erase(it); + if (auto queueIt = m_docsToScan.find(folder_id); queueIt != m_docsToScan.end()) { + if (auto key = m_chunkStreamer.currentDocKey()) { + if (ranges::any_of(queueIt->second, [&key](auto &d) { return d.key() == key; })) + m_chunkStreamer.reset(); // done with this document + } + // remove folder from queue + m_docsToScan.erase(queueIt); + } } void Database::enqueueDocumentInternal(DocumentInfo &&info, bool prepend) @@ -1758,7 +1815,12 @@ void Database::start() m_databaseValid = false; } else { cleanDB(); - addCurrentFolders(); + QSqlQuery q(m_db); + if (!refreshDocumentIdCache(q)) { + m_databaseValid = false; + } else { + addCurrentFolders(); + } } if (!m_databaseValid) diff --git a/gpt4all-chat/src/database.h b/gpt4all-chat/src/database.h index 1a93e3b1..0e90c260 100644 --- a/gpt4all-chat/src/database.h +++ b/gpt4all-chat/src/database.h @@ -163,6 +163,8 @@ public: void setDocument(const DocumentInfo &doc, int documentId, const QString &embeddingModel, const QString &title, const QString &author, const QString &subject, const QString &keywords); + std::optional currentDocKey() const; + void reset(); Status step(); @@ -224,6 +226,12 @@ private: void commit(); void rollback(); + bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file, + const QString &title, const QString &author, const QString &subject, const QString &keywords, + int page, int from, int to, int words, int *chunk_id); + bool refreshDocumentIdCache(QSqlQuery &q); + bool removeChunksByDocumentId(QSqlQuery &q, int document_id); + bool sqlRemoveDocsByFolderPath(QSqlQuery &q, const QString &path); bool hasContent(); // not found -> 0, , exists and has content -> 1, error -> -1 int openDatabase(const QString &modelPath, bool create = true, int ver = LOCALDOCS_VERSION); @@ -293,6 +301,7 @@ private: QHash m_collectionMap; // used only for tracking indexing/embedding progress std::atomic m_databaseValid; ChunkStreamer m_chunkStreamer; + QSet m_documentIdCache; // cached list of documents with chunks for fast lookup friend class ChunkStreamer; };