From 10d2375bf379485fb9adce087fd62c7720bdc73a Mon Sep 17 00:00:00 2001 From: AT Date: Thu, 26 Sep 2024 11:58:48 -0400 Subject: [PATCH] Hybrid search (#2969) Signed-off-by: Adam Treat --- gpt4all-chat/CHANGELOG.md | 6 + gpt4all-chat/src/database.cpp | 299 +++++++++++++++++++++++++++++++--- gpt4all-chat/src/database.h | 21 ++- 3 files changed, 303 insertions(+), 23 deletions(-) diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index 403130b6..6edd4ce8 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). +## [Unreleased] + +### Added +- Add bm25 hybrid search to localdocs ([#2969](https://github.com/nomic-ai/gpt4all/pull/2969)) + ## [3.3.0] - 2024-09-20 ### Added @@ -119,6 +124,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694)) - Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701)) +[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.3.0...HEAD [3.3.0]: https://github.com/nomic-ai/gpt4all/compare/v3.2.1...v3.3.0 [3.2.1]: https://github.com/nomic-ai/gpt4all/compare/v3.2.0...v3.2.1 [3.2.0]: https://github.com/nomic-ai/gpt4all/compare/v3.1.1...v3.2.0 diff --git a/gpt4all-chat/src/database.cpp b/gpt4all-chat/src/database.cpp index 02ab5364..7d029735 100644 --- a/gpt4all-chat/src/database.cpp +++ b/gpt4all-chat/src/database.cpp @@ -103,6 +103,20 @@ static const QString INIT_DB_SQL[] = { tokens integer default 0 not null, foreign key(document_id) references documents(id) ); + )"_s, uR"( + create virtual table chunks_fts using fts5( + id unindexed, + document_id unindexed, + chunk_text, + file, + title, + author, + subject, + keywords, + content='chunks', + content_rowid='id', + tokenize='porter' + ); )"_s, uR"( create table collections( id integer primary key, @@ -151,7 +165,13 @@ static const QString INSERT_CHUNK_SQL = uR"( file, title, author, subject, keywords, page, line_from, line_to, words) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) returning id; - )"_s; +)"_s; + +static const QString INSERT_CHUNK_FTS_SQL = uR"( + insert into chunks_fts(document_id, chunk_text, + file, title, author, subject, keywords) + values(?, ?, ?, ?, ?, ?, ?); +)"_s; static const QString DELETE_CHUNKS_SQL[] = { uR"( @@ -161,12 +181,14 @@ static const QString DELETE_CHUNKS_SQL[] = { ); )"_s, uR"( delete from chunks where document_id = ?; + )"_s, uR"( + delete from chunks_fts where document_id = ?; )"_s, }; static const QString SELECT_CHUNKS_BY_DOCUMENT_SQL = uR"( select id from chunks WHERE document_id = ?; - )"_s; +)"_s; static const QString SELECT_CHUNKS_SQL = uR"( select c.id, d.document_time, d.document_path, c.chunk_text, c.file, c.title, c.author, c.page, c.line_from, c.line_to, co.name @@ -190,14 +212,21 @@ static const QString SELECT_UNCOMPLETED_CHUNKS_SQL = uR"( from embeddings e where e.chunk_id = c.id and e.model = co.embedding_model ); - )"_s; +)"_s; static const QString SELECT_COUNT_CHUNKS_SQL = uR"( select count(c.id) from chunks c join documents d on d.id = c.document_id where d.folder_id = ?; - )"_s; +)"_s; + +static const QString SELECT_CHUNKS_FTS_SQL = uR"( + select id, bm25(chunks_fts) as score + from chunks_fts + where chunks_fts match ? + order by score limit %1; +)"_s; static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, @@ -219,6 +248,18 @@ static bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, c if (!q.exec() || !q.next()) return false; *chunk_id = q.value(0).toInt(); + + if (!q.prepare(INSERT_CHUNK_FTS_SQL)) + return false; + q.addBindValue(document_id); + q.addBindValue(chunk_text); + q.addBindValue(file); + q.addBindValue(title); + q.addBindValue(author); + q.addBindValue(subject); + q.addBindValue(keywords); + if (!q.exec()) + return false; return true; } @@ -424,6 +465,7 @@ static bool selectAllFromCollections(QSqlQuery &q, QList *collec return false; break; case 2: + case 3: if (!q.prepare(SELECT_COLLECTIONS_SQL_V2)) return false; break; @@ -770,6 +812,12 @@ static const QString GET_COLLECTION_EMBEDDINGS_SQL = uR"( where co.name in ('%1'); )"_s; +static const QString GET_CHUNK_EMBEDDINGS_SQL = uR"( + select e.chunk_id, e.embedding + from embeddings e + where e.chunk_id in (%1); +)"_s; + static const QString GET_CHUNK_FILE_SQL = uR"( select file from chunks where id = ?; )"_s; @@ -1858,19 +1906,13 @@ void Database::removeFolderFromWatch(const QString &path) m_watchedPaths -= QSet(children.begin(), children.end()); } -QList Database::searchEmbeddings(const std::vector &query, const QList &collections, int nNeighbors) +QList Database::searchEmbeddingsHelper(const std::vector &query, QSqlQuery &q, int nNeighbors) { constexpr int BATCH_SIZE = 2048; const int n_embd = query.size(); const us::metric_punned_t metric(n_embd, us::metric_kind_t::ip_k); // inner product - QSqlQuery q(m_db); - if (!q.exec(GET_COLLECTION_EMBEDDINGS_SQL.arg(collections.join("', '")))) { - qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError(); - return {}; - } - us::executor_default_t executor(std::thread::hardware_concurrency()); us::exact_search_t search; @@ -1882,6 +1924,7 @@ QList Database::searchEmbeddings(const std::vector &query, const QLi struct Result { int chunkId; us::distance_punned_t dist; }; QList results; + // The q parameter is expected to be the result of a QSqlQuery returning (chunk_id, embedding) pairs while (q.at() != QSql::AfterLastRow) { // batches batchChunkIds.clear(); batchEmbeddings.clear(); @@ -1937,6 +1980,223 @@ QList Database::searchEmbeddings(const std::vector &query, const QLi return chunkIds; } +QList Database::searchEmbeddings(const std::vector &query, const QList &collections, + int nNeighbors) +{ + QSqlQuery q(m_db); + if (!q.exec(GET_COLLECTION_EMBEDDINGS_SQL.arg(collections.join("', '")))) { + qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError(); + return {}; + } + return searchEmbeddingsHelper(query, q, nNeighbors); +} + +QList Database::scoreChunks(const std::vector &query, const QList &chunks) +{ + QList chunkStrings; + for (int id : chunks) + chunkStrings << QString::number(id); + QSqlQuery q(m_db); + if (!q.exec(GET_CHUNK_EMBEDDINGS_SQL.arg(chunkStrings.join(", ")))) { + qWarning() << "Database ERROR: Failed to exec embeddings query:" << q.lastError(); + return {}; + } + return searchEmbeddingsHelper(query, q, chunks.size()); +} + +QList Database::queriesForFTS5(const QString &input) +{ + // Escape double quotes by adding a second double quote + QString escapedInput = input; + escapedInput.replace("\"", "\"\""); + + static QRegularExpression spaces("\\s+"); + QStringList oWords = escapedInput.split(spaces, Qt::SkipEmptyParts); + + QList queries; + + // Start by trying to match the entire input + BM25Query e; + e.isExact = true; + e.input = oWords.join(" "); + e.query = "\"" + oWords.join(" ") + "\""; + e.qlength = oWords.size(); + e.ilength = oWords.size(); + queries << e; + + // https://github.com/igorbrigadir/stopwords?tab=readme-ov-file + // Lucene, Solr, Elastisearch + static const QSet stopWords = { + "a", "an", "and", "are", "as", "at", "be", "but", "by", + "for", "if", "in", "into", "is", "it", "no", "not", "of", + "on", "or", "such", "that", "the", "their", "then", "there", + "these", "they", "this", "to", "was", "will", "with" + }; + + QStringList quotedWords; + for (const QString &w : oWords) + if (!stopWords.contains(w.toLower())) + quotedWords << "\"" + w + "\""; + + BM25Query b; + b.input = oWords.join(" "); + b.query = "(" + quotedWords.join(" OR ") + ")"; + b.qlength = 1; // length of phrase + b.ilength = oWords.size(); + b.rlength = oWords.size() - quotedWords.size(); + queries << b; + return queries; +} + +QList Database::searchBM25(const QString &query, const QList &collections, BM25Query &bm25q, int k) +{ + struct SearchResult { int chunkId; float score; }; + QList bm25Queries = queriesForFTS5(query); + + QSqlQuery sqlQuery(m_db); + sqlQuery.prepare(SELECT_CHUNKS_FTS_SQL.arg(k)); + + QList results; + for (auto &bm25Query : std::as_const(bm25Queries)) { + sqlQuery.addBindValue(bm25Query.query); + + if (!sqlQuery.exec()) { + qWarning() << "Database ERROR: Failed to execute BM25 query:" << sqlQuery.lastError(); + return {}; + } + + if (sqlQuery.next()) { + // Save the query that was used to produce results + bm25q = bm25Query; + break; + } + } + + do { + const int chunkId = sqlQuery.value(0).toInt(); + const float score = sqlQuery.value(1).toFloat(); + results.append({chunkId, score}); + } while (sqlQuery.next()); + + k = qMin(k, results.size()); + std::partial_sort( + results.begin(), results.begin() + k, results.end(), + [](const SearchResult &a, const SearchResult &b) { return a.score < b.score; } + ); + + QList chunkIds; + chunkIds.reserve(k); + for (int i = 0; i < k; i++) + chunkIds << results[i].chunkId; + return chunkIds; +} + +float Database::computeBM25Weight(const Database::BM25Query &bm25q) +{ + float bmWeight = 0.0f; + if (bm25q.isExact) { + bmWeight = 0.9f; // the highest we give + } else { + // qlength is the length of the phrases in the query by number of distinct words + // ilength is the length of the natural language query by number of distinct words + // rlength is the number of stop words removed from the natural language query to form the query + + // calculate the query length weight based on the ratio of query terms to meaningful terms. + // this formula adjusts the weight with the empirically determined insight that BM25's + // effectiveness decreases as query length increases. + float queryLengthWeight = 1 / powf(float(bm25q.ilength - bm25q.rlength), 2); + queryLengthWeight = qBound(0.0f, queryLengthWeight, 1.0f); + + // the weighting is bound between 1/4 and 3/4 which was determined empirically to work well + // with the beir nfcorpus, scifact, fiqa and trec-covid datasets along with our embedding + // model + bmWeight = 0.25f + queryLengthWeight * 0.50f; + } + +#if 0 + qDebug() + << "bm25q.type" << bm25q.type + << "bm25q.qlength" << bm25q.qlength + << "bm25q.ilength" << bm25q.ilength + << "bm25q.rlength" << bm25q.rlength + << "bmWeight" << bmWeight; +#endif + + return bmWeight; +} + +QList Database::reciprocalRankFusion(const std::vector &query, const QList &embeddingResults, + const QList &bm25Results, const BM25Query &bm25q, int k) +{ + // We default to the embedding results and augment with bm25 if any + QList results = embeddingResults; + + QList missingScores; + QHash bm25Ranks; + for (int i = 0; i < bm25Results.size(); ++i) { + if (!results.contains(bm25Results[i])) + missingScores.append(bm25Results[i]); + bm25Ranks[bm25Results[i]] = i + 1; + } + + if (!missingScores.isEmpty()) { + QList scored = scoreChunks(query, missingScores); + results << scored; + } + + QHash embeddingRanks; + for (int i = 0; i < results.size(); ++i) + embeddingRanks[results[i]] = i + 1; + + const float bmWeight = bm25Results.isEmpty() ? 0 : computeBM25Weight(bm25q); + + // From the paper: "Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods" + // doi: 10.1145/1571941.157211 + const int fusion_k = 60; + + std::stable_sort( + results.begin(), results.end(), + [&](const int &a, const int &b) { + // Reciprocal Rank Fusion (RRF) + const int aBm25Rank = bm25Ranks.value(a, bm25Results.size() + 1); + const int aEmbeddingRank = embeddingRanks.value(a, embeddingResults.size() + 1); + Q_ASSERT(embeddingRanks.contains(a)); + + const int bBm25Rank = bm25Ranks.value(b, bm25Results.size() + 1); + const int bEmbeddingRank = embeddingRanks.value(b, embeddingResults.size() + 1); + Q_ASSERT(embeddingRanks.contains(b)); + + const float aBm25Score = 1.0f / (fusion_k + aBm25Rank); + const float bBm25Score = 1.0f / (fusion_k + bBm25Rank); + const float aEmbeddingScore = 1.0f / (fusion_k + aEmbeddingRank); + const float bEmbeddingScore = 1.0f / (fusion_k + bEmbeddingRank); + const float aWeightedScore = bmWeight * aBm25Score + (1.f - bmWeight) * aEmbeddingScore; + const float bWeightedScore = bmWeight * bBm25Score + (1.f - bmWeight) * bEmbeddingScore; + + // Higher RRF score means better ranking, so we use greater than for sorting + return aWeightedScore > bWeightedScore; + } + ); + + k = qMin(k, results.size()); + results.resize(k); + return results; +} + +QList Database::searchDatabase(const QString &query, const QList &collections, int k) +{ + std::vector queryEmbd = m_embLLM->generateQueryEmbedding(query); + if (queryEmbd.empty()) { + qDebug() << "ERROR: generating embeddings returned a null result"; + return { }; + } + + const QList embeddingResults = searchEmbeddings(queryEmbd, collections, k); + BM25Query bm25q; + const QList bm25Results = searchBM25(query, collections, bm25q, k); + return reciprocalRankFusion(queryEmbd, embeddingResults, bm25Results, bm25q, k); +} + void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results) { @@ -1944,13 +2204,7 @@ void Database::retrieveFromDB(const QList &collections, const QString & qDebug() << "retrieveFromDB" << collections << text << retrievalSize; #endif - std::vector queryEmbd = m_embLLM->generateQueryEmbedding(text); - if (queryEmbd.empty()) { - qDebug() << "ERROR: generating embeddings returned a null result"; - return; - } - - QList searchResults = searchEmbeddings(queryEmbd, collections, retrievalSize); + QList searchResults = searchDatabase(text, collections, retrievalSize); if (searchResults.isEmpty()) return; @@ -1960,10 +2214,9 @@ void Database::retrieveFromDB(const QList &collections, const QString & return; } + QHash tempResults; while (q.next()) { -#if defined(DEBUG) const int rowid = q.value(0).toInt(); -#endif const QString document_path = q.value(2).toString(); const QString chunk_text = q.value(3).toString(); const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"); @@ -1985,12 +2238,16 @@ void Database::retrieveFromDB(const QList &collections, const QString & info.page = page; info.from = from; info.to = to; - results->append(info); + tempResults.insert(rowid, info); #if defined(DEBUG) qDebug() << "retrieve rowid:" << rowid << "chunk_text:" << chunk_text; #endif } + + for (int id : searchResults) + if (tempResults.contains(id)) + results->append(tempResults.value(id)); } // FIXME This is very slow and non-interruptible and when we close the application and we're diff --git a/gpt4all-chat/src/database.h b/gpt4all-chat/src/database.h index 90312290..c5a43b22 100644 --- a/gpt4all-chat/src/database.h +++ b/gpt4all-chat/src/database.h @@ -35,7 +35,7 @@ class QTimer; // minimum supported version static const int LOCALDOCS_MIN_VER = 1; // current version -static const int LOCALDOCS_VERSION = 2; +static const int LOCALDOCS_VERSION = 3; struct DocumentInfo { @@ -206,7 +206,24 @@ private: bool cleanDB(); void addFolderToWatch(const QString &path); void removeFolderFromWatch(const QString &path); - QList searchEmbeddings(const std::vector &query, const QList &collections, int nNeighbors); + static QList searchEmbeddingsHelper(const std::vector &query, QSqlQuery &q, int nNeighbors); + QList searchEmbeddings(const std::vector &query, const QList &collections, + int nNeighbors); + struct BM25Query { + QString input; + QString query; + bool isExact = false; + int qlength = 0; + int ilength = 0; + int rlength = 0; + }; + QList queriesForFTS5(const QString &input); + QList searchBM25(const QString &query, const QList &collections, BM25Query &bm25q, int k); + QList scoreChunks(const std::vector &query, const QList &chunks); + float computeBM25Weight(const BM25Query &bm25q); + QList reciprocalRankFusion(const std::vector &query, const QList &embeddingResults, + const QList &bm25Results, const BM25Query &bm25q, int k); + QList searchDatabase(const QString &query, const QList &collections, int k); void setStartUpdateTime(CollectionItem &item); void setLastUpdateTime(CollectionItem &item);