From 371e2a5cbc70be9d2e06d67c43d9adbee27da441 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 24 Oct 2023 12:13:32 -0400 Subject: [PATCH] LocalDocs version 2 with text embeddings. --- gpt4all-backend/bert.cpp | 3 + gpt4all-backend/llmodel_shared.h | 3 +- gpt4all-chat/CMakeLists.txt | 5 +- gpt4all-chat/chat.cpp | 3 + gpt4all-chat/chat.h | 4 + gpt4all-chat/database.cpp | 422 +++++-- gpt4all-chat/database.h | 44 +- gpt4all-chat/embeddings.cpp | 190 +++ gpt4all-chat/embeddings.h | 45 + gpt4all-chat/embllm.cpp | 64 + gpt4all-chat/embllm.h | 27 + gpt4all-chat/hnswlib/bruteforce.h | 167 +++ gpt4all-chat/hnswlib/hnswalg.h | 1271 ++++++++++++++++++++ gpt4all-chat/hnswlib/hnswlib.h | 199 +++ gpt4all-chat/hnswlib/space_ip.h | 375 ++++++ gpt4all-chat/hnswlib/space_l2.h | 324 +++++ gpt4all-chat/hnswlib/visited_list_pool.h | 78 ++ gpt4all-chat/localdocs.cpp | 38 +- gpt4all-chat/localdocs.h | 2 +- gpt4all-chat/localdocsmodel.cpp | 162 ++- gpt4all-chat/localdocsmodel.h | 42 +- gpt4all-chat/main.qml | 15 +- gpt4all-chat/metadata/models2.json | 10 +- gpt4all-chat/modellist.cpp | 63 +- gpt4all-chat/modellist.h | 25 + gpt4all-chat/qml/CollectionsDialog.qml | 45 +- gpt4all-chat/qml/LocalDocsSettings.qml | 45 +- gpt4all-chat/qml/ModelDownloaderDialog.qml | 8 + gpt4all-chat/qml/MySettingsTab.qml | 6 +- gpt4all-chat/qml/SettingsDialog.qml | 19 +- 30 files changed, 3540 insertions(+), 164 deletions(-) create mode 100644 gpt4all-chat/embeddings.cpp create mode 100644 gpt4all-chat/embeddings.h create mode 100644 gpt4all-chat/embllm.cpp create mode 100644 gpt4all-chat/embllm.h create mode 100644 gpt4all-chat/hnswlib/bruteforce.h create mode 100644 gpt4all-chat/hnswlib/hnswalg.h create mode 100644 gpt4all-chat/hnswlib/hnswlib.h create mode 100644 gpt4all-chat/hnswlib/space_ip.h create mode 100644 gpt4all-chat/hnswlib/space_l2.h create mode 100644 gpt4all-chat/hnswlib/visited_list_pool.h diff --git a/gpt4all-backend/bert.cpp b/gpt4all-backend/bert.cpp index ba92465d..f74e8554 100644 --- a/gpt4all-backend/bert.cpp +++ b/gpt4all-backend/bert.cpp @@ -490,6 +490,9 @@ struct bert_ctx * bert_load_from_file(const char *fname) #endif bert_ctx * new_bert = new bert_ctx; + new_bert->buf_compute.force_cpu = true; + new_bert->work_buf.force_cpu = true; + bert_model & model = new_bert->model; bert_vocab & vocab = new_bert->vocab; diff --git a/gpt4all-backend/llmodel_shared.h b/gpt4all-backend/llmodel_shared.h index 0c620c4e..c48f1fdf 100644 --- a/gpt4all-backend/llmodel_shared.h +++ b/gpt4all-backend/llmodel_shared.h @@ -10,13 +10,14 @@ struct llm_buffer { uint8_t * addr = NULL; size_t size = 0; ggml_vk_memory memory; + bool force_cpu = false; llm_buffer() = default; void resize(size_t size) { free(); - if (!ggml_vk_has_device()) { + if (!ggml_vk_has_device() || force_cpu) { this->addr = new uint8_t[size]; this->size = size; } else { diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 11167afe..f14aed3b 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -75,7 +75,9 @@ qt_add_executable(chat chatmodel.h chatlistmodel.h chatlistmodel.cpp chatgpt.h chatgpt.cpp database.h database.cpp + embeddings.h embeddings.cpp download.h download.cpp + embllm.cpp embllm.h localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp llm.h llm.cpp modellist.h modellist.cpp @@ -90,6 +92,7 @@ qt_add_executable(chat qt_add_qml_module(chat URI gpt4all VERSION 1.0 + NO_CACHEGEN QML_FILES main.qml qml/ChatDrawer.qml @@ -170,7 +173,7 @@ else() PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf) endif() target_link_libraries(chat - PRIVATE llmodel) + PRIVATE llmodel bert-default) set(COMPONENT_NAME_MAIN ${PROJECT_NAME}) set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 453f17f0..3e7f91c1 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -18,6 +18,7 @@ Chat::Chat(QObject *parent) , m_shouldDeleteLater(false) , m_isModelLoaded(false) , m_shouldLoadModelWhenInstalled(false) + , m_collectionModel(new LocalDocsCollectionsModel(this)) { connectLLM(); } @@ -35,6 +36,7 @@ Chat::Chat(bool isServer, QObject *parent) , m_shouldDeleteLater(false) , m_isModelLoaded(false) , m_shouldLoadModelWhenInstalled(false) + , m_collectionModel(new LocalDocsCollectionsModel(this)) { connectLLM(); } @@ -71,6 +73,7 @@ void Chat::connectLLM() connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection); connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection); + connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); connect(ModelList::globalInstance()->installedModels(), &InstalledModels::countChanged, this, &Chat::handleModelInstalled, Qt::QueuedConnection); } diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index f13747dd..30308f42 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -27,6 +27,7 @@ class Chat : public QObject Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged); Q_PROPERTY(QString device READ device NOTIFY deviceChanged); Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged); + Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -83,6 +84,7 @@ public: bool isServer() const { return m_isServer; } QList collectionList() const; + LocalDocsCollectionsModel *collectionModel() const { return m_collectionModel; } Q_INVOKABLE bool hasCollection(const QString &collection) const; Q_INVOKABLE void addCollection(const QString &collection); @@ -123,6 +125,7 @@ Q_SIGNALS: void tokenSpeedChanged(); void deviceChanged(); void fallbackReasonChanged(); + void collectionModelChanged(); private Q_SLOTS: void handleResponseChanged(const QString &response); @@ -161,6 +164,7 @@ private: bool m_shouldDeleteLater; bool m_isModelLoaded; bool m_shouldLoadModelWhenInstalled; + LocalDocsCollectionsModel *m_collectionModel; }; #endif // CHAT_H diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index 54100a50..8369db5b 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -1,5 +1,7 @@ #include "database.h" #include "mysettings.h" +#include "embllm.h" +#include "embeddings.h" #include #include @@ -7,18 +9,18 @@ //#define DEBUG //#define DEBUG_EXAMPLE -#define LOCALDOCS_VERSION 0 +#define LOCALDOCS_VERSION 1 const auto INSERT_CHUNK_SQL = QLatin1String(R"( - insert into chunks(document_id, chunk_id, chunk_text, - file, title, author, subject, keywords, page, line_from, line_to, - embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + insert into chunks(document_id, chunk_text, + file, title, author, subject, keywords, page, line_from, line_to) + values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?); )"); const auto INSERT_CHUNK_FTS_SQL = QLatin1String(R"( insert into chunks_fts(document_id, chunk_id, chunk_text, - file, title, author, subject, keywords, page, line_from, line_to, - embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + file, title, author, subject, keywords, page, line_from, line_to) + values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); )"); const auto DELETE_CHUNKS_SQL = QLatin1String(R"( @@ -30,20 +32,33 @@ const auto DELETE_CHUNKS_FTS_SQL = QLatin1String(R"( )"); const auto CHUNKS_SQL = QLatin1String(R"( - create table chunks(document_id integer, chunk_id integer, chunk_text varchar, + create table chunks(document_id integer, chunk_id integer primary key autoincrement, chunk_text varchar, file varchar, title varchar, author varchar, subject varchar, keywords varchar, - page integer, line_from integer, line_to integer, - embedding_id integer, embedding_path varchar); + page integer, line_from integer, line_to integer); )"); const auto FTS_CHUNKS_SQL = QLatin1String(R"( create virtual table chunks_fts using fts5(document_id unindexed, chunk_id unindexed, chunk_text, - file, title, author, subject, keywords, page, line_from, line_to, - embedding_id unindexed, embedding_path unindexed, tokenize="trigram"); + file, title, author, subject, keywords, page, line_from, line_to, tokenize="trigram"); )"); -const auto SELECT_SQL = QLatin1String(R"( - select chunks_fts.rowid, documents.document_time, +const auto SELECT_CHUNKS_BY_DOCUMENT_SQL = QLatin1String(R"( + select chunk_id from chunks WHERE document_id = ?; + )"); + +const auto SELECT_CHUNKS_SQL = QLatin1String(R"( + select chunks.chunk_id, documents.document_time, + chunks.chunk_text, chunks.file, chunks.title, chunks.author, chunks.page, + chunks.line_from, chunks.line_to + from chunks + join documents ON chunks.document_id = documents.id + join folders ON documents.folder_id = folders.id + join collections ON folders.id = collections.folder_id + where chunks.chunk_id in (%1) and collections.collection_name in (%2); +)"); + +const auto SELECT_NGRAM_SQL = QLatin1String(R"( + select chunks_fts.chunk_id, documents.document_time, chunks_fts.chunk_text, chunks_fts.file, chunks_fts.title, chunks_fts.author, chunks_fts.page, chunks_fts.line_from, chunks_fts.line_to from chunks_fts @@ -55,16 +70,14 @@ const auto SELECT_SQL = QLatin1String(R"( limit %2; )"); -bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_text, +bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text, const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, - int page, int from, int to, - int embedding_id, const QString &embedding_path) + int page, int from, int to, int *chunk_id) { { if (!q.prepare(INSERT_CHUNK_SQL)) return false; q.addBindValue(document_id); - q.addBindValue(chunk_id); q.addBindValue(chunk_text); q.addBindValue(file); q.addBindValue(title); @@ -74,16 +87,19 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_ q.addBindValue(page); q.addBindValue(from); q.addBindValue(to); - q.addBindValue(embedding_id); - q.addBindValue(embedding_path); if (!q.exec()) return false; } + if (!q.exec("select last_insert_rowid();")) + return false; + if (!q.next()) + return false; + *chunk_id = q.value(0).toInt(); { if (!q.prepare(INSERT_CHUNK_FTS_SQL)) return false; q.addBindValue(document_id); - q.addBindValue(chunk_id); + q.addBindValue(*chunk_id); q.addBindValue(chunk_text); q.addBindValue(file); q.addBindValue(title); @@ -93,8 +109,6 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_ q.addBindValue(page); q.addBindValue(from); q.addBindValue(to); - q.addBindValue(embedding_id); - q.addBindValue(embedding_path); if (!q.exec()) return false; } @@ -146,6 +160,18 @@ QStringList generateGrams(const QString &input, int N) return ngrams; } +bool selectChunk(QSqlQuery &q, const QList &collection_names, const std::vector &chunk_ids, int retrievalSize) +{ + QString chunk_ids_str = QString::number(chunk_ids[0]); + for (size_t i = 1; i < chunk_ids.size(); ++i) + chunk_ids_str += "," + QString::number(chunk_ids[i]); + const QString collection_names_str = collection_names.join("', '"); + const QString formatted_query = SELECT_CHUNKS_SQL.arg(chunk_ids_str).arg("'" + collection_names_str + "'"); + if (!q.prepare(formatted_query)) + return false; + return q.exec(); +} + bool selectChunk(QSqlQuery &q, const QList &collection_names, const QString &chunk_text, int retrievalSize) { static QRegularExpression spaces("\\s+"); @@ -155,7 +181,7 @@ bool selectChunk(QSqlQuery &q, const QList &collection_names, const QSt QList text = generateGrams(chunk_text, N); QString orText = text.join(" OR "); const QString collection_names_str = collection_names.join("', '"); - const QString formatted_query = SELECT_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize)); + const QString formatted_query = SELECT_NGRAM_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize)); if (!q.prepare(formatted_query)) return false; q.addBindValue(orText); @@ -248,7 +274,8 @@ bool selectAllFromCollections(QSqlQuery &q, QList *collections) CollectionItem i; i.collection = q.value(0).toString(); i.folder_path = q.value(1).toString(); - i.folder_id = q.value(0).toInt(); + i.folder_id = q.value(2).toInt(); + i.indexing = false; i.installed = true; collections->append(i); } @@ -459,6 +486,12 @@ QSqlError initDb() return q.lastError(); } + CollectionItem i; + i.collection = collection_name; + i.folder_path = folder_path; + i.folder_id = folder_id; + emit addCollectionItem(i); + // Add a document int document_time = 123456789; int document_id; @@ -504,6 +537,8 @@ Database::Database(int chunkSize) : QObject(nullptr) , m_watcher(new QFileSystemWatcher(this)) , m_chunkSize(chunkSize) + , m_embLLM(new EmbeddingLLM) + , m_embeddings(new Embeddings(this)) { moveToThread(&m_dbThread); connect(&m_dbThread, &QThread::started, this, &Database::start); @@ -511,22 +546,39 @@ Database::Database(int chunkSize) m_dbThread.start(); } -void Database::handleDocumentErrorAndScheduleNext(const QString &errorMessage, - int document_id, const QString &document_path, const QSqlError &error) +Database::~Database() { - qWarning() << errorMessage << document_id << document_path << error.text(); + m_dbThread.quit(); + m_dbThread.wait(); +} + +void Database::scheduleNext(int folder_id, size_t countForFolder) +{ + emit updateCurrentDocsToIndex(folder_id, countForFolder); + if (!countForFolder) { + emit updateIndexing(folder_id, false); + emit updateInstalled(folder_id, true); + m_embeddings->save(); + } if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue); } -void Database::chunkStream(QTextStream &stream, int document_id, const QString &file, - const QString &title, const QString &author, const QString &subject, const QString &keywords, int page) +void Database::handleDocumentError(const QString &errorMessage, + int document_id, const QString &document_path, const QSqlError &error) +{ + qWarning() << errorMessage << document_id << document_path << error.text(); +} + +size_t Database::chunkStream(QTextStream &stream, int document_id, const QString &file, + const QString &title, const QString &author, const QString &subject, const QString &keywords, int page, + int maxChunks) { - int chunk_id = 0; int charCount = 0; int line_from = -1; int line_to = -1; QList words; + int chunks = 0; while (!stream.atEnd()) { QString word; @@ -536,9 +588,9 @@ void Database::chunkStream(QTextStream &stream, int document_id, const QString & if (charCount + words.size() - 1 >= m_chunkSize || stream.atEnd()) { const QString chunk = words.join(" "); QSqlQuery q; + int chunk_id = 0; if (!addChunk(q, document_id, - ++chunk_id, chunk, file, title, @@ -548,15 +600,111 @@ void Database::chunkStream(QTextStream &stream, int document_id, const QString & page, line_from, line_to, - 0 /*embedding_id*/, - QString() /*embedding_path*/ + &chunk_id )) { qWarning() << "ERROR: Could not insert chunk into db" << q.lastError(); } + + const std::vector result = m_embLLM->generateEmbeddings(chunk); + if (!m_embeddings->add(result, chunk_id)) + qWarning() << "ERROR: Cannot add point to embeddings index"; + + ++chunks; + words.clear(); charCount = 0; + + if (maxChunks > 0 && chunks == maxChunks) + return stream.pos(); } } + return stream.pos(); +} + +void Database::removeEmbeddingsByDocumentId(int document_id) +{ + QSqlQuery q; + + if (!q.prepare(SELECT_CHUNKS_BY_DOCUMENT_SQL)) { + qWarning() << "ERROR: Cannot prepare sql for select chunks by document" << q.lastError(); + return; + } + + q.addBindValue(document_id); + + if (!q.exec()) { + qWarning() << "ERROR: Cannot exec sql for select chunks by document" << q.lastError(); + return; + } + + while (q.next()) { + const int chunk_id = q.value(0).toInt(); + m_embeddings->remove(chunk_id); + } + m_embeddings->save(); +} + +size_t Database::countOfDocuments(int folder_id) const +{ + if (!m_docsToScan.contains(folder_id)) + return 0; + return m_docsToScan.value(folder_id).size(); +} + +size_t Database::countOfBytes(int folder_id) const +{ + if (!m_docsToScan.contains(folder_id)) + return 0; + size_t totalBytes = 0; + const QQueue &docs = m_docsToScan.value(folder_id); + for (const DocumentInfo &f : docs) + totalBytes += f.doc.size(); + return totalBytes; +} + +DocumentInfo Database::dequeueDocument() +{ + Q_ASSERT(!m_docsToScan.isEmpty()); + const int firstKey = m_docsToScan.firstKey(); + QQueue &queue = m_docsToScan[firstKey]; + Q_ASSERT(!queue.isEmpty()); + DocumentInfo result = queue.dequeue(); + if (queue.isEmpty()) + m_docsToScan.remove(firstKey); + return result; +} + +void Database::removeFolderFromDocumentQueue(int folder_id) +{ + if (!m_docsToScan.contains(folder_id)) + return; + m_docsToScan.remove(folder_id); + emit removeFolderById(folder_id); + emit docsToScanChanged(); +} + +void Database::enqueueDocumentInternal(const DocumentInfo &info, bool prepend) +{ + const int key = info.folder; + if (!m_docsToScan.contains(key)) + m_docsToScan[key] = QQueue(); + if (prepend) + m_docsToScan[key].prepend(info); + else + m_docsToScan[key].enqueue(info); +} + +void Database::enqueueDocuments(int folder_id, const QVector &infos) +{ + for (int i = 0; i < infos.size(); ++i) + enqueueDocumentInternal(infos[i]); + const size_t count = countOfDocuments(folder_id); + emit updateCurrentDocsToIndex(folder_id, count); + emit updateTotalDocsToIndex(folder_id, count); + const size_t bytes = countOfBytes(folder_id); + emit updateCurrentBytesToIndex(folder_id, bytes); + emit updateTotalBytesToIndex(folder_id, bytes); + emit docsToScanChanged(); } void Database::scanQueue() @@ -564,7 +712,9 @@ void Database::scanQueue() if (m_docsToScan.isEmpty()) return; - DocumentInfo info = m_docsToScan.dequeue(); + DocumentInfo info = dequeueDocument(); + const size_t countForFolder = countOfDocuments(info.folder); + const int folder_id = info.folder; // Update info info.doc.stat(); @@ -572,99 +722,127 @@ void Database::scanQueue() // If the doc has since been deleted or no longer readable, then we schedule more work and return // leaving the cleanup for the cleanup handler if (!info.doc.exists() || !info.doc.isReadable()) { - if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue); - return; + return scheduleNext(folder_id, countForFolder); } - const int folder_id = info.folder; const qint64 document_time = info.doc.fileTime(QFile::FileModificationTime).toMSecsSinceEpoch(); const QString document_path = info.doc.canonicalFilePath(); - -#if defined(DEBUG) - qDebug() << "scanning document" << document_path; -#endif + const bool currentlyProcessing = info.currentlyProcessing; // Check and see if we already have this document QSqlQuery q; int existing_id = -1; qint64 existing_time = -1; if (!selectDocument(q, document_path, &existing_id, &existing_time)) { - return handleDocumentErrorAndScheduleNext("ERROR: Cannot select document", + handleDocumentError("ERROR: Cannot select document", existing_id, document_path, q.lastError()); + return scheduleNext(folder_id, countForFolder); } // If we have the document, we need to compare the last modification time and if it is newer // we must rescan the document, otherwise return - if (existing_id != -1) { + if (existing_id != -1 && !currentlyProcessing) { Q_ASSERT(existing_time != -1); if (document_time == existing_time) { // No need to rescan, but we do have to schedule next - if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue); - return; + return scheduleNext(folder_id, countForFolder); } else { + removeEmbeddingsByDocumentId(existing_id); if (!removeChunksByDocumentId(q, existing_id)) { - return handleDocumentErrorAndScheduleNext("ERROR: Cannot remove chunks of document", + handleDocumentError("ERROR: Cannot remove chunks of document", existing_id, document_path, q.lastError()); + return scheduleNext(folder_id, countForFolder); } } } // Update the document_time for an existing document, or add it for the first time now int document_id = existing_id; - if (document_id != -1) { - if (!updateDocument(q, document_id, document_time)) { - return handleDocumentErrorAndScheduleNext("ERROR: Could not update document_time", - document_id, document_path, q.lastError()); - } - } else { - if (!addDocument(q, folder_id, document_time, document_path, &document_id)) { - return handleDocumentErrorAndScheduleNext("ERROR: Could not add document", - document_id, document_path, q.lastError()); + if (!currentlyProcessing) { + if (document_id != -1) { + if (!updateDocument(q, document_id, document_time)) { + handleDocumentError("ERROR: Could not update document_time", + document_id, document_path, q.lastError()); + return scheduleNext(folder_id, countForFolder); + } + } else { + if (!addDocument(q, folder_id, document_time, document_path, &document_id)) { + handleDocumentError("ERROR: Could not add document", + document_id, document_path, q.lastError()); + return scheduleNext(folder_id, countForFolder); + } } } - QElapsedTimer timer; - timer.start(); - QSqlDatabase::database().transaction(); Q_ASSERT(document_id != -1); - if (info.doc.suffix() == QLatin1String("pdf")) { + if (info.isPdf()) { QPdfDocument doc; if (QPdfDocument::Error::None != doc.load(info.doc.canonicalFilePath())) { - return handleDocumentErrorAndScheduleNext("ERROR: Could not load pdf", + handleDocumentError("ERROR: Could not load pdf", document_id, document_path, q.lastError()); - return; + return scheduleNext(folder_id, countForFolder); } - for (int i = 0; i < doc.pageCount(); ++i) { - const QPdfSelection selection = doc.getAllText(i); - QString text = selection.text(); - QTextStream stream(&text); - chunkStream(stream, document_id, info.doc.fileName(), - doc.metaData(QPdfDocument::MetaDataField::Title).toString(), - doc.metaData(QPdfDocument::MetaDataField::Author).toString(), - doc.metaData(QPdfDocument::MetaDataField::Subject).toString(), - doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(), - i + 1 - ); + const size_t bytes = info.doc.size(); + const size_t bytesPerPage = std::floor(bytes / doc.pageCount()); + const int pageIndex = info.currentPage; +#if defined(DEBUG) + qDebug() << "scanning page" << pageIndex << "of" << doc.pageCount() << document_path; +#endif + const QPdfSelection selection = doc.getAllText(pageIndex); + QString text = selection.text(); + QTextStream stream(&text); + chunkStream(stream, document_id, info.doc.fileName(), + doc.metaData(QPdfDocument::MetaDataField::Title).toString(), + doc.metaData(QPdfDocument::MetaDataField::Author).toString(), + doc.metaData(QPdfDocument::MetaDataField::Subject).toString(), + doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(), + pageIndex + 1 + ); + m_embeddings->save(); + emit subtractCurrentBytesToIndex(info.folder, bytesPerPage); + if (info.currentPage < doc.pageCount()) { + info.currentPage += 1; + info.currentlyProcessing = true; + enqueueDocumentInternal(info, true /*prepend*/); + return scheduleNext(folder_id, countForFolder + 1); + } else { + emit subtractCurrentBytesToIndex(info.folder, bytes - (bytesPerPage * doc.pageCount())); } } else { QFile file(document_path); - if (!file.open( QIODevice::ReadOnly)) { - return handleDocumentErrorAndScheduleNext("ERROR: Cannot open file for scanning", - existing_id, document_path, q.lastError()); + if (!file.open(QIODevice::ReadOnly)) { + handleDocumentError("ERROR: Cannot open file for scanning", + existing_id, document_path, q.lastError()); + return scheduleNext(folder_id, countForFolder); } + + const size_t bytes = info.doc.size(); QTextStream stream(&file); - chunkStream(stream, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/, - QString() /*subject*/, QString() /*keywords*/, -1 /*page*/); + const size_t byteIndex = info.currentPosition; + if (!stream.seek(byteIndex)) { + handleDocumentError("ERROR: Cannot seek to pos for scanning", + existing_id, document_path, q.lastError()); + return scheduleNext(folder_id, countForFolder); + } +#if defined(DEBUG) + qDebug() << "scanning byteIndex" << byteIndex << "of" << bytes << document_path; +#endif + int pos = chunkStream(stream, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/, + QString() /*subject*/, QString() /*keywords*/, -1 /*page*/, 5 /*maxChunks*/); + m_embeddings->save(); file.close(); + const size_t bytesChunked = pos - byteIndex; + emit subtractCurrentBytesToIndex(info.folder, bytesChunked); + if (info.currentPosition < bytes) { + info.currentPosition = pos; + info.currentlyProcessing = true; + enqueueDocumentInternal(info, true /*prepend*/); + return scheduleNext(folder_id, countForFolder + 1); + } } QSqlDatabase::database().commit(); - -#if defined(DEBUG) - qDebug() << "chunking" << document_path << "took" << timer.elapsed() << "ms"; -#endif - - if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue); + return scheduleNext(folder_id, countForFolder); } void Database::scanDocuments(int folder_id, const QString &folder_path) @@ -687,6 +865,7 @@ void Database::scanDocuments(int folder_id, const QString &folder_path) Q_ASSERT(dir.exists()); Q_ASSERT(dir.isReadable()); QDirIterator it(folder_path, QDir::Readable | QDir::Files, QDirIterator::Subdirectories); + QVector infos; while (it.hasNext()) { it.next(); QFileInfo fileInfo = it.fileInfo(); @@ -701,9 +880,13 @@ void Database::scanDocuments(int folder_id, const QString &folder_path) DocumentInfo info; info.folder = folder_id; info.doc = fileInfo; - m_docsToScan.enqueue(info); + infos.append(info); + } + + if (!infos.isEmpty()) { + emit updateIndexing(folder_id, true); + enqueueDocuments(folder_id, infos); } - emit docsToScanChanged(); } void Database::start() @@ -717,6 +900,10 @@ void Database::start() if (err.type() != QSqlError::NoError) qWarning() << "ERROR: initializing db" << err.text(); } + + if (m_embeddings->fileExists() && !m_embeddings->load()) + qWarning() << "ERROR: Could not load embeddings"; + addCurrentFolders(); } @@ -733,25 +920,12 @@ void Database::addCurrentFolders() return; } + emit collectionListUpdated(collections); + for (const auto &i : collections) addFolder(i.collection, i.folder_path); } -void Database::updateCollectionList() -{ -#if defined(DEBUG) - qDebug() << "updateCollectionList"; -#endif - - QSqlQuery q; - QList collections; - if (!selectAllFromCollections(q, &collections)) { - qWarning() << "ERROR: Cannot select collections" << q.lastError(); - return; - } - emit collectionListUpdated(collections); -} - void Database::addFolder(const QString &collection, const QString &path) { QFileInfo info(path); @@ -784,14 +958,21 @@ void Database::addFolder(const QString &collection, const QString &path) return; } - if (!folders.contains(folder_id) && !addCollection(q, collection, folder_id)) { - qWarning() << "ERROR: Cannot add folder to collection" << collection << path << q.lastError(); - return; + if (!folders.contains(folder_id)) { + if (!addCollection(q, collection, folder_id)) { + qWarning() << "ERROR: Cannot add folder to collection" << collection << path << q.lastError(); + return; + } + + CollectionItem i; + i.collection = collection; + i.folder_path = path; + i.folder_id = folder_id; + emit addCollectionItem(i); } addFolderToWatch(path); scanDocuments(folder_id, path); - updateCollectionList(); } void Database::removeFolder(const QString &collection, const QString &path) @@ -840,15 +1021,8 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co if (collections.count() > 1) return; - // First remove all upcoming jobs associated with this folder by performing an opt-in filter - QQueue docsToScan; - for (const DocumentInfo &info : m_docsToScan) { - if (info.folder == folder_id) - continue; - docsToScan.append(info); - } - m_docsToScan = docsToScan; - emit docsToScanChanged(); + // First remove all upcoming jobs associated with this folder + removeFolderFromDocumentQueue(folder_id); // Get a list of all documents associated with folder QList documentIds; @@ -859,6 +1033,7 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co // Remove all chunks and documents associated with this folder for (int document_id : documentIds) { + removeEmbeddingsByDocumentId(document_id); if (!removeChunksByDocumentId(q, document_id)) { qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << q.lastError(); return; @@ -875,8 +1050,9 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co return; } + emit removeFolderById(folder_id); + removeFolderFromWatch(path); - updateCollectionList(); } bool Database::addFolderToWatch(const QString &path) @@ -903,9 +1079,18 @@ void Database::retrieveFromDB(const QList &collections, const QString & #endif QSqlQuery q; - if (!selectChunk(q, collections, text, retrievalSize)) { - qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); - return; + if (m_embeddings->isLoaded()) { + std::vector result = m_embLLM->generateEmbeddings(text); + std::vector embeddings = m_embeddings->search(result, retrievalSize); + if (!selectChunk(q, collections, embeddings, retrievalSize)) { + qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); + return; + } + } else { + if (!selectChunk(q, collections, text, retrievalSize)) { + qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); + return; + } } while (q.next()) { @@ -986,6 +1171,7 @@ void Database::cleanDB() // Remove all chunks and documents that either don't exist or have become unreadable QSqlQuery query; + removeEmbeddingsByDocumentId(document_id); if (!removeChunksByDocumentId(query, document_id)) { qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError(); } @@ -994,7 +1180,6 @@ void Database::cleanDB() qWarning() << "ERROR: Cannot remove document_id" << document_id << query.lastError(); } } - updateCollectionList(); } void Database::changeChunkSize(int chunkSize) @@ -1024,6 +1209,7 @@ void Database::changeChunkSize(int chunkSize) int document_id = q.value(0).toInt(); // Remove all chunks and documents to change the chunk size QSqlQuery query; + removeEmbeddingsByDocumentId(document_id); if (!removeChunksByDocumentId(query, document_id)) { qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError(); } diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index b0e90472..b217758b 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -8,10 +8,18 @@ #include #include +class Embeddings; +class EmbeddingLLM; struct DocumentInfo { int folder; QFileInfo doc; + int currentPage = 0; + size_t currentPosition = 0; + bool currentlyProcessing = false; + bool isPdf() const { + return doc.suffix() == QLatin1String("pdf"); + } }; struct ResultInfo { @@ -30,6 +38,11 @@ struct CollectionItem { QString folder_path; int folder_id = -1; bool installed = false; + bool indexing = false; + int currentDocsToIndex = 0; + int totalDocsToIndex = 0; + size_t currentBytesToIndex = 0; + size_t totalBytesToIndex = 0; }; Q_DECLARE_METATYPE(CollectionItem) @@ -38,6 +51,7 @@ class Database : public QObject Q_OBJECT public: Database(int chunkSize); + virtual ~Database(); public Q_SLOTS: void scanQueue(); @@ -50,6 +64,16 @@ public Q_SLOTS: Q_SIGNALS: void docsToScanChanged(); + void updateInstalled(int folder_id, bool b); + void updateIndexing(int folder_id, bool b); + void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex); + void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex); + void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes); + void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex); + void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex); + void addCollectionItem(const CollectionItem &item); + void removeFolderById(int folder_id); + void removeCollectionItem(const QString &collectionName); void collectionListUpdated(const QList &collectionList); private Q_SLOTS: @@ -58,21 +82,31 @@ private Q_SLOTS: bool addFolderToWatch(const QString &path); bool removeFolderFromWatch(const QString &path); void addCurrentFolders(); - void updateCollectionList(); private: void removeFolderInternal(const QString &collection, int folder_id, const QString &path); - void chunkStream(QTextStream &stream, int document_id, const QString &file, - const QString &title, const QString &author, const QString &subject, const QString &keywords, int page); - void handleDocumentErrorAndScheduleNext(const QString &errorMessage, + size_t chunkStream(QTextStream &stream, int document_id, const QString &file, + const QString &title, const QString &author, const QString &subject, const QString &keywords, int page, + int maxChunks = -1); + void removeEmbeddingsByDocumentId(int document_id); + void scheduleNext(int folder_id, size_t countForFolder); + void handleDocumentError(const QString &errorMessage, int document_id, const QString &document_path, const QSqlError &error); + size_t countOfDocuments(int folder_id) const; + size_t countOfBytes(int folder_id) const; + DocumentInfo dequeueDocument(); + void removeFolderFromDocumentQueue(int folder_id); + void enqueueDocumentInternal(const DocumentInfo &info, bool prepend = false); + void enqueueDocuments(int folder_id, const QVector &infos); private: int m_chunkSize; - QQueue m_docsToScan; + QMap> m_docsToScan; QList m_retrieve; QThread m_dbThread; QFileSystemWatcher *m_watcher; + EmbeddingLLM *m_embLLM; + Embeddings *m_embeddings; }; #endif // DATABASE_H diff --git a/gpt4all-chat/embeddings.cpp b/gpt4all-chat/embeddings.cpp new file mode 100644 index 00000000..58137809 --- /dev/null +++ b/gpt4all-chat/embeddings.cpp @@ -0,0 +1,190 @@ +#include "embeddings.h" + +#include +#include +#include + +#include "mysettings.h" +#include "hnswlib/hnswlib.h" + +#define EMBEDDINGS_VERSION 0 + +const int s_dim = 384; // Dimension of the elements +const int s_ef_construction = 200; // Controls index search speed/build speed tradeoff +const int s_M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + +Embeddings::Embeddings(QObject *parent) + : QObject(parent) + , m_space(nullptr) + , m_hnsw(nullptr) +{ + m_filePath = MySettings::globalInstance()->modelPath() + + QString("embeddings_v%1.dat").arg(EMBEDDINGS_VERSION); +} + +Embeddings::~Embeddings() +{ + delete m_hnsw; + m_hnsw = nullptr; + delete m_space; + m_space = nullptr; +} + +bool Embeddings::load() +{ + QFileInfo info(m_filePath); + if (!info.exists()) { + qWarning() << "ERROR: loading embeddings file does not exist" << m_filePath; + return false; + } + + if (!info.isReadable()) { + qWarning() << "ERROR: loading embeddings file is not readable" << m_filePath; + return false; + } + + if (!info.isWritable()) { + qWarning() << "ERROR: loading embeddings file is not writeable" << m_filePath; + return false; + } + + try { + m_space = new hnswlib::InnerProductSpace(s_dim); + m_hnsw = new hnswlib::HierarchicalNSW(m_space, m_filePath.toStdString(), s_M, s_ef_construction); + } catch (const std::exception &e) { + qWarning() << "ERROR: could not load hnswlib index:" << e.what(); + return false; + } + return isLoaded(); +} + +bool Embeddings::load(qint64 maxElements) +{ + try { + m_space = new hnswlib::InnerProductSpace(s_dim); + m_hnsw = new hnswlib::HierarchicalNSW(m_space, maxElements, s_M, s_ef_construction); + } catch (const std::exception &e) { + qWarning() << "ERROR: could not create hnswlib index:" << e.what(); + return false; + } + return isLoaded(); +} + +bool Embeddings::save() +{ + if (!isLoaded()) + return false; + try { + m_hnsw->saveIndex(m_filePath.toStdString()); + } catch (const std::exception &e) { + qWarning() << "ERROR: could not save hnswlib index:" << e.what(); + return false; + } + return true; +} + +bool Embeddings::isLoaded() const +{ + return m_hnsw != nullptr; +} + +bool Embeddings::fileExists() const +{ + QFileInfo info(m_filePath); + return info.exists(); +} + +bool Embeddings::resize(qint64 size) +{ + if (!isLoaded()) { + qWarning() << "ERROR: attempting to resize an embedding when the embeddings are not open!"; + return false; + } + + Q_ASSERT(m_hnsw); + try { + m_hnsw->resizeIndex(size); + } catch (const std::exception &e) { + qWarning() << "ERROR: could not resize hnswlib index:" << e.what(); + return false; + } + return true; +} + +bool Embeddings::add(const std::vector &embedding, qint64 label) +{ + if (!isLoaded()) { + bool success = load(500); + if (!success) { + qWarning() << "ERROR: attempting to add an embedding when the embeddings are not open!"; + return false; + } + } + + Q_ASSERT(m_hnsw); + if (m_hnsw->cur_element_count + 1 > m_hnsw->max_elements_) { + if (!resize(m_hnsw->max_elements_ + 500)) { + return false; + } + } + + try { + m_hnsw->addPoint(embedding.data(), label, false); + } catch (const std::exception &e) { + qWarning() << "ERROR: could not add embedding to hnswlib index:" << e.what(); + return false; + } + return true; +} + +void Embeddings::remove(qint64 label) +{ + if (!isLoaded()) { + qWarning() << "ERROR: attempting to remove an embedding when the embeddings are not open!"; + return; + } + + Q_ASSERT(m_hnsw); + try { + m_hnsw->markDelete(label); + } catch (const std::exception &e) { + qWarning() << "ERROR: could not add remove embedding from hnswlib index:" << e.what(); + } +} + +void Embeddings::clear() +{ + delete m_hnsw; + m_hnsw = nullptr; + delete m_space; + m_space = nullptr; +} + +std::vector Embeddings::search(const std::vector &embedding, int K) +{ + if (!isLoaded()) + return std::vector(); + + Q_ASSERT(m_hnsw); + std::priority_queue> result; + try { + result = m_hnsw->searchKnn(embedding.data(), K); + } catch (const std::exception &e) { + qWarning() << "ERROR: could not search hnswlib index:" << e.what(); + return std::vector(); + } + + std::vector neighbors; + neighbors.reserve(K); + + while(!result.empty()) { + neighbors.push_back(result.top().second); + result.pop(); + } + + // Reverse the neighbors, as the top of the priority queue is the farthest neighbor. + std::reverse(neighbors.begin(), neighbors.end()); + + return neighbors; +} diff --git a/gpt4all-chat/embeddings.h b/gpt4all-chat/embeddings.h new file mode 100644 index 00000000..88f87579 --- /dev/null +++ b/gpt4all-chat/embeddings.h @@ -0,0 +1,45 @@ +#ifndef EMBEDDINGS_H +#define EMBEDDINGS_H + +#include + +namespace hnswlib { + template + class HierarchicalNSW; + class InnerProductSpace; +} + +class Embeddings : public QObject +{ + Q_OBJECT +public: + Embeddings(QObject *parent); + virtual ~Embeddings(); + + bool load(); + bool load(qint64 maxElements); + bool save(); + bool isLoaded() const; + bool fileExists() const; + bool resize(qint64 size); + + // Adds the embedding and returns the label used + bool add(const std::vector &embedding, qint64 label); + + // Removes the embedding at label by marking it as unused + void remove(qint64 label); + + // Clears the embeddings + void clear(); + + // Performs a nearest neighbor search of the embeddings and returns a vector of labels + // for the K nearest neighbors of the given embedding + std::vector search(const std::vector &embedding, int K); + +private: + QString m_filePath; + hnswlib::InnerProductSpace *m_space; + hnswlib::HierarchicalNSW *m_hnsw; +}; + +#endif // EMBEDDINGS_H diff --git a/gpt4all-chat/embllm.cpp b/gpt4all-chat/embllm.cpp new file mode 100644 index 00000000..09bb20f8 --- /dev/null +++ b/gpt4all-chat/embllm.cpp @@ -0,0 +1,64 @@ +#include "embllm.h" +#include "modellist.h" + +EmbeddingLLM::EmbeddingLLM() + : QObject{nullptr} + , m_model{nullptr} +{ +} + +EmbeddingLLM::~EmbeddingLLM() +{ + delete m_model; + m_model = nullptr; +} + +bool EmbeddingLLM::loadModel() +{ + const EmbeddingModels *embeddingModels = ModelList::globalInstance()->embeddingModels(); + if (!embeddingModels->count()) + return false; + + const ModelInfo defaultModel = embeddingModels->defaultModelInfo(); + + QString filePath = defaultModel.dirpath + defaultModel.filename(); + QFileInfo fileInfo(filePath); + if (!fileInfo.exists()) { + qWarning() << "WARNING: Could not load sbert because file does not exist"; + m_model = nullptr; + return false; + } + + m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto"); + bool success = m_model->loadModel(filePath.toStdString()); + if (!success) { + qWarning() << "WARNING: Could not load sbert"; + delete m_model; + m_model = nullptr; + return false; + } + + if (m_model->implementation().modelType()[0] != 'B') { + qWarning() << "WARNING: Model type is not sbert"; + delete m_model; + m_model = nullptr; + return false; + } + return true; +} + +bool EmbeddingLLM::hasModel() const +{ + return m_model; +} + +std::vector EmbeddingLLM::generateEmbeddings(const QString &text) +{ + if (!hasModel() && !loadModel()) { + qWarning() << "WARNING: Could not load sbert model for embeddings"; + return std::vector(); + } + + Q_ASSERT(hasModel()); + return m_model->embedding(text.toStdString()); +} diff --git a/gpt4all-chat/embllm.h b/gpt4all-chat/embllm.h new file mode 100644 index 00000000..29148546 --- /dev/null +++ b/gpt4all-chat/embllm.h @@ -0,0 +1,27 @@ +#ifndef EMBLLM_H +#define EMBLLM_H + +#include +#include +#include "../gpt4all-backend/llmodel.h" + +class EmbeddingLLM : public QObject +{ + Q_OBJECT +public: + EmbeddingLLM(); + virtual ~EmbeddingLLM(); + + bool hasModel() const; + +public Q_SLOTS: + std::vector generateEmbeddings(const QString &text); + +private: + bool loadModel(); + +private: + LLModel *m_model = nullptr; +}; + +#endif // EMBLLM_H diff --git a/gpt4all-chat/hnswlib/bruteforce.h b/gpt4all-chat/hnswlib/bruteforce.h new file mode 100644 index 00000000..30b33ae9 --- /dev/null +++ b/gpt4all-chat/hnswlib/bruteforce.h @@ -0,0 +1,167 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace hnswlib { +template +class BruteforceSearch : public AlgorithmInterface { + public: + char *data_; + size_t maxelements_; + size_t cur_element_count; + size_t size_per_element_; + + size_t data_size_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::mutex index_lock; + + std::unordered_map dict_external_to_internal; + + + BruteforceSearch(SpaceInterface *s) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + } + + + BruteforceSearch(SpaceInterface *s, const std::string &location) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + loadIndex(location, s); + } + + + BruteforceSearch(SpaceInterface *s, size_t maxElements) { + maxelements_ = maxElements; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxElements * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + cur_element_count = 0; + } + + + ~BruteforceSearch() { + free(data_); + } + + + void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { + int idx; + { + std::unique_lock lock(index_lock); + + auto search = dict_external_to_internal.find(label); + if (search != dict_external_to_internal.end()) { + idx = search->second; + } else { + if (cur_element_count >= maxelements_) { + throw std::runtime_error("The number of elements exceeds the specified limit\n"); + } + idx = cur_element_count; + dict_external_to_internal[label] = idx; + cur_element_count++; + } + } + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + } + + + void removePoint(labeltype cur_external) { + size_t cur_c = dict_external_to_internal[cur_external]; + + dict_external_to_internal.erase(cur_external); + + labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + dict_external_to_internal[label] = cur_c; + memcpy(data_ + size_per_element_ * cur_c, + data_ + size_per_element_ * (cur_element_count-1), + data_size_+sizeof(labeltype)); + cur_element_count--; + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + assert(k <= cur_element_count); + std::priority_queue> topResults; + if (cur_element_count == 0) return topResults; + for (int i = 0; i < k; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); + if ((!isIdAllowed) || (*isIdAllowed)(label)) { + topResults.push(std::pair(dist, label)); + } + } + dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; + for (int i = k; i < cur_element_count; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + if (dist <= lastdist) { + labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); + if ((!isIdAllowed) || (*isIdAllowed)(label)) { + topResults.push(std::pair(dist, label)); + } + if (topResults.size() > k) + topResults.pop(); + + if (!topResults.empty()) { + lastdist = topResults.top().first; + } + } + } + return topResults; + } + + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, maxelements_); + writeBinaryPOD(output, size_per_element_); + writeBinaryPOD(output, cur_element_count); + + output.write(data_, maxelements_ * size_per_element_); + + output.close(); + } + + + void loadIndex(const std::string &location, SpaceInterface *s) { + std::ifstream input(location, std::ios::binary); + std::streampos position; + + readBinaryPOD(input, maxelements_); + readBinaryPOD(input, size_per_element_); + readBinaryPOD(input, cur_element_count); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxelements_ * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + + input.read(data_, maxelements_ * size_per_element_); + + input.close(); + } +}; +} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/hnswalg.h b/gpt4all-chat/hnswlib/hnswalg.h new file mode 100644 index 00000000..bef00170 --- /dev/null +++ b/gpt4all-chat/hnswlib/hnswalg.h @@ -0,0 +1,1271 @@ +#pragma once + +#include "visited_list_pool.h" +#include "hnswlib.h" +#include +#include +#include +#include +#include +#include + +namespace hnswlib { +typedef unsigned int tableint; +typedef unsigned int linklistsizeint; + +template +class HierarchicalNSW : public AlgorithmInterface { + public: + static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; + static const unsigned char DELETE_MARK = 0x01; + + size_t max_elements_{0}; + mutable std::atomic cur_element_count{0}; // current number of elements + size_t size_data_per_element_{0}; + size_t size_links_per_element_{0}; + mutable std::atomic num_deleted_{0}; // number of deleted elements + size_t M_{0}; + size_t maxM_{0}; + size_t maxM0_{0}; + size_t ef_construction_{0}; + size_t ef_{ 0 }; + + double mult_{0.0}, revSize_{0.0}; + int maxlevel_{0}; + + VisitedListPool *visited_list_pool_{nullptr}; + + // Locks operations with element by label value + mutable std::vector label_op_locks_; + + std::mutex global; + std::vector link_list_locks_; + + tableint enterpoint_node_{0}; + + size_t size_links_level0_{0}; + size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 }; + + char *data_level0_memory_{nullptr}; + char **linkLists_{nullptr}; + std::vector element_levels_; // keeps level of each element + + size_t data_size_{0}; + + DISTFUNC fstdistfunc_; + void *dist_func_param_{nullptr}; + + mutable std::mutex label_lookup_lock; // lock for label_lookup_ + std::unordered_map label_lookup_; + + std::default_random_engine level_generator_; + std::default_random_engine update_probability_generator_; + + mutable std::atomic metric_distance_computations{0}; + mutable std::atomic metric_hops{0}; + + bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions + + std::mutex deleted_elements_lock; // lock for deleted_elements + std::unordered_set deleted_elements; // contains internal ids of deleted elements + + + HierarchicalNSW(SpaceInterface *s) { + } + + + HierarchicalNSW( + SpaceInterface *s, + const std::string &location, + bool nmslib = false, + size_t max_elements = 0, + bool allow_replace_deleted = false) + : allow_replace_deleted_(allow_replace_deleted) { + loadIndex(location, s, max_elements); + } + + + HierarchicalNSW( + SpaceInterface *s, + size_t max_elements, + size_t M = 16, + size_t ef_construction = 200, + size_t random_seed = 100, + bool allow_replace_deleted = false) + : link_list_locks_(max_elements), + label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + element_levels_(max_elements), + allow_replace_deleted_(allow_replace_deleted) { + max_elements_ = max_elements; + num_deleted_ = 0; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction, M_); + ef_ = 10; + + level_generator_.seed(random_seed); + update_probability_generator_.seed(random_seed + 1); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + // initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + + ~HierarchicalNSW() { + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + delete visited_list_pool_; + } + + + struct CompareByFirst { + constexpr bool operator()(std::pair const& a, + std::pair const& b) const noexcept { + return a.first < b.first; + } + }; + + + void setEf(size_t ef) { + ef_ = ef; + } + + + inline std::mutex& getLabelOpMutex(labeltype label) const { + // calculate hash + size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); + return label_op_locks_[lock_id]; + } + + + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } + + + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } + + + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } + + + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } + + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + size_t getMaxElements() { + return max_elements_; + } + + size_t getCurrentElementCount() { + return cur_element_count; + } + + size_t getDeletedCount() { + return num_deleted_; + } + + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock lock(link_list_locks_[curNodeNum]); + + int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); +// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); +#endif + + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); +#endif + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); + + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); +#endif + + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); + + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; + if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); + } + + visited_array[ep_id] = visited_array_tag; + + while (!candidate_set.empty()) { + std::pair current_node_pair = candidate_set.top(); + + if ((-current_node_pair.first) > lowerBound && + (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + break; + } + candidate_set.pop(); + + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); +// bool cur_node_deleted = isMarkedDeleted(current_node_id); + if (collect_metrics) { + metric_hops++; + metric_distance_computations+=size; + } + +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif + + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); +// if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0); //////////// +#endif + if (!(visited_array[candidate_id] == visited_array_tag)) { + visited_array[candidate_id] = visited_array_tag; + + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); +#ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_, /////////// + _MM_HINT_T0); //////////////////////// +#endif + + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + top_candidates.emplace(dist, candidate_id); + + if (top_candidates.size() > ef) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + } + + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } + + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; + } + + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_); + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) { + return_list.push_back(curent_pair); + } + } + + for (std::pair curent_pair : return_list) { + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } + + + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } + + + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } + + + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + } + + + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { + return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); + } + + + tableint mutuallyConnectNewElement( + const void *data_point, + tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, + bool isUpdate) { + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + + std::vector selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } + + tableint next_closest_entry_point = selectedNeighbors.back(); + + { + // lock only during the update + // because during the addition the lock for cur_c is already acquired + std::unique_lock lock(link_list_locks_[cur_c], std::defer_lock); + if (isUpdate) { + lock.lock(); + } + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur && !isUpdate) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur, selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx] && !isUpdate) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + data[idx] = selectedNeighbors[idx]; + } + } + + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); + + size_t sz_link_list_other = getListCount(ll_other); + + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + tableint *data = (tableint *) (ll_other + 1); + + bool is_cur_c_present = false; + if (isUpdate) { + for (size_t j = 0; j < sz_link_list_other; j++) { + if (data[j] == cur_c) { + is_cur_c_present = true; + break; + } + } + } + + // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. + if (!is_cur_c_present) { + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } + + getNeighborsByHeuristic2(candidates, Mcurmax); + + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } + } + if (indx >= 0) { + data[indx] = cur_c; + } */ + } + } + } + + return next_closest_entry_point; + } + + + void resizeIndex(size_t new_max_elements) { + if (new_max_elements < cur_element_count) + throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); + + delete visited_list_pool_; + visited_list_pool_ = new VisitedListPool(1, new_max_elements); + + element_levels_.resize(new_max_elements); + + std::vector(new_max_elements).swap(link_list_locks_); + + // Reallocate base layer + char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + data_level0_memory_ = data_level0_memory_new; + + // Reallocate all other layers + char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + linkLists_ = linkLists_new; + + max_elements_ = new_max_elements; + } + + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + // get file size: + input.seekg(0, input.end); + std::streampos total_filesize = input.tellg(); + input.seekg(0, input.beg); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements = max_elements_i; + if (max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos = input.tellg(); + + /// Optional - check if index is ok: + input.seekg(cur_element_count * size_data_per_element_, input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if (input.tellg() < 0 || input.tellg() >= total_filesize) { + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize, input.cur); + } + } + + // throw exception if it either corrupted or old index + if (input.tellg() != total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + /// Optional check end + + input.seekg(pos, input.beg); + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)] = i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + for (size_t i = 0; i < cur_element_count; i++) { + if (isMarkedDeleted(i)) { + num_deleted_ += 1; + if (allow_replace_deleted_) deleted_elements.insert(i); + } + } + + input.close(); + + return; + } + + + template + std::vector getDataByLabel(labeltype label) const { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + char* data_ptrv = getDataByInternalId(internalId); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + + + /* + * Marks an element with the given label deleted, does NOT really change the current graph. + */ + void markDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + markDeletedInternal(internalId); + } + + + /* + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ + void markDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (!isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + num_deleted_ += 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.insert(internalId); + } + } else { + throw std::runtime_error("The requested to delete element is already deleted"); + } + } + + + /* + * Removes the deleted mark of the node, does NOT really change the current graph. + * + * Note: the method is not safe to use when replacement of deleted elements is enabled, + * because elements marked as deleted can be completely removed by addPoint + */ + void unmarkDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + unmarkDeletedInternal(internalId); + } + + + + /* + * Remove the deleted mark of the node. + */ + void unmarkDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; + *ll_cur &= ~DELETE_MARK; + num_deleted_ -= 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.erase(internalId); + } + } else { + throw std::runtime_error("The requested to undelete element is not deleted"); + } + } + + + /* + * Checks the first 16 bits of the memory to see if the element is marked deleted. + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; + return *ll_cur & DELETE_MARK; + } + + + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + + /* + * Adds point. Updates the point if it is already in the index. + * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point + */ + void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { + if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { + throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); + } + + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + if (!replace_deleted) { + addPoint(data_point, label, -1); + return; + } + // check if there is vacant place + tableint internal_id_replaced; + std::unique_lock lock_deleted_elements(deleted_elements_lock); + bool is_vacant_place = !deleted_elements.empty(); + if (is_vacant_place) { + internal_id_replaced = *deleted_elements.begin(); + deleted_elements.erase(internal_id_replaced); + } + lock_deleted_elements.unlock(); + + // if there is no vacant place then add or update point + // else add point to vacant place + if (!is_vacant_place) { + addPoint(data_point, label, -1); + } else { + // we assume that there are no concurrent operations on deleted element + labeltype label_replaced = getExternalLabel(internal_id_replaced); + setExternalLabel(internal_id_replaced, label); + + std::unique_lock lock_table(label_lookup_lock); + label_lookup_.erase(label_replaced); + label_lookup_[label] = internal_id_replaced; + lock_table.unlock(); + + unmarkDeletedInternal(internal_id_replaced); + updatePoint(data_point, internal_id_replaced, 1.0); + } + } + + + void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + // update the feature vector associated with existing point with new vector + memcpy(getDataByInternalId(internalId), dataPoint, data_size_); + + int maxLevelCopy = maxlevel_; + tableint entryPointCopy = enterpoint_node_; + // If point to be updated is entry point and graph just contains single element then just return. + if (entryPointCopy == internalId && cur_element_count == 1) + return; + + int elemLevel = element_levels_[internalId]; + std::uniform_real_distribution distribution(0.0, 1.0); + for (int layer = 0; layer <= elemLevel; layer++) { + std::unordered_set sCand; + std::unordered_set sNeigh; + std::vector listOneHop = getConnectionsWithLock(internalId, layer); + if (listOneHop.size() == 0) + continue; + + sCand.insert(internalId); + + for (auto&& elOneHop : listOneHop) { + sCand.insert(elOneHop); + + if (distribution(update_probability_generator_) > updateNeighborProbability) + continue; + + sNeigh.insert(elOneHop); + + std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); + for (auto&& elTwoHop : listTwoHop) { + sCand.insert(elTwoHop); + } + } + + for (auto&& neigh : sNeigh) { + // if (neigh == internalId) + // continue; + + std::priority_queue, std::vector>, CompareByFirst> candidates; + size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 + size_t elementsToKeep = std::min(ef_construction_, size); + for (auto&& cand : sCand) { + if (cand == neigh) + continue; + + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); + if (candidates.size() < elementsToKeep) { + candidates.emplace(distance, cand); + } else { + if (distance < candidates.top().first) { + candidates.pop(); + candidates.emplace(distance, cand); + } + } + } + + // Retrieve neighbours using heuristic and set connections. + getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); + + { + std::unique_lock lock(link_list_locks_[neigh]); + linklistsizeint *ll_cur; + ll_cur = get_linklist_at_level(neigh, layer); + size_t candSize = candidates.size(); + setListCount(ll_cur, candSize); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < candSize; idx++) { + data[idx] = candidates.top().second; + candidates.pop(); + } + } + } + } + + repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + } + + + void repairConnectionsForUpdate( + const void *dataPoint, + tableint entryPointInternalId, + tableint dataPointInternalId, + int dataPointLevel, + int maxLevel) { + tableint currObj = entryPointInternalId; + if (dataPointLevel < maxLevel) { + dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxLevel; level > dataPointLevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); +#endif + for (int i = 0; i < size; i++) { +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); +#endif + tableint cand = datal[i]; + dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + if (dataPointLevel > maxLevel) + throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); + + for (int level = dataPointLevel; level >= 0; level--) { + std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( + currObj, dataPoint, level); + + std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; + while (topCandidates.size() > 0) { + if (topCandidates.top().second != dataPointInternalId) + filteredTopCandidates.push(topCandidates.top()); + + topCandidates.pop(); + } + + // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. + // To prevent self loops, the `topCandidates` is filtered and thus can be empty. + if (filteredTopCandidates.size() > 0) { + bool epDeleted = isMarkedDeleted(entryPointInternalId); + if (epDeleted) { + filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); + if (filteredTopCandidates.size() > ef_construction_) + filteredTopCandidates.pop(); + } + + currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + } + } + } + + + std::vector getConnectionsWithLock(tableint internalId, int level) { + std::unique_lock lock(link_list_locks_[internalId]); + unsigned int *data = get_linklist_at_level(internalId, level); + int size = getListCount(data); + std::vector result(size); + tableint *ll = (tableint *) (data + 1); + memcpy(result.data(), ll, size * sizeof(tableint)); + return result; + } + + + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + // Checking if the element with the same label already exists + // if so, updating it *instead* of creating a new element. + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + tableint existingInternalId = search->second; + if (allow_replace_deleted_) { + if (isMarkedDeleted(existingInternalId)) { + throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); + } + } + lock_table.unlock(); + + if (isMarkedDeleted(existingInternalId)) { + unmarkDeletedInternal(existingInternalId); + } + updatePoint(data_point, existingInternalId, 1.0); + + return existingInternalId; + } + + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + } + + cur_c = cur_element_count; + cur_element_count++; + label_lookup_[label] = cur_c; + } + + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + if ((signed)currObj != -1) { + if (curlevel < maxlevelcopy) { + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj, level); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); + } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + } + + // Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (num_deleted_) { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } else { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + } + + + void checkIntegrity() { + int connections_checked = 0; + std::vector inbound_connections_num(cur_element_count, 0); + for (int i = 0; i < cur_element_count; i++) { + for (int l = 0; l <= element_levels_[i]; l++) { + linklistsizeint *ll_cur = get_linklist_at_level(i, l); + int size = getListCount(ll_cur); + tableint *data = (tableint *) (ll_cur + 1); + std::unordered_set s; + for (int j = 0; j < size; j++) { + assert(data[j] > 0); + assert(data[j] < cur_element_count); + assert(data[j] != i); + inbound_connections_num[data[j]]++; + s.insert(data[j]); + connections_checked++; + } + assert(s.size() == size); + } + } + if (cur_element_count > 1) { + int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; + for (int i=0; i < cur_element_count; i++) { + assert(inbound_connections_num[i] > 0); + min1 = std::min(inbound_connections_num[i], min1); + max1 = std::max(inbound_connections_num[i], max1); + } + std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; + } + std::cout << "integrity ok, checked " << connections_checked << " connections\n"; + } +}; +} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/hnswlib.h b/gpt4all-chat/hnswlib/hnswlib.h new file mode 100644 index 00000000..fb7118fa --- /dev/null +++ b/gpt4all-chat/hnswlib/hnswlib.h @@ -0,0 +1,199 @@ +#pragma once +#ifndef NO_MANUAL_VECTORIZATION +#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) +#define USE_SSE +#ifdef __AVX__ +#define USE_AVX +#ifdef __AVX512F__ +#define USE_AVX512 +#endif +#endif +#endif +#endif + +#if defined(USE_AVX) || defined(USE_SSE) +#ifdef _MSC_VER +#include +#include +void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { + __cpuidex(out, eax, ecx); +} +static __int64 xgetbv(unsigned int x) { + return _xgetbv(x); +} +#else +#include +#include +#include +static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { + __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); +} +static uint64_t xgetbv(unsigned int index) { + uint32_t eax, edx; + __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); + return ((uint64_t)edx << 32) | eax; +} +#endif + +#if defined(USE_AVX512) +#include +#endif + +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#define PORTABLE_ALIGN64 __attribute__((aligned(64))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#define PORTABLE_ALIGN64 __declspec(align(64)) +#endif + +// Adapted from https://github.com/Mysticial/FeatureDetector +#define _XCR_XFEATURE_ENABLED_MASK 0 + +static bool AVXCapable() { + int cpuInfo[4]; + + // CPU support + cpuid(cpuInfo, 0, 0); + int nIds = cpuInfo[0]; + + bool HW_AVX = false; + if (nIds >= 0x00000001) { + cpuid(cpuInfo, 0x00000001, 0); + HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0; + } + + // OS support + cpuid(cpuInfo, 1, 0); + + bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; + bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; + + bool avxSupported = false; + if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { + uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); + avxSupported = (xcrFeatureMask & 0x6) == 0x6; + } + return HW_AVX && avxSupported; +} + +static bool AVX512Capable() { + if (!AVXCapable()) return false; + + int cpuInfo[4]; + + // CPU support + cpuid(cpuInfo, 0, 0); + int nIds = cpuInfo[0]; + + bool HW_AVX512F = false; + if (nIds >= 0x00000007) { // AVX512 Foundation + cpuid(cpuInfo, 0x00000007, 0); + HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; + } + + // OS support + cpuid(cpuInfo, 1, 0); + + bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; + bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; + + bool avx512Supported = false; + if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { + uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); + avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6; + } + return HW_AVX512F && avx512Supported; +} +#endif + +#include +#include +#include +#include + +namespace hnswlib { +typedef size_t labeltype; + +// This can be extended to store state for filtering (e.g. from a std::set) +class BaseFilterFunctor { + public: + virtual bool operator()(hnswlib::labeltype id) { return true; } +}; + +template +class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; + } +}; + +template +static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); +} + +template +static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); +} + +template +using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + +template +class SpaceInterface { + public: + // virtual void search(void *); + virtual size_t get_data_size() = 0; + + virtual DISTFUNC get_dist_func() = 0; + + virtual void *get_dist_func_param() = 0; + + virtual ~SpaceInterface() {} +}; + +template +class AlgorithmInterface { + public: + virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; + + virtual std::priority_queue> + searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; + + // Return k nearest neighbor in the order of closer fist + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; + + virtual void saveIndex(const std::string &location) = 0; + virtual ~AlgorithmInterface(){ + } +}; + +template +std::vector> +AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + BaseFilterFunctor* isIdAllowed) const { + std::vector> result; + + // here searchKnn returns the result in the order of further first + auto ret = searchKnn(query_data, k, isIdAllowed); + { + size_t sz = ret.size(); + result.resize(sz); + while (!ret.empty()) { + result[--sz] = ret.top(); + ret.pop(); + } + } + + return result; +} +} // namespace hnswlib + +#include "space_l2.h" +#include "space_ip.h" +#include "bruteforce.h" +#include "hnswalg.h" diff --git a/gpt4all-chat/hnswlib/space_ip.h b/gpt4all-chat/hnswlib/space_ip.h new file mode 100644 index 00000000..2b1c359e --- /dev/null +++ b/gpt4all-chat/hnswlib/space_ip.h @@ -0,0 +1,375 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + +static float +InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; + } + return res; +} + +static float +InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { + return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); +} + +#if defined(USE_AVX) + +// Favor using AVX if available. +static float +InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + __m128 v1, v2; + __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return sum; +} + +static float +InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_SSE) + +static float +InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return sum; +} + +static float +InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); +} + +#endif + + +#if defined(USE_AVX512) + +static float +InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN64 TmpRes[16]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m512 sum512 = _mm512_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m512 v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + __m512 v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); + } + + _mm512_store_ps(TmpRes, sum512); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; + + return sum; +} + +static float +InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_AVX) + +static float +InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } + + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + + return sum; +} + +static float +InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_SSE) + +static float +InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + + return sum; +} + +static float +InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); +} + +#endif + +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; +static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; +static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; +static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; + +static float +InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + return 1.0f - (res + res_tail); +} + +static float +InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + + return 1.0f - (res + res_tail); +} +#endif + +class InnerProductSpace : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #endif + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; + } + #endif + + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + +~InnerProductSpace() {} +}; + +} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/space_l2.h b/gpt4all-chat/hnswlib/space_l2.h new file mode 100644 index 00000000..834d19f7 --- /dev/null +++ b/gpt4all-chat/hnswlib/space_l2.h @@ -0,0 +1,324 @@ +#pragma once +#include "hnswlib.h" + +namespace hnswlib { + +static float +L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + float res = 0; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; + res += t * t; + } + return (res); +} + +#if defined(USE_AVX512) + +// Favor using AVX512 if available. +static float +L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN64 TmpRes[16]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m512 diff, v1, v2; + __m512 sum = _mm512_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + diff = _mm512_sub_ps(v1, v2); + // sum = _mm512_fmadd_ps(diff, diff, sum); + sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); + } + + _mm512_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + + TmpRes[13] + TmpRes[14] + TmpRes[15]; + + return (res); +} +#endif + +#if defined(USE_AVX) + +// Favor using AVX if available. +static float +L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + } + + _mm256_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; +} + +#endif + +#if defined(USE_SSE) + +static float +L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} +#endif + +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; + +static float +L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + return (res + res_tail); +} +#endif + + +#if defined(USE_SSE) +static float +L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + + size_t qty4 = qty >> 2; + + const float *pEnd1 = pVect1 + (qty4 << 2); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + } + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} + +static float +L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + + return (res + res_tail); +} +#endif + +class L2Space : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif + + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2Space() {} +}; + +static int +L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + } + return (res); +} + +static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int res = 0; + unsigned char* a = (unsigned char*)pVect1; + unsigned char* b = (unsigned char*)pVect2; + + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + } + return (res); +} + +class L2SpaceI : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2SpaceI(size_t dim) { + if (dim % 4 == 0) { + fstdistfunc_ = L2SqrI4x; + } else { + fstdistfunc_ = L2SqrI; + } + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + + void *get_dist_func_param() { + return &dim_; + } + + ~L2SpaceI() {} +}; +} // namespace hnswlib diff --git a/gpt4all-chat/hnswlib/visited_list_pool.h b/gpt4all-chat/hnswlib/visited_list_pool.h new file mode 100644 index 00000000..2e201ec4 --- /dev/null +++ b/gpt4all-chat/hnswlib/visited_list_pool.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include + +namespace hnswlib { +typedef unsigned short int vl_type; + +class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; + + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } + + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); + curV++; + } + } + + ~VisitedList() { delete[] mass; } +}; +/////////////////////////////////////////////////////////// +// +// Class for multi-threaded pool-management of VisitedLists +// +///////////////////////////////////////////////////////// + +class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; + + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } + + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { + std::unique_lock lock(poolguard); + if (pool.size() > 0) { + rez = pool.front(); + pool.pop_front(); + } else { + rez = new VisitedList(numelements); + } + } + rez->reset(); + return rez; + } + + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + } + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + } +}; +} // namespace hnswlib diff --git a/gpt4all-chat/localdocs.cpp b/gpt4all-chat/localdocs.cpp index 2bbb0cec..1f37a207 100644 --- a/gpt4all-chat/localdocs.cpp +++ b/gpt4all-chat/localdocs.cpp @@ -24,24 +24,50 @@ LocalDocs::LocalDocs() &Database::removeFolder, Qt::QueuedConnection); connect(this, &LocalDocs::requestChunkSizeChange, m_database, &Database::changeChunkSize, Qt::QueuedConnection); + + // Connections for modifying the model and keeping it updated with the database + connect(m_database, &Database::updateInstalled, + m_localDocsModel, &LocalDocsModel::updateInstalled, Qt::QueuedConnection); + connect(m_database, &Database::updateIndexing, + m_localDocsModel, &LocalDocsModel::updateIndexing, Qt::QueuedConnection); + connect(m_database, &Database::updateCurrentDocsToIndex, + m_localDocsModel, &LocalDocsModel::updateCurrentDocsToIndex, Qt::QueuedConnection); + connect(m_database, &Database::updateTotalDocsToIndex, + m_localDocsModel, &LocalDocsModel::updateTotalDocsToIndex, Qt::QueuedConnection); + connect(m_database, &Database::subtractCurrentBytesToIndex, + m_localDocsModel, &LocalDocsModel::subtractCurrentBytesToIndex, Qt::QueuedConnection); + connect(m_database, &Database::updateCurrentBytesToIndex, + m_localDocsModel, &LocalDocsModel::updateCurrentBytesToIndex, Qt::QueuedConnection); + connect(m_database, &Database::updateTotalBytesToIndex, + m_localDocsModel, &LocalDocsModel::updateTotalBytesToIndex, Qt::QueuedConnection); + connect(m_database, &Database::addCollectionItem, + m_localDocsModel, &LocalDocsModel::addCollectionItem, Qt::QueuedConnection); + connect(m_database, &Database::removeFolderById, + m_localDocsModel, &LocalDocsModel::removeFolderById, Qt::QueuedConnection); + connect(m_database, &Database::removeCollectionItem, + m_localDocsModel, &LocalDocsModel::removeCollectionItem, Qt::QueuedConnection); connect(m_database, &Database::collectionListUpdated, - m_localDocsModel, &LocalDocsModel::handleCollectionListUpdated, Qt::QueuedConnection); + m_localDocsModel, &LocalDocsModel::collectionListUpdated, Qt::QueuedConnection); + + connect(qApp, &QCoreApplication::aboutToQuit, this, &LocalDocs::aboutToQuit); +} + +void LocalDocs::aboutToQuit() +{ + delete m_database; + m_database = nullptr; } void LocalDocs::addFolder(const QString &collection, const QString &path) { const QUrl url(path); const QString localPath = url.isLocalFile() ? url.toLocalFile() : path; - // Add a placeholder collection that is not installed yet - CollectionItem i; - i.collection = collection; - i.folder_path = localPath; - m_localDocsModel->addCollectionItem(i); emit requestAddFolder(collection, localPath); } void LocalDocs::removeFolder(const QString &collection, const QString &path) { + m_localDocsModel->removeCollectionPath(collection, path); emit requestRemoveFolder(collection, path); } diff --git a/gpt4all-chat/localdocs.h b/gpt4all-chat/localdocs.h index 07814fc4..dd08987a 100644 --- a/gpt4all-chat/localdocs.h +++ b/gpt4all-chat/localdocs.h @@ -23,6 +23,7 @@ public: public Q_SLOTS: void handleChunkSizeChanged(); + void aboutToQuit(); Q_SIGNALS: void requestAddFolder(const QString &collection, const QString &path); @@ -36,7 +37,6 @@ private: private: explicit LocalDocs(); - ~LocalDocs() {} friend class MyLocalDocs; }; diff --git a/gpt4all-chat/localdocsmodel.cpp b/gpt4all-chat/localdocsmodel.cpp index 8bd4fffd..e12e3773 100644 --- a/gpt4all-chat/localdocsmodel.cpp +++ b/gpt4all-chat/localdocsmodel.cpp @@ -1,5 +1,27 @@ #include "localdocsmodel.h" +#include "localdocs.h" + +LocalDocsCollectionsModel::LocalDocsCollectionsModel(QObject *parent) + : QSortFilterProxyModel(parent) +{ + setSourceModel(LocalDocs::globalInstance()->localDocsModel()); +} + +bool LocalDocsCollectionsModel::filterAcceptsRow(int sourceRow, + const QModelIndex &sourceParent) const +{ + QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); + const QString collection = sourceModel()->data(index, LocalDocsModel::CollectionRole).toString(); + return m_collections.contains(collection); +} + +void LocalDocsCollectionsModel::setCollections(const QList &collections) +{ + m_collections = collections; + invalidateFilter(); +} + LocalDocsModel::LocalDocsModel(QObject *parent) : QAbstractListModel(parent) { @@ -24,6 +46,16 @@ QVariant LocalDocsModel::data(const QModelIndex &index, int role) const return item.folder_path; case InstalledRole: return item.installed; + case IndexingRole: + return item.indexing; + case CurrentDocsToIndexRole: + return item.currentDocsToIndex; + case TotalDocsToIndexRole: + return item.totalDocsToIndex; + case CurrentBytesToIndexRole: + return quint64(item.currentBytesToIndex); + case TotalBytesToIndexRole: + return quint64(item.totalBytesToIndex); } return QVariant(); @@ -35,9 +67,98 @@ QHash LocalDocsModel::roleNames() const roles[CollectionRole] = "collection"; roles[FolderPathRole] = "folder_path"; roles[InstalledRole] = "installed"; + roles[IndexingRole] = "indexing"; + roles[CurrentDocsToIndexRole] = "currentDocsToIndex"; + roles[TotalDocsToIndexRole] = "totalDocsToIndex"; + roles[CurrentBytesToIndexRole] = "currentBytesToIndex"; + roles[TotalBytesToIndexRole] = "totalBytesToIndex"; return roles; } +void LocalDocsModel::updateInstalled(int folder_id, bool b) +{ + for (int i = 0; i < m_collectionList.size(); ++i) { + if (m_collectionList.at(i).folder_id != folder_id) + continue; + + m_collectionList[i].installed = b; + emit collectionItemUpdated(i, m_collectionList[i]); + emit dataChanged(this->index(i), this->index(i), {InstalledRole}); + } +} + +void LocalDocsModel::updateIndexing(int folder_id, bool b) +{ + for (int i = 0; i < m_collectionList.size(); ++i) { + if (m_collectionList.at(i).folder_id != folder_id) + continue; + + m_collectionList[i].indexing = b; + emit collectionItemUpdated(i, m_collectionList[i]); + emit dataChanged(this->index(i), this->index(i), {IndexingRole}); + } +} + +void LocalDocsModel::updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex) +{ + for (int i = 0; i < m_collectionList.size(); ++i) { + if (m_collectionList.at(i).folder_id != folder_id) + continue; + + m_collectionList[i].currentDocsToIndex = currentDocsToIndex; + emit collectionItemUpdated(i, m_collectionList[i]); + emit dataChanged(this->index(i), this->index(i), {CurrentDocsToIndexRole}); + } +} + +void LocalDocsModel::updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex) +{ + for (int i = 0; i < m_collectionList.size(); ++i) { + if (m_collectionList.at(i).folder_id != folder_id) + continue; + + m_collectionList[i].totalDocsToIndex = totalDocsToIndex; + emit collectionItemUpdated(i, m_collectionList[i]); + emit dataChanged(this->index(i), this->index(i), {TotalDocsToIndexRole}); + } +} + +void LocalDocsModel::subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes) +{ + for (int i = 0; i < m_collectionList.size(); ++i) { + if (m_collectionList.at(i).folder_id != folder_id) + continue; + + m_collectionList[i].currentBytesToIndex -= subtractedBytes; + emit collectionItemUpdated(i, m_collectionList[i]); + emit dataChanged(this->index(i), this->index(i), {CurrentBytesToIndexRole}); + } +} + +void LocalDocsModel::updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex) +{ + for (int i = 0; i < m_collectionList.size(); ++i) { + if (m_collectionList.at(i).folder_id != folder_id) + continue; + + m_collectionList[i].currentBytesToIndex = currentBytesToIndex; + emit collectionItemUpdated(i, m_collectionList[i]); + emit dataChanged(this->index(i), this->index(i), {CurrentBytesToIndexRole}); + } +} + +void LocalDocsModel::updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex) +{ + for (int i = 0; i < m_collectionList.size(); ++i) { + if (m_collectionList.at(i).folder_id != folder_id) + continue; + + m_collectionList[i].totalBytesToIndex = totalBytesToIndex; + emit collectionItemUpdated(i, m_collectionList[i]); + emit dataChanged(this->index(i), this->index(i), {TotalBytesToIndexRole}); + } +} + void LocalDocsModel::addCollectionItem(const CollectionItem &item) { beginInsertRows(QModelIndex(), m_collectionList.size(), m_collectionList.size()); @@ -45,7 +166,46 @@ void LocalDocsModel::addCollectionItem(const CollectionItem &item) endInsertRows(); } -void LocalDocsModel::handleCollectionListUpdated(const QList &collectionList) +void LocalDocsModel::removeFolderById(int folder_id) +{ + for (int i = 0; i < m_collectionList.size();) { + if (m_collectionList.at(i).folder_id == folder_id) { + beginRemoveRows(QModelIndex(), i, i); + m_collectionList.removeAt(i); + endRemoveRows(); + } else { + ++i; + } + } +} + +void LocalDocsModel::removeCollectionPath(const QString &name, const QString &path) +{ + for (int i = 0; i < m_collectionList.size();) { + if (m_collectionList.at(i).collection == name && m_collectionList.at(i).folder_path == path) { + beginRemoveRows(QModelIndex(), i, i); + m_collectionList.removeAt(i); + endRemoveRows(); + } else { + ++i; + } + } +} + +void LocalDocsModel::removeCollectionItem(const QString &collectionName) +{ + for (int i = 0; i < m_collectionList.size();) { + if (m_collectionList.at(i).collection == collectionName) { + beginRemoveRows(QModelIndex(), i, i); + m_collectionList.removeAt(i); + endRemoveRows(); + } else { + ++i; + } + } +} + +void LocalDocsModel::collectionListUpdated(const QList &collectionList) { beginResetModel(); m_collectionList = collectionList; diff --git a/gpt4all-chat/localdocsmodel.h b/gpt4all-chat/localdocsmodel.h index f3b30e6b..41d37bd2 100644 --- a/gpt4all-chat/localdocsmodel.h +++ b/gpt4all-chat/localdocsmodel.h @@ -4,6 +4,22 @@ #include #include "database.h" +class LocalDocsCollectionsModel : public QSortFilterProxyModel +{ + Q_OBJECT +public: + explicit LocalDocsCollectionsModel(QObject *parent); + +public Q_SLOTS: + void setCollections(const QList &collections); + +protected: + bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override; + +private: + QList m_collections; +}; + class LocalDocsModel : public QAbstractListModel { Q_OBJECT @@ -12,7 +28,13 @@ public: enum Roles { CollectionRole = Qt::UserRole + 1, FolderPathRole, - InstalledRole + InstalledRole, + IndexingRole, + EmbeddingRole, + CurrentDocsToIndexRole, + TotalDocsToIndexRole, + CurrentBytesToIndexRole, + TotalBytesToIndexRole }; explicit LocalDocsModel(QObject *parent = nullptr); @@ -20,9 +42,25 @@ public: QVariant data(const QModelIndex &index, int role) const override; QHash roleNames() const override; +Q_SIGNALS: + void collectionItemUpdated(int index, const CollectionItem& item); + public Q_SLOTS: + void updateInstalled(int folder_id, bool b); + void updateIndexing(int folder_id, bool b); + void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex); + void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex); + void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes); + void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex); + void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex); void addCollectionItem(const CollectionItem &item); - void handleCollectionListUpdated(const QList &collectionList); + void removeFolderById(int folder_id); + void removeCollectionPath(const QString &name, const QString &path); + void removeCollectionItem(const QString &collectionName); + void collectionListUpdated(const QList &collectionList); + +private: + void updateItem(int index, const CollectionItem& item); private: QList m_collectionList; diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index fceee50a..fa3e8430 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -325,6 +325,10 @@ Window { anchors.centerIn: parent width: Math.min(1280, window.width - (window.width * .1)) height: window.height - (window.height * .1) + onDownloadClicked: { + downloadNewModels.showEmbeddingModels = true + downloadNewModels.open() + } } Button { @@ -652,6 +656,7 @@ Window { width: Math.min(600, 0.3 * window.width) height: window.height - y onDownloadClicked: { + downloadNewModels.showEmbeddingModels = false downloadNewModels.open() } onAboutClicked: { @@ -818,11 +823,11 @@ Window { color: theme.textAccent text: { switch (currentChat.responseState) { - case Chat.ResponseStopped: return "response stopped ..."; - case Chat.LocalDocsRetrieval: return "retrieving " + currentChat.collectionList.join(", ") + " ..."; - case Chat.LocalDocsProcessing: return "processing " + currentChat.collectionList.join(", ") + " ..."; - case Chat.PromptProcessing: return "processing ..." - case Chat.ResponseGeneration: return "generating response ..."; + case Chat.ResponseStopped: return qsTr("response stopped ..."); + case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: ") + currentChat.collectionList.join(", ") + " ..."; + case Chat.LocalDocsProcessing: return qsTr("searching localdocs: ") + currentChat.collectionList.join(", ") + " ..."; + case Chat.PromptProcessing: return qsTr("processing ...") + case Chat.ResponseGeneration: return qsTr("generating response ..."); default: return ""; // handle unexpected values } } diff --git a/gpt4all-chat/metadata/models2.json b/gpt4all-chat/metadata/models2.json index 757774bf..75b8006f 100644 --- a/gpt4all-chat/metadata/models2.json +++ b/gpt4all-chat/metadata/models2.json @@ -138,7 +138,7 @@ "type": "Replit", "systemPrompt": " ", "promptTemplate": "%1", - "description": "Trained on subset of the Stack
  • Code completion based
  • Licensed for commercial use
", + "description": "Trained on subset of the Stack
  • Code completion based
  • Licensed for commercial use
  • WARNING: Not available for chat GUI
", "url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-q4_0.gguf" }, { @@ -155,7 +155,7 @@ "type": "Starcoder", "systemPrompt": " ", "promptTemplate": "%1", - "description": "Trained on subset of the Stack
  • Code completion based
", + "description": "Trained on subset of the Stack
  • Code completion based
  • WARNING: Not available for chat GUI
", "url": "https://gpt4all.io/models/gguf/starcoder-q4_0.gguf" }, { @@ -172,7 +172,7 @@ "type": "LLaMA", "systemPrompt": " ", "promptTemplate": "%1", - "description": "Code completion model", + "description": "Trained on collection of Python and TypeScript
  • Code completion based
  • WARNING: Not available for chat GUI
  • ", "url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf" }, { @@ -184,11 +184,11 @@ "filesize": "45887744", "requires": "2.5.0", "ramrequired": "1", - "parameters": "1 million", + "parameters": "40 million", "quant": "f16", "type": "Bert", "systemPrompt": " ", - "description": "Sbert
    • For embeddings", + "description": "LocalDocs text embeddings model
      • Necessary for LocalDocs feature
      • Used for retrieval augmented generation (RAG)", "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf" }, { diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 63df2905..34e3d823 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -139,6 +139,50 @@ void ModelInfo::setSystemPrompt(const QString &p) m_systemPrompt = p; } +EmbeddingModels::EmbeddingModels(QObject *parent) + : QSortFilterProxyModel(parent) +{ + connect(this, &EmbeddingModels::rowsInserted, this, &EmbeddingModels::countChanged); + connect(this, &EmbeddingModels::rowsRemoved, this, &EmbeddingModels::countChanged); + connect(this, &EmbeddingModels::modelReset, this, &EmbeddingModels::countChanged); + connect(this, &EmbeddingModels::layoutChanged, this, &EmbeddingModels::countChanged); +} + +bool EmbeddingModels::filterAcceptsRow(int sourceRow, + const QModelIndex &sourceParent) const +{ + QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); + bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool(); + bool isEmbedding = sourceModel()->data(index, ModelList::FilenameRole).toString() == "all-MiniLM-L6-v2-f16.gguf"; + return isInstalled && isEmbedding; +} + +int EmbeddingModels::count() const +{ + return rowCount(); +} + +ModelInfo EmbeddingModels::defaultModelInfo() const +{ + if (!sourceModel()) + return ModelInfo(); + + const ModelList *sourceListModel = qobject_cast(sourceModel()); + if (!sourceListModel) + return ModelInfo(); + + const int rows = sourceListModel->rowCount(); + for (int i = 0; i < rows; ++i) { + QModelIndex sourceIndex = sourceListModel->index(i, 0); + if (filterAcceptsRow(i, sourceIndex.parent())) { + const QString id = sourceListModel->data(sourceIndex, ModelList::IdRole).toString(); + return sourceListModel->modelInfo(id); + } + } + + return ModelInfo(); +} + InstalledModels::InstalledModels(QObject *parent) : QSortFilterProxyModel(parent) { @@ -153,7 +197,8 @@ bool InstalledModels::filterAcceptsRow(int sourceRow, { QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool(); - return isInstalled; + bool showInGUI = !sourceModel()->data(index, ModelList::DisableGUIRole).toBool(); + return isInstalled && showInGUI; } int InstalledModels::count() const @@ -178,8 +223,7 @@ bool DownloadableModels::filterAcceptsRow(int sourceRow, bool withinLimit = sourceRow < (m_expanded ? sourceModel()->rowCount() : m_limit); QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); bool isDownloadable = !sourceModel()->data(index, ModelList::DescriptionRole).toString().isEmpty(); - bool showInGUI = !sourceModel()->data(index, ModelList::DisableGUIRole).toBool(); - return withinLimit && isDownloadable && showInGUI; + return withinLimit && isDownloadable; } int DownloadableModels::count() const @@ -210,10 +254,12 @@ ModelList *ModelList::globalInstance() ModelList::ModelList() : QAbstractListModel(nullptr) + , m_embeddingModels(new EmbeddingModels(this)) , m_installedModels(new InstalledModels(this)) , m_downloadableModels(new DownloadableModels(this)) , m_asyncModelRequestOngoing(false) { + m_embeddingModels->setSourceModel(this); m_installedModels->setSourceModel(this); m_downloadableModels->setSourceModel(this); m_watcher = new QFileSystemWatcher(this); @@ -280,6 +326,17 @@ const QList ModelList::userDefaultModelList() const return models; } +int ModelList::defaultEmbeddingModelIndex() const +{ + QMutexLocker locker(&m_mutex); + for (int i = 0; i < m_models.size(); ++i) { + const ModelInfo *info = m_models.at(i); + const bool isEmbedding = info->filename() == "all-MiniLM-L6-v2-f16.gguf"; + if (isEmbedding) return i; + } + return -1; +} + ModelInfo ModelList::defaultModelInfo() const { QMutexLocker locker(&m_mutex); diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index b17016c7..d1884266 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -120,6 +120,24 @@ private: }; Q_DECLARE_METATYPE(ModelInfo) +class EmbeddingModels : public QSortFilterProxyModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) +public: + explicit EmbeddingModels(QObject *parent); + int count() const; + + ModelInfo defaultModelInfo() const; + +Q_SIGNALS: + void countChanged(); + void defaultModelIndexChanged(); + +protected: + bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override; +}; + class InstalledModels : public QSortFilterProxyModel { Q_OBJECT @@ -165,6 +183,8 @@ class ModelList : public QAbstractListModel { Q_OBJECT Q_PROPERTY(int count READ count NOTIFY countChanged) + Q_PROPERTY(int defaultEmbeddingModelIndex READ defaultEmbeddingModelIndex NOTIFY defaultEmbeddingModelIndexChanged) + Q_PROPERTY(EmbeddingModels* embeddingModels READ embeddingModels NOTIFY embeddingModelsChanged) Q_PROPERTY(InstalledModels* installedModels READ installedModels NOTIFY installedModelsChanged) Q_PROPERTY(DownloadableModels* downloadableModels READ downloadableModels NOTIFY downloadableModelsChanged) Q_PROPERTY(QList userDefaultModelList READ userDefaultModelList NOTIFY userDefaultModelListChanged) @@ -273,6 +293,7 @@ public: Q_INVOKABLE QString clone(const ModelInfo &model); Q_INVOKABLE void remove(const ModelInfo &model); ModelInfo defaultModelInfo() const; + int defaultEmbeddingModelIndex() const; void addModel(const QString &id); void changeId(const QString &oldId, const QString &newId); @@ -280,6 +301,7 @@ public: const QList exportModelList() const; const QList userDefaultModelList() const; + EmbeddingModels *embeddingModels() const { return m_embeddingModels; } InstalledModels *installedModels() const { return m_installedModels; } DownloadableModels *downloadableModels() const { return m_downloadableModels; } @@ -300,10 +322,12 @@ public: Q_SIGNALS: void countChanged(); + void embeddingModelsChanged(); void installedModelsChanged(); void downloadableModelsChanged(); void userDefaultModelListChanged(); void asyncModelRequestOngoingChanged(); + void defaultEmbeddingModelIndexChanged(); private Q_SLOTS: void updateModelsFromJson(); @@ -326,6 +350,7 @@ private: private: mutable QMutex m_mutex; QNetworkAccessManager m_networkManager; + EmbeddingModels *m_embeddingModels; InstalledModels *m_installedModels; DownloadableModels *m_downloadableModels; QList m_models; diff --git a/gpt4all-chat/qml/CollectionsDialog.qml b/gpt4all-chat/qml/CollectionsDialog.qml index 095c20b0..f5a8b538 100644 --- a/gpt4all-chat/qml/CollectionsDialog.qml +++ b/gpt4all-chat/qml/CollectionsDialog.qml @@ -21,7 +21,7 @@ MyDialog { id: listLabel anchors.top: parent.top anchors.left: parent.left - text: "Available LocalDocs Collections:" + text: qsTr("Local Documents:") font.pixelSize: theme.fontSizeLarge color: theme.textColor } @@ -63,17 +63,60 @@ MyDialog { currentChat.removeCollection(collection) } } + ToolTip.text: qsTr("Warning: searching collections while indexing can return incomplete results") + ToolTip.visible: hovered && model.indexing } Text { id: collectionId anchors.verticalCenter: parent.verticalCenter anchors.left: checkBox.right anchors.margins: 20 + anchors.leftMargin: 10 text: collection font.pixelSize: theme.fontSizeLarge elide: Text.ElideRight color: theme.textColor } + ProgressBar { + id: itemProgressBar + anchors.verticalCenter: parent.verticalCenter + anchors.left: collectionId.right + anchors.right: parent.right + anchors.margins: 20 + anchors.leftMargin: 40 + visible: model.indexing + value: (model.totalBytesToIndex - model.currentBytesToIndex) / model.totalBytesToIndex + background: Rectangle { + implicitHeight: 45 + color: theme.backgroundDarkest + radius: 3 + } + contentItem: Item { + implicitHeight: 40 + + Rectangle { + width: itemProgressBar.visualPosition * parent.width + height: parent.height + radius: 2 + color: theme.assistantColor + } + } + Accessible.role: Accessible.ProgressBar + Accessible.name: qsTr("Indexing progressBar") + Accessible.description: qsTr("Shows the progress made in the indexing") + } + Label { + id: speedLabel + color: theme.textColor + visible: model.indexing + anchors.verticalCenter: itemProgressBar.verticalCenter + anchors.left: itemProgressBar.left + anchors.right: itemProgressBar.right + horizontalAlignment: Text.AlignHCenter + text: qsTr("indexing...") + elide: Text.ElideRight + font.pixelSize: theme.fontSizeLarge + } } } } diff --git a/gpt4all-chat/qml/LocalDocsSettings.qml b/gpt4all-chat/qml/LocalDocsSettings.qml index 04ce3c11..9c0ab8b1 100644 --- a/gpt4all-chat/qml/LocalDocsSettings.qml +++ b/gpt4all-chat/qml/LocalDocsSettings.qml @@ -5,6 +5,7 @@ import QtQuick.Controls.Basic import QtQuick.Layouts import QtQuick.Dialogs import localdocs +import modellist import mysettings import network @@ -13,7 +14,11 @@ MySettingsTab { MySettings.restoreLocalDocsDefaults(); } - title: qsTr("LocalDocs Plugin (BETA)") + property bool hasEmbeddingModel: ModelList.embeddingModels.count !== 0 + showAdvancedSettingsButton: hasEmbeddingModel + showRestoreDefaultsButton: hasEmbeddingModel + + title: qsTr("LocalDocs") contentItem: ColumnLayout { id: root spacing: 10 @@ -21,7 +26,30 @@ MySettingsTab { property alias collection: collection.text property alias folder_path: folderEdit.text + Label { + id: downloadLabel + Layout.fillWidth: true + Layout.maximumWidth: parent.width + wrapMode: Text.Wrap + visible: !hasEmbeddingModel + Layout.alignment: Qt.AlignLeft + text: qsTr("This feature requires the download of a text embedding model in order to index documents for later search. Please download the SBert text embedding model from the download dialog to proceed.") + font.pixelSize: theme.fontSizeLarger + } + + MyButton { + visible: !hasEmbeddingModel + Layout.topMargin: 20 + Layout.alignment: Qt.AlignLeft + text: qsTr("Download") + font.pixelSize: theme.fontSizeLarger + onClicked: { + downloadClicked() + } + } + Item { + visible: hasEmbeddingModel Layout.fillWidth: true height: row.height RowLayout { @@ -106,6 +134,7 @@ MySettingsTab { } ColumnLayout { + visible: hasEmbeddingModel spacing: 0 Repeater { model: LocalDocs.localDocsModel @@ -145,29 +174,25 @@ MySettingsTab { anchors.right: parent.right anchors.verticalCenter: parent.verticalCenter anchors.margins: 20 - width: Math.max(removeButton.width, busyIndicator.width) - height: Math.max(removeButton.height, busyIndicator.height) + width: removeButton.width + height:removeButton.height MyButton { id: removeButton anchors.centerIn: parent text: qsTr("Remove") - visible: !item.removing && installed + visible: !item.removing onClicked: { item.removing = true LocalDocs.removeFolder(collection, folder_path) } } - MyBusyIndicator { - id: busyIndicator - anchors.centerIn: parent - visible: item.removing || !installed - } } } } } RowLayout { + visible: hasEmbeddingModel Label { id: showReferencesLabel text: qsTr("Show references:") @@ -186,6 +211,7 @@ MySettingsTab { } Rectangle { + visible: hasEmbeddingModel Layout.fillWidth: true height: 1 color: theme.tabBorder @@ -196,6 +222,7 @@ MySettingsTab { columns: 3 rowSpacing: 10 columnSpacing: 10 + visible: hasEmbeddingModel Rectangle { Layout.row: 3 diff --git a/gpt4all-chat/qml/ModelDownloaderDialog.qml b/gpt4all-chat/qml/ModelDownloaderDialog.qml index d0da3cde..efb4f20c 100644 --- a/gpt4all-chat/qml/ModelDownloaderDialog.qml +++ b/gpt4all-chat/qml/ModelDownloaderDialog.qml @@ -16,9 +16,17 @@ MyDialog { modal: true closePolicy: Popup.CloseOnEscape | Popup.CloseOnPressOutside padding: 10 + property bool showEmbeddingModels: false onOpened: { Network.sendModelDownloaderDialog(); + + if (showEmbeddingModels) { + ModelList.downloadableModels.expanded = true + var targetModelIndex = ModelList.defaultEmbeddingModelIndex + console.log("targetModelIndex " + targetModelIndex) + modelListView.positionViewAtIndex(targetModelIndex, ListView.Contain); + } } PopupDialog { diff --git a/gpt4all-chat/qml/MySettingsTab.qml b/gpt4all-chat/qml/MySettingsTab.qml index 83d94c5a..7d47eb4f 100644 --- a/gpt4all-chat/qml/MySettingsTab.qml +++ b/gpt4all-chat/qml/MySettingsTab.qml @@ -9,8 +9,11 @@ Item { property string title: "" property Item contentItem: null property Item advancedSettings: null + property bool showAdvancedSettingsButton: true + property bool showRestoreDefaultsButton: true property var openFolderDialog signal restoreDefaultsClicked + signal downloadClicked onContentItemChanged: function() { if (contentItem) { @@ -64,6 +67,7 @@ Item { MyButton { id: restoreDefaultsButton anchors.left: parent.left + visible: showRestoreDefaultsButton width: implicitWidth text: qsTr("Restore Defaults") font.pixelSize: theme.fontSizeLarge @@ -77,7 +81,7 @@ Item { MyButton { id: advancedSettingsButton anchors.right: parent.right - visible: root.advancedSettings + visible: root.advancedSettings && showAdvancedSettingsButton width: implicitWidth text: !advancedInner.visible ? qsTr("Advanced Settings") : qsTr("Hide Advanced Settings") font.pixelSize: theme.fontSizeLarge diff --git a/gpt4all-chat/qml/SettingsDialog.qml b/gpt4all-chat/qml/SettingsDialog.qml index 65496f5d..dd68d3c8 100644 --- a/gpt4all-chat/qml/SettingsDialog.qml +++ b/gpt4all-chat/qml/SettingsDialog.qml @@ -19,6 +19,8 @@ MyDialog { Network.sendSettingsDialog(); } + signal downloadClicked + Item { Accessible.role: Accessible.Dialog Accessible.name: qsTr("Settings") @@ -28,13 +30,13 @@ MyDialog { ListModel { id: stacksModel ListElement { - title: "Models" + title: qsTr("Models") } ListElement { - title: "Application" + title: qsTr("Application") } ListElement { - title: "Plugins" + title: qsTr("LocalDocs") } } @@ -107,9 +109,16 @@ MyDialog { } MySettingsStack { - title: qsTr("LocalDocs Plugin (BETA) Settings") + title: qsTr("Local Document Collections") tabs: [ - Component { LocalDocsSettings { } } + Component { + LocalDocsSettings { + id: localDocsSettings + Component.onCompleted: { + localDocsSettings.downloadClicked.connect(settingsDialog.downloadClicked); + } + } + } ] } }