From d14b95f4bde46e671e51695aa3bb4e9354f2b29f Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 22 Jan 2024 12:36:01 -0500 Subject: [PATCH] Add Nomic Embed model for atlas with localdocs. --- gpt4all-chat/chatllm.cpp | 2 +- gpt4all-chat/database.cpp | 61 ++++- gpt4all-chat/database.h | 13 +- gpt4all-chat/embeddings.cpp | 3 + gpt4all-chat/embllm.cpp | 270 +++++++++++++++++++-- gpt4all-chat/embllm.h | 68 +++++- gpt4all-chat/localdocs.cpp | 6 + gpt4all-chat/localdocsmodel.cpp | 27 +++ gpt4all-chat/localdocsmodel.h | 9 +- gpt4all-chat/main.qml | 2 +- gpt4all-chat/metadata/models2.json | 2 +- gpt4all-chat/modellist.cpp | 57 ++++- gpt4all-chat/modellist.h | 10 +- gpt4all-chat/qml/CollectionsDialog.qml | 14 +- gpt4all-chat/qml/ModelDownloaderDialog.qml | 40 +-- 15 files changed, 506 insertions(+), 78 deletions(-) diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index eac7b0c2..11cc2559 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -156,7 +156,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (isModelLoaded() && this->modelInfo() == modelInfo) return true; - bool isChatGPT = modelInfo.isChatGPT; + bool isChatGPT = modelInfo.isOnline; // right now only chatgpt is offered for online chat models... QString filePath = modelInfo.dirpath + modelInfo.filename(); QFileInfo fileInfo(filePath); diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index 8369db5b..3fdde3ac 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -558,7 +558,6 @@ void Database::scheduleNext(int folder_id, size_t countForFolder) if (!countForFolder) { emit updateIndexing(folder_id, false); emit updateInstalled(folder_id, true); - m_embeddings->save(); } if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue); @@ -570,7 +569,7 @@ void Database::handleDocumentError(const QString &errorMessage, qWarning() << errorMessage << document_id << document_path << error.text(); } -size_t Database::chunkStream(QTextStream &stream, int document_id, const QString &file, +size_t Database::chunkStream(QTextStream &stream, int folder_id, int document_id, const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, int page, int maxChunks) { @@ -580,6 +579,8 @@ size_t Database::chunkStream(QTextStream &stream, int document_id, const QString QList words; int chunks = 0; + QVector chunkList; + while (!stream.atEnd()) { QString word; stream >> word; @@ -605,9 +606,22 @@ size_t Database::chunkStream(QTextStream &stream, int document_id, const QString qWarning() << "ERROR: Could not insert chunk into db" << q.lastError(); } +#if 1 + EmbeddingChunk toEmbed; + toEmbed.folder_id = folder_id; + toEmbed.chunk_id = chunk_id; + toEmbed.chunk = chunk; + chunkList << toEmbed; + if (chunkList.count() == 100) { + m_embLLM->generateAsyncEmbeddings(chunkList); + emit updateTotalEmbeddingsToIndex(folder_id, 100); + chunkList.clear(); + } +#else const std::vector result = m_embLLM->generateEmbeddings(chunk); if (!m_embeddings->add(result, chunk_id)) qWarning() << "ERROR: Cannot add point to embeddings index"; +#endif ++chunks; @@ -615,12 +629,39 @@ size_t Database::chunkStream(QTextStream &stream, int document_id, const QString charCount = 0; if (maxChunks > 0 && chunks == maxChunks) - return stream.pos(); + break; } } + + if (!chunkList.isEmpty()) { + m_embLLM->generateAsyncEmbeddings(chunkList); + emit updateTotalEmbeddingsToIndex(folder_id, chunkList.count()); + chunkList.clear(); + } + return stream.pos(); } +void Database::handleEmbeddingsGenerated(const QVector &embeddings) +{ + if (embeddings.isEmpty()) + return; + + int folder_id = 0; + for (auto e : embeddings) { + folder_id = e.folder_id; + if (!m_embeddings->add(e.embedding, e.chunk_id)) + qWarning() << "ERROR: Cannot add point to embeddings index"; + } + emit updateCurrentEmbeddingsToIndex(folder_id, embeddings.count()); + m_embeddings->save(); +} + +void Database::handleErrorGenerated(int folder_id, const QString &error) +{ + emit updateError(folder_id, error); +} + void Database::removeEmbeddingsByDocumentId(int document_id) { QSqlQuery q; @@ -792,14 +833,13 @@ void Database::scanQueue() const QPdfSelection selection = doc.getAllText(pageIndex); QString text = selection.text(); QTextStream stream(&text); - chunkStream(stream, document_id, info.doc.fileName(), + chunkStream(stream, info.folder, document_id, info.doc.fileName(), doc.metaData(QPdfDocument::MetaDataField::Title).toString(), doc.metaData(QPdfDocument::MetaDataField::Author).toString(), doc.metaData(QPdfDocument::MetaDataField::Subject).toString(), doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(), pageIndex + 1 ); - m_embeddings->save(); emit subtractCurrentBytesToIndex(info.folder, bytesPerPage); if (info.currentPage < doc.pageCount()) { info.currentPage += 1; @@ -828,9 +868,8 @@ void Database::scanQueue() #if defined(DEBUG) qDebug() << "scanning byteIndex" << byteIndex << "of" << bytes << document_path; #endif - int pos = chunkStream(stream, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/, - QString() /*subject*/, QString() /*keywords*/, -1 /*page*/, 5 /*maxChunks*/); - m_embeddings->save(); + int pos = chunkStream(stream, info.folder, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/, + QString() /*subject*/, QString() /*keywords*/, -1 /*page*/, 100 /*maxChunks*/); file.close(); const size_t bytesChunked = pos - byteIndex; emit subtractCurrentBytesToIndex(info.folder, bytesChunked); @@ -892,6 +931,8 @@ void Database::scanDocuments(int folder_id, const QString &folder_path) void Database::start() { connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Database::directoryChanged); + connect(m_embLLM, &EmbeddingLLM::embeddingsGenerated, this, &Database::handleEmbeddingsGenerated); + connect(m_embLLM, &EmbeddingLLM::errorGenerated, this, &Database::handleErrorGenerated); connect(this, &Database::docsToScanChanged, this, &Database::scanQueue); if (!QSqlDatabase::drivers().contains("QSQLITE")) { qWarning() << "ERROR: missing sqllite driver"; @@ -1081,6 +1122,10 @@ void Database::retrieveFromDB(const QList &collections, const QString & QSqlQuery q; if (m_embeddings->isLoaded()) { std::vector result = m_embLLM->generateEmbeddings(text); + if (result.empty()) { + qDebug() << "ERROR: generating embeddings returned a null result"; + return; + } std::vector embeddings = m_embeddings->search(result, retrievalSize); if (!selectChunk(q, collections, embeddings, retrievalSize)) { qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index b217758b..9d10fd00 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -8,8 +8,9 @@ #include #include +#include "embllm.h" + class Embeddings; -class EmbeddingLLM; struct DocumentInfo { int folder; @@ -39,10 +40,13 @@ struct CollectionItem { int folder_id = -1; bool installed = false; bool indexing = false; + QString error; int currentDocsToIndex = 0; int totalDocsToIndex = 0; size_t currentBytesToIndex = 0; size_t totalBytesToIndex = 0; + size_t currentEmbeddingsToIndex = 0; + size_t totalEmbeddingsToIndex = 0; }; Q_DECLARE_METATYPE(CollectionItem) @@ -66,11 +70,14 @@ Q_SIGNALS: void docsToScanChanged(); void updateInstalled(int folder_id, bool b); void updateIndexing(int folder_id, bool b); + void updateError(int folder_id, const QString &error); void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex); void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex); void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes); void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex); void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex); + void updateCurrentEmbeddingsToIndex(int folder_id, size_t currentBytesToIndex); + void updateTotalEmbeddingsToIndex(int folder_id, size_t totalBytesToIndex); void addCollectionItem(const CollectionItem &item); void removeFolderById(int folder_id); void removeCollectionItem(const QString &collectionName); @@ -82,10 +89,12 @@ private Q_SLOTS: bool addFolderToWatch(const QString &path); bool removeFolderFromWatch(const QString &path); void addCurrentFolders(); + void handleEmbeddingsGenerated(const QVector &embeddings); + void handleErrorGenerated(int folder_id, const QString &error); private: void removeFolderInternal(const QString &collection, int folder_id, const QString &path); - size_t chunkStream(QTextStream &stream, int document_id, const QString &file, + size_t chunkStream(QTextStream &stream, int folder_id, int document_id, const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, int page, int maxChunks = -1); void removeEmbeddingsByDocumentId(int document_id); diff --git a/gpt4all-chat/embeddings.cpp b/gpt4all-chat/embeddings.cpp index 6bfe256a..cd157033 100644 --- a/gpt4all-chat/embeddings.cpp +++ b/gpt4all-chat/embeddings.cpp @@ -129,6 +129,9 @@ bool Embeddings::add(const std::vector &embedding, qint64 label) } } + if (embedding.empty()) + return false; + try { m_hnsw->addPoint(embedding.data(), label, false); } catch (const std::exception &e) { diff --git a/gpt4all-chat/embllm.cpp b/gpt4all-chat/embllm.cpp index bc7b2ef9..7a50fa01 100644 --- a/gpt4all-chat/embllm.cpp +++ b/gpt4all-chat/embllm.cpp @@ -1,19 +1,31 @@ #include "embllm.h" #include "modellist.h" -EmbeddingLLM::EmbeddingLLM() - : QObject{nullptr} - , m_model{nullptr} +EmbeddingLLMWorker::EmbeddingLLMWorker() + : QObject(nullptr) + , m_networkManager(new QNetworkAccessManager(this)) + , m_model(nullptr) { + moveToThread(&m_workerThread); + connect(this, &EmbeddingLLMWorker::finished, &m_workerThread, &QThread::quit, Qt::DirectConnection); + m_workerThread.setObjectName("embedding"); + m_workerThread.start(); } -EmbeddingLLM::~EmbeddingLLM() +EmbeddingLLMWorker::~EmbeddingLLMWorker() { - delete m_model; - m_model = nullptr; + if (m_model) { + delete m_model; + m_model = nullptr; + } } -bool EmbeddingLLM::loadModel() +void EmbeddingLLMWorker::wait() +{ + m_workerThread.wait(); +} + +bool EmbeddingLLMWorker::loadModel() { const EmbeddingModels *embeddingModels = ModelList::globalInstance()->embeddingModels(); if (!embeddingModels->count()) @@ -29,6 +41,16 @@ bool EmbeddingLLM::loadModel() return false; } + bool isNomic = fileInfo.fileName().startsWith("nomic"); + if (isNomic) { + QFile file(filePath); + file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text); + QTextStream stream(&file); + m_nomicAPIKey = stream.readAll(); + file.close(); + return true; + } + m_model = LLModel::Implementation::construct(filePath.toStdString()); bool success = m_model->loadModel(filePath.toStdString(), 2048, 0); if (!success) { @@ -47,18 +69,236 @@ bool EmbeddingLLM::loadModel() return true; } -bool EmbeddingLLM::hasModel() const +bool EmbeddingLLMWorker::hasModel() const { - return m_model; + return m_model || !m_nomicAPIKey.isEmpty(); +} + +bool EmbeddingLLMWorker::isNomic() const +{ + return !m_nomicAPIKey.isEmpty(); +} + +std::vector EmbeddingLLMWorker::generateSyncEmbedding(const QString &text) +{ + if (!hasModel() && !loadModel()) { + qWarning() << "WARNING: Could not load model for embeddings"; + return std::vector(); + } + + if (isNomic()) { + qWarning() << "WARNING: Request to generate sync embeddings for non-local model invalid"; + return std::vector(); + } + + return m_model->embedding(text.toStdString()); +} + +void EmbeddingLLMWorker::requestSyncEmbedding(const QString &text) +{ + if (!hasModel() && !loadModel()) { + qWarning() << "WARNING: Could not load model for embeddings"; + return; + } + + if (!isNomic()) { + qWarning() << "WARNING: Request to generate sync embeddings for local model invalid"; + return; + } + + Q_ASSERT(hasModel()); + + QJsonObject root; + root.insert("model", "nomic-embed-text-v1"); + QJsonArray texts; + texts.append(text); + root.insert("texts", texts); + + QJsonDocument doc(root); + + QUrl nomicUrl("https://api-atlas.nomic.ai/v1/embedding/text"); + const QString authorization = QString("Bearer %1").arg(m_nomicAPIKey).trimmed(); + QNetworkRequest request(nomicUrl); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + request.setRawHeader("Authorization", authorization.toUtf8()); + QNetworkReply *reply = m_networkManager->post(request, doc.toJson(QJsonDocument::Compact)); + connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort); + connect(reply, &QNetworkReply::finished, this, &EmbeddingLLMWorker::handleFinished); +} + +void EmbeddingLLMWorker::requestAsyncEmbedding(const QVector &chunks) +{ + if (!hasModel() && !loadModel()) { + qWarning() << "WARNING: Could not load model for embeddings"; + return; + } + + if (m_nomicAPIKey.isEmpty()) { + QVector results; + results.reserve(chunks.size()); + for (auto c : chunks) { + EmbeddingResult result; + result.folder_id = c.folder_id; + result.chunk_id = c.chunk_id; + result.embedding = m_model->embedding(c.chunk.toStdString()); + results << result; + } + emit embeddingsGenerated(results); + return; + }; + + QJsonObject root; + root.insert("model", "nomic-embed-text-v1"); + QJsonArray texts; + + for (auto c : chunks) + texts.append(c.chunk); + root.insert("texts", texts); + + QJsonDocument doc(root); + + QUrl nomicUrl("https://api-atlas.nomic.ai/v1/embedding/text"); + const QString authorization = QString("Bearer %1").arg(m_nomicAPIKey).trimmed(); + QNetworkRequest request(nomicUrl); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + request.setRawHeader("Authorization", authorization.toUtf8()); + request.setAttribute(QNetworkRequest::User, QVariant::fromValue(chunks)); + + QNetworkReply *reply = m_networkManager->post(request, doc.toJson(QJsonDocument::Compact)); + connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort); + connect(reply, &QNetworkReply::finished, this, &EmbeddingLLMWorker::handleFinished); +} + +std::vector jsonArrayToVector(const QJsonArray &jsonArray) { + std::vector result; + + for (const QJsonValue &innerValue : jsonArray) { + if (innerValue.isArray()) { + QJsonArray innerArray = innerValue.toArray(); + result.reserve(result.size() + innerArray.size()); + for (const QJsonValue &value : innerArray) { + result.push_back(static_cast(value.toDouble())); + } + } + } + + return result; +} + +QVector jsonArrayToEmbeddingResults(const QVector& chunks, const QJsonArray& embeddings) { + QVector results; + + if (chunks.size() != embeddings.size()) { + qWarning() << "WARNING: Size of json array result does not match input!"; + return results; + } + + for (int i = 0; i < chunks.size(); ++i) { + const EmbeddingChunk& chunk = chunks.at(i); + const QJsonArray embeddingArray = embeddings.at(i).toArray(); + + std::vector embeddingVector; + for (const QJsonValue& value : embeddingArray) + embeddingVector.push_back(static_cast(value.toDouble())); + + EmbeddingResult result; + result.folder_id = chunk.folder_id; + result.chunk_id = chunk.chunk_id; + result.embedding = std::move(embeddingVector); + results.push_back(std::move(result)); + } + + return results; +} + +void EmbeddingLLMWorker::handleFinished() +{ + QNetworkReply *reply = qobject_cast(sender()); + if (!reply) + return; + + QVariant retrievedData = reply->request().attribute(QNetworkRequest::User); + QVector chunks; + if (retrievedData.isValid() && retrievedData.canConvert>()) + chunks = retrievedData.value>(); + + int folder_id = 0; + if (!chunks.isEmpty()) + folder_id = chunks.first().folder_id; + + QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); + Q_ASSERT(response.isValid()); + bool ok; + int code = response.toInt(&ok); + if (!ok || code != 200) { + QString errorDetails; + QString replyErrorString = reply->errorString().trimmed(); + QByteArray replyContent = reply->readAll().trimmed(); + errorDetails = QString("ERROR: Nomic Atlas responded with error code \"%1\"").arg(code); + if (!replyErrorString.isEmpty()) + errorDetails += QString(". Error Details: \"%1\"").arg(replyErrorString); + if (!replyContent.isEmpty()) + errorDetails += QString(". Response Content: \"%1\"").arg(QString::fromUtf8(replyContent)); + qWarning() << errorDetails; + emit errorGenerated(folder_id, errorDetails); + return; + } + + QByteArray jsonData = reply->readAll(); + + QJsonParseError err; + QJsonDocument document = QJsonDocument::fromJson(jsonData, &err); + if (err.error != QJsonParseError::NoError) { + qWarning() << "ERROR: Couldn't parse Nomic Atlas response: " << jsonData << err.errorString(); + return; + } + + const QJsonObject root = document.object(); + const QJsonArray embeddings = root.value("embeddings").toArray(); + + if (!chunks.isEmpty()) { + emit embeddingsGenerated(jsonArrayToEmbeddingResults(chunks, embeddings)); + } else { + m_lastResponse = jsonArrayToVector(embeddings); + emit finished(); + } + + reply->deleteLater(); +} + +EmbeddingLLM::EmbeddingLLM() + : QObject(nullptr) + , m_embeddingWorker(new EmbeddingLLMWorker) +{ + connect(this, &EmbeddingLLM::requestAsyncEmbedding, m_embeddingWorker, + &EmbeddingLLMWorker::requestAsyncEmbedding, Qt::QueuedConnection); + connect(m_embeddingWorker, &EmbeddingLLMWorker::embeddingsGenerated, this, + &EmbeddingLLM::embeddingsGenerated, Qt::QueuedConnection); + connect(m_embeddingWorker, &EmbeddingLLMWorker::errorGenerated, this, + &EmbeddingLLM::errorGenerated, Qt::QueuedConnection); +} + +EmbeddingLLM::~EmbeddingLLM() +{ + delete m_embeddingWorker; + m_embeddingWorker = nullptr; } std::vector EmbeddingLLM::generateEmbeddings(const QString &text) { - if (!hasModel() && !loadModel()) { - qWarning() << "WARNING: Could not load sbert model for embeddings"; - return std::vector(); + if (!m_embeddingWorker->isNomic()) { + return m_embeddingWorker->generateSyncEmbedding(text); + } else { + EmbeddingLLMWorker worker; + connect(this, &EmbeddingLLM::requestSyncEmbedding, &worker, + &EmbeddingLLMWorker::requestSyncEmbedding, Qt::QueuedConnection); + emit requestSyncEmbedding(text); + worker.wait(); + return worker.lastResponse(); } - - Q_ASSERT(hasModel()); - return m_model->embedding(text.toStdString()); +} + +void EmbeddingLLM::generateAsyncEmbeddings(const QVector &chunks) +{ + emit requestAsyncEmbedding(chunks); } diff --git a/gpt4all-chat/embllm.h b/gpt4all-chat/embllm.h index 29148546..cde30c60 100644 --- a/gpt4all-chat/embllm.h +++ b/gpt4all-chat/embllm.h @@ -3,8 +3,61 @@ #include #include +#include +#include + #include "../gpt4all-backend/llmodel.h" +struct EmbeddingChunk { + int folder_id; + int chunk_id; + QString chunk; +}; + +Q_DECLARE_METATYPE(EmbeddingChunk) + +struct EmbeddingResult { + int folder_id; + int chunk_id; + std::vector embedding; +}; + +class EmbeddingLLMWorker : public QObject { + Q_OBJECT +public: + EmbeddingLLMWorker(); + virtual ~EmbeddingLLMWorker(); + + void wait(); + + std::vector lastResponse() const { return m_lastResponse; } + + bool loadModel(); + bool hasModel() const; + bool isNomic() const; + + std::vector generateSyncEmbedding(const QString &text); + +public Q_SLOTS: + void requestSyncEmbedding(const QString &text); + void requestAsyncEmbedding(const QVector &chunks); + +Q_SIGNALS: + void embeddingsGenerated(const QVector &embeddings); + void errorGenerated(int folder_id, const QString &error); + void finished(); + +private Q_SLOTS: + void handleFinished(); + +private: + QString m_nomicAPIKey; + QNetworkAccessManager *m_networkManager; + std::vector m_lastResponse; + LLModel *m_model = nullptr; + QThread m_workerThread; +}; + class EmbeddingLLM : public QObject { Q_OBJECT @@ -12,16 +65,21 @@ public: EmbeddingLLM(); virtual ~EmbeddingLLM(); + bool loadModel(); bool hasModel() const; public Q_SLOTS: - std::vector generateEmbeddings(const QString &text); + std::vector generateEmbeddings(const QString &text); // synchronous + void generateAsyncEmbeddings(const QVector &chunks); + +Q_SIGNALS: + void requestSyncEmbedding(const QString &text); + void requestAsyncEmbedding(const QVector &chunks); + void embeddingsGenerated(const QVector &embeddings); + void errorGenerated(int folder_id, const QString &error); private: - bool loadModel(); - -private: - LLModel *m_model = nullptr; + EmbeddingLLMWorker *m_embeddingWorker; }; #endif // EMBLLM_H diff --git a/gpt4all-chat/localdocs.cpp b/gpt4all-chat/localdocs.cpp index 1f37a207..3baaa983 100644 --- a/gpt4all-chat/localdocs.cpp +++ b/gpt4all-chat/localdocs.cpp @@ -30,6 +30,8 @@ LocalDocs::LocalDocs() m_localDocsModel, &LocalDocsModel::updateInstalled, Qt::QueuedConnection); connect(m_database, &Database::updateIndexing, m_localDocsModel, &LocalDocsModel::updateIndexing, Qt::QueuedConnection); + connect(m_database, &Database::updateError, + m_localDocsModel, &LocalDocsModel::updateError, Qt::QueuedConnection); connect(m_database, &Database::updateCurrentDocsToIndex, m_localDocsModel, &LocalDocsModel::updateCurrentDocsToIndex, Qt::QueuedConnection); connect(m_database, &Database::updateTotalDocsToIndex, @@ -40,6 +42,10 @@ LocalDocs::LocalDocs() m_localDocsModel, &LocalDocsModel::updateCurrentBytesToIndex, Qt::QueuedConnection); connect(m_database, &Database::updateTotalBytesToIndex, m_localDocsModel, &LocalDocsModel::updateTotalBytesToIndex, Qt::QueuedConnection); + connect(m_database, &Database::updateCurrentEmbeddingsToIndex, + m_localDocsModel, &LocalDocsModel::updateCurrentEmbeddingsToIndex, Qt::QueuedConnection); + connect(m_database, &Database::updateTotalEmbeddingsToIndex, + m_localDocsModel, &LocalDocsModel::updateTotalEmbeddingsToIndex, Qt::QueuedConnection); connect(m_database, &Database::addCollectionItem, m_localDocsModel, &LocalDocsModel::addCollectionItem, Qt::QueuedConnection); connect(m_database, &Database::removeFolderById, diff --git a/gpt4all-chat/localdocsmodel.cpp b/gpt4all-chat/localdocsmodel.cpp index b2bea0fa..56730169 100644 --- a/gpt4all-chat/localdocsmodel.cpp +++ b/gpt4all-chat/localdocsmodel.cpp @@ -48,6 +48,8 @@ QVariant LocalDocsModel::data(const QModelIndex &index, int role) const return item.installed; case IndexingRole: return item.indexing; + case ErrorRole: + return item.error; case CurrentDocsToIndexRole: return item.currentDocsToIndex; case TotalDocsToIndexRole: @@ -56,6 +58,10 @@ QVariant LocalDocsModel::data(const QModelIndex &index, int role) const return quint64(item.currentBytesToIndex); case TotalBytesToIndexRole: return quint64(item.totalBytesToIndex); + case CurrentEmbeddingsToIndexRole: + return quint64(item.currentEmbeddingsToIndex); + case TotalEmbeddingsToIndexRole: + return quint64(item.totalEmbeddingsToIndex); } return QVariant(); @@ -68,10 +74,13 @@ QHash LocalDocsModel::roleNames() const roles[FolderPathRole] = "folder_path"; roles[InstalledRole] = "installed"; roles[IndexingRole] = "indexing"; + roles[ErrorRole] = "error"; roles[CurrentDocsToIndexRole] = "currentDocsToIndex"; roles[TotalDocsToIndexRole] = "totalDocsToIndex"; roles[CurrentBytesToIndexRole] = "currentBytesToIndex"; roles[TotalBytesToIndexRole] = "totalBytesToIndex"; + roles[CurrentEmbeddingsToIndexRole] = "currentEmbeddingsToIndex"; + roles[TotalEmbeddingsToIndexRole] = "totalEmbeddingsToIndex"; return roles; } @@ -101,6 +110,12 @@ void LocalDocsModel::updateIndexing(int folder_id, bool b) [](CollectionItem& item, bool val) { item.indexing = val; }, {IndexingRole}); } +void LocalDocsModel::updateError(int folder_id, const QString &error) +{ + updateField(folder_id, error, + [](CollectionItem& item, QString val) { item.error = val; }, {ErrorRole}); +} + void LocalDocsModel::updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex) { updateField(folder_id, currentDocsToIndex, @@ -131,6 +146,18 @@ void LocalDocsModel::updateTotalBytesToIndex(int folder_id, size_t totalBytesToI [](CollectionItem& item, size_t val) { item.totalBytesToIndex = val; }, {TotalBytesToIndexRole}); } +void LocalDocsModel::updateCurrentEmbeddingsToIndex(int folder_id, size_t currentEmbeddingsToIndex) +{ + updateField(folder_id, currentEmbeddingsToIndex, + [](CollectionItem& item, size_t val) { item.currentEmbeddingsToIndex += val; }, {CurrentEmbeddingsToIndexRole}); +} + +void LocalDocsModel::updateTotalEmbeddingsToIndex(int folder_id, size_t totalEmbeddingsToIndex) +{ + updateField(folder_id, totalEmbeddingsToIndex, + [](CollectionItem& item, size_t val) { item.totalEmbeddingsToIndex += val; }, {TotalEmbeddingsToIndexRole}); +} + void LocalDocsModel::addCollectionItem(const CollectionItem &item) { beginInsertRows(QModelIndex(), m_collectionList.size(), m_collectionList.size()); diff --git a/gpt4all-chat/localdocsmodel.h b/gpt4all-chat/localdocsmodel.h index 47997143..4db836a3 100644 --- a/gpt4all-chat/localdocsmodel.h +++ b/gpt4all-chat/localdocsmodel.h @@ -30,11 +30,13 @@ public: FolderPathRole, InstalledRole, IndexingRole, - EmbeddingRole, + ErrorRole, CurrentDocsToIndexRole, TotalDocsToIndexRole, CurrentBytesToIndexRole, - TotalBytesToIndexRole + TotalBytesToIndexRole, + CurrentEmbeddingsToIndexRole, + TotalEmbeddingsToIndexRole }; explicit LocalDocsModel(QObject *parent = nullptr); @@ -45,11 +47,14 @@ public: public Q_SLOTS: void updateInstalled(int folder_id, bool b); void updateIndexing(int folder_id, bool b); + void updateError(int folder_id, const QString &error); void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex); void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex); void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes); void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex); void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex); + void updateCurrentEmbeddingsToIndex(int folder_id, size_t currentBytesToIndex); + void updateTotalEmbeddingsToIndex(int folder_id, size_t totalBytesToIndex); void addCollectionItem(const CollectionItem &item); void removeFolderById(int folder_id); void removeCollectionPath(const QString &name, const QString &path); diff --git a/gpt4all-chat/main.qml b/gpt4all-chat/main.qml index 6e317062..ea79b1e0 100644 --- a/gpt4all-chat/main.qml +++ b/gpt4all-chat/main.qml @@ -1129,7 +1129,7 @@ Window { } Image { - visible: currentChat.isServer || currentChat.modelInfo.isChatGPT + visible: currentChat.isServer || currentChat.modelInfo.isOnline anchors.fill: parent sourceSize.width: 1024 sourceSize.height: 1024 diff --git a/gpt4all-chat/metadata/models2.json b/gpt4all-chat/metadata/models2.json index 5105b3eb..98bc4440 100644 --- a/gpt4all-chat/metadata/models2.json +++ b/gpt4all-chat/metadata/models2.json @@ -218,7 +218,7 @@ "quant": "f16", "type": "Bert", "systemPrompt": " ", - "description": "LocalDocs text embeddings model
  • Necessary for LocalDocs feature
  • Used for retrieval augmented generation (RAG)", + "description": "LocalDocs text embeddings model
    • For use with LocalDocs feature
    • Used for retrieval augmented generation (RAG)", "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf" }, { diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 97403594..7d07e4c5 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -10,6 +10,7 @@ //#define USE_LOCAL_MODELSJSON #define DEFAULT_EMBEDDING_MODEL "all-MiniLM-L6-v2-f16.gguf" +#define NOMIC_EMBEDDING_MODEL "nomic-embed-text-v1.txt" QString ModelInfo::id() const { @@ -202,7 +203,8 @@ bool EmbeddingModels::filterAcceptsRow(int sourceRow, { QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool(); - bool isEmbedding = sourceModel()->data(index, ModelList::FilenameRole).toString() == DEFAULT_EMBEDDING_MODEL; + bool isEmbedding = sourceModel()->data(index, ModelList::FilenameRole).toString() == DEFAULT_EMBEDDING_MODEL || + sourceModel()->data(index, ModelList::FilenameRole).toString() == NOMIC_EMBEDDING_MODEL; return isInstalled && isEmbedding; } @@ -405,7 +407,7 @@ ModelInfo ModelList::defaultModelInfo() const const size_t ramrequired = defaultModel->ramrequired; // If we don't have either setting, then just use the first model that requires less than 16GB that is installed - if (!hasUserDefaultName && !info->isChatGPT && ramrequired > 0 && ramrequired < 16) + if (!hasUserDefaultName && !info->isOnline && ramrequired > 0 && ramrequired < 16) break; // If we have a user specified default and match, then use it @@ -526,8 +528,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->installed; case DefaultRole: return info->isDefault; - case ChatGPTRole: - return info->isChatGPT; + case OnlineRole: + return info->isOnline; case DisableGUIRole: return info->disableGUI; case DescriptionRole: @@ -655,8 +657,8 @@ void ModelList::updateData(const QString &id, int role, const QVariant &value) info->installed = value.toBool(); break; case DefaultRole: info->isDefault = value.toBool(); break; - case ChatGPTRole: - info->isChatGPT = value.toBool(); break; + case OnlineRole: + info->isOnline = value.toBool(); break; case DisableGUIRole: info->disableGUI = value.toBool(); break; case DescriptionRole: @@ -791,7 +793,7 @@ QString ModelList::clone(const ModelInfo &model) updateData(id, ModelList::FilenameRole, model.filename()); updateData(id, ModelList::DirpathRole, model.dirpath); updateData(id, ModelList::InstalledRole, model.installed); - updateData(id, ModelList::ChatGPTRole, model.isChatGPT); + updateData(id, ModelList::OnlineRole, model.isOnline); updateData(id, ModelList::TemperatureRole, model.temperature()); updateData(id, ModelList::TopPRole, model.topP()); updateData(id, ModelList::TopKRole, model.topK()); @@ -873,10 +875,10 @@ QString ModelList::uniqueModelName(const ModelInfo &model) const return baseName; } -QString ModelList::modelDirPath(const QString &modelName, bool isChatGPT) +QString ModelList::modelDirPath(const QString &modelName, bool isOnline) { QVector possibleFilePaths; - if (isChatGPT) + if (isOnline) possibleFilePaths << "/" + modelName + ".txt"; else { possibleFilePaths << "/ggml-" + modelName + ".bin"; @@ -911,7 +913,7 @@ void ModelList::updateModelsFromDirectory() // All files that end with .bin and have 'ggml' somewhere in the name if (((filename.endsWith(".bin") || filename.endsWith(".gguf")) && (/*filename.contains("ggml") ||*/ filename.contains("gguf")) && !filename.startsWith("incomplete")) - || (filename.endsWith(".txt") && filename.startsWith("chatgpt-"))) { + || (filename.endsWith(".txt") && (filename.startsWith("chatgpt-") || filename.startsWith("nomic-")))) { QString filePath = it.filePath(); QFileInfo info(filePath); @@ -934,7 +936,8 @@ void ModelList::updateModelsFromDirectory() for (const QString &id : modelsById) { updateData(id, FilenameRole, filename); - updateData(id, ChatGPTRole, filename.startsWith("chatgpt-")); + // FIXME: WE should change this to use a consistent filename for online models + updateData(id, OnlineRole, filename.startsWith("chatgpt-") || filename.startsWith("nomic-")); updateData(id, DirpathRole, info.dir().absolutePath() + "/"); updateData(id, FilesizeRole, toFileSize(info.size())); } @@ -1195,7 +1198,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) updateData(id, ModelList::NameRole, modelName); updateData(id, ModelList::FilenameRole, modelFilename); updateData(id, ModelList::FilesizeRole, "minimal"); - updateData(id, ModelList::ChatGPTRole, true); + updateData(id, ModelList::OnlineRole, true); updateData(id, ModelList::DescriptionRole, tr("OpenAI's ChatGPT model GPT-3.5 Turbo
      ") + chatGPTDesc); updateData(id, ModelList::RequiresVersionRole, "2.4.2"); @@ -1219,7 +1222,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) updateData(id, ModelList::NameRole, modelName); updateData(id, ModelList::FilenameRole, modelFilename); updateData(id, ModelList::FilesizeRole, "minimal"); - updateData(id, ModelList::ChatGPTRole, true); + updateData(id, ModelList::OnlineRole, true); updateData(id, ModelList::DescriptionRole, tr("OpenAI's ChatGPT model GPT-4
      ") + chatGPTDesc + chatGPT4Warn); updateData(id, ModelList::RequiresVersionRole, "2.4.2"); @@ -1229,6 +1232,34 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) updateData(id, ModelList::QuantRole, "NA"); updateData(id, ModelList::TypeRole, "GPT"); } + + { + const QString nomicEmbedDesc = tr("
      • For use with LocalDocs feature
      • " + "
      • Used for retrieval augmented generation (RAG)
      • " + "
      • Requires personal Nomic API key.
      • " + "
      • WARNING: Will send your localdocs to Nomic Atlas!
      • " + "
      • You can apply for an API key with Nomic Atlas.
      • "); + const QString modelName = "Nomic Embed"; + const QString id = modelName; + const QString modelFilename = "nomic-embed-text-v1.txt"; + if (contains(modelFilename)) + changeId(modelFilename, id); + if (!contains(id)) + addModel(id); + updateData(id, ModelList::NameRole, modelName); + updateData(id, ModelList::FilenameRole, modelFilename); + updateData(id, ModelList::FilesizeRole, "minimal"); + updateData(id, ModelList::OnlineRole, true); + updateData(id, ModelList::DisableGUIRole, true); + updateData(id, ModelList::DescriptionRole, + tr("LocalDocs Nomic Atlas Embed
        ") + nomicEmbedDesc); + updateData(id, ModelList::RequiresVersionRole, "2.6.3"); + updateData(id, ModelList::OrderRole, "na"); + updateData(id, ModelList::RamrequiredRole, 0); + updateData(id, ModelList::ParametersRole, "?"); + updateData(id, ModelList::QuantRole, "NA"); + updateData(id, ModelList::TypeRole, "Bert"); + } } void ModelList::updateModelsFromSettings() diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 475d6a40..8ffd8163 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -16,7 +16,7 @@ struct ModelInfo { Q_PROPERTY(bool installed MEMBER installed) Q_PROPERTY(bool isDefault MEMBER isDefault) Q_PROPERTY(bool disableGUI MEMBER disableGUI) - Q_PROPERTY(bool isChatGPT MEMBER isChatGPT) + Q_PROPERTY(bool isOnline MEMBER isOnline) Q_PROPERTY(QString description MEMBER description) Q_PROPERTY(QString requiresVersion MEMBER requiresVersion) Q_PROPERTY(QString deprecatedVersion MEMBER deprecatedVersion) @@ -64,7 +64,7 @@ public: bool calcHash = false; bool installed = false; bool isDefault = false; - bool isChatGPT = false; + bool isOnline = false; bool disableGUI = false; QString description; QString requiresVersion; @@ -217,7 +217,7 @@ public: CalcHashRole, InstalledRole, DefaultRole, - ChatGPTRole, + OnlineRole, DisableGUIRole, DescriptionRole, RequiresVersionRole, @@ -261,7 +261,7 @@ public: roles[CalcHashRole] = "calcHash"; roles[InstalledRole] = "installed"; roles[DefaultRole] = "isDefault"; - roles[ChatGPTRole] = "isChatGPT"; + roles[OnlineRole] = "isOnline"; roles[DisableGUIRole] = "disableGUI"; roles[DescriptionRole] = "description"; roles[RequiresVersionRole] = "requiresVersion"; @@ -359,7 +359,7 @@ private Q_SLOTS: void handleSslErrors(QNetworkReply *reply, const QList &errors); private: - QString modelDirPath(const QString &modelName, bool isChatGPT); + QString modelDirPath(const QString &modelName, bool isOnline); int indexForModel(ModelInfo *model); QVariant dataInternal(const ModelInfo *info, int role) const; static bool lessThan(const ModelInfo* a, const ModelInfo* b); diff --git a/gpt4all-chat/qml/CollectionsDialog.qml b/gpt4all-chat/qml/CollectionsDialog.qml index a23db45e..2374fa3d 100644 --- a/gpt4all-chat/qml/CollectionsDialog.qml +++ b/gpt4all-chat/qml/CollectionsDialog.qml @@ -94,11 +94,13 @@ MyDialog { anchors.right: parent.right anchors.margins: 20 anchors.leftMargin: 40 - visible: model.indexing - value: (model.totalBytesToIndex - model.currentBytesToIndex) / model.totalBytesToIndex + visible: model.indexing || model.currentEmbeddingsToIndex !== model.totalEmbeddingsToIndex || model.error !== "" + value: model.error !== "" ? 0 : model.indexing ? + (model.totalBytesToIndex - model.currentBytesToIndex) / model.totalBytesToIndex : + (model.currentEmbeddingsToIndex / model.totalEmbeddingsToIndex) background: Rectangle { implicitHeight: 45 - color: theme.progressBackground + color: model.error ? theme.textErrorColor : theme.progressBackground radius: 3 } contentItem: Item { @@ -114,16 +116,18 @@ MyDialog { Accessible.role: Accessible.ProgressBar Accessible.name: qsTr("Indexing progressBar") Accessible.description: qsTr("Shows the progress made in the indexing") + ToolTip.text: model.error + ToolTip.visible: hovered && model.error !== "" } Label { id: speedLabel color: theme.textColor - visible: model.indexing + visible: model.indexing || model.currentEmbeddingsToIndex !== model.totalEmbeddingsToIndex anchors.verticalCenter: itemProgressBar.verticalCenter anchors.left: itemProgressBar.left anchors.right: itemProgressBar.right horizontalAlignment: Text.AlignHCenter - text: qsTr("indexing...") + text: model.error !== "" ? qsTr("error...") : (model.indexing ? qsTr("indexing...") : qsTr("embeddings...")) elide: Text.ElideRight font.pixelSize: theme.fontSizeLarge } diff --git a/gpt4all-chat/qml/ModelDownloaderDialog.qml b/gpt4all-chat/qml/ModelDownloaderDialog.qml index 8a79a7cb..dac1798b 100644 --- a/gpt4all-chat/qml/ModelDownloaderDialog.qml +++ b/gpt4all-chat/qml/ModelDownloaderDialog.qml @@ -135,10 +135,10 @@ MyDialog { font.pixelSize: theme.fontSizeLarge Layout.topMargin: 20 Layout.leftMargin: 20 - Layout.minimumWidth: openaiKey.width + Layout.minimumWidth: apiKey.width Layout.fillWidth: true Layout.alignment: Qt.AlignTop | Qt.AlignHCenter - visible: !isChatGPT && !installed && !calcHash && downloadError === "" + visible: !isOnline && !installed && !calcHash && downloadError === "" Accessible.description: qsTr("Stop/restart/start the download") onClicked: { if (!isDownloading) { @@ -154,7 +154,7 @@ MyDialog { text: qsTr("Remove") Layout.topMargin: 20 Layout.leftMargin: 20 - Layout.minimumWidth: openaiKey.width + Layout.minimumWidth: apiKey.width Layout.fillWidth: true Layout.alignment: Qt.AlignTop | Qt.AlignHCenter visible: installed || downloadError !== "" @@ -166,23 +166,23 @@ MyDialog { MySettingsButton { id: installButton - visible: !installed && isChatGPT + visible: !installed && isOnline Layout.topMargin: 20 Layout.leftMargin: 20 - Layout.minimumWidth: openaiKey.width + Layout.minimumWidth: apiKey.width Layout.fillWidth: true Layout.alignment: Qt.AlignTop | Qt.AlignHCenter text: qsTr("Install") font.pixelSize: theme.fontSizeLarge onClicked: { - if (openaiKey.text === "") - openaiKey.showError(); + if (apiKey.text === "") + apiKey.showError(); else - Download.installModel(filename, openaiKey.text); + Download.installModel(filename, apiKey.text); } Accessible.role: Accessible.Button Accessible.name: qsTr("Install") - Accessible.description: qsTr("Install chatGPT model") + Accessible.description: qsTr("Install online model") } ColumnLayout { @@ -238,7 +238,7 @@ MyDialog { visible: LLM.systemTotalRAMInGB() < ramrequired Layout.topMargin: 20 Layout.leftMargin: 20 - Layout.maximumWidth: openaiKey.width + Layout.maximumWidth: apiKey.width textFormat: Text.StyledText text: qsTr("WARNING: Not recommended for your hardware.") + qsTr(" Model requires more memory (") + ramrequired @@ -261,7 +261,7 @@ MyDialog { visible: isDownloading && !calcHash Layout.topMargin: 20 Layout.leftMargin: 20 - Layout.minimumWidth: openaiKey.width + Layout.minimumWidth: apiKey.width Layout.fillWidth: true Layout.alignment: Qt.AlignTop | Qt.AlignHCenter spacing: 20 @@ -269,7 +269,7 @@ MyDialog { ProgressBar { id: itemProgressBar Layout.fillWidth: true - width: openaiKey.width + width: apiKey.width value: bytesReceived / bytesTotal background: Rectangle { implicitHeight: 45 @@ -307,7 +307,7 @@ MyDialog { visible: calcHash Layout.topMargin: 20 Layout.leftMargin: 20 - Layout.minimumWidth: openaiKey.width + Layout.minimumWidth: apiKey.width Layout.fillWidth: true Layout.alignment: Qt.AlignTop | Qt.AlignHCenter @@ -331,8 +331,8 @@ MyDialog { } MyTextField { - id: openaiKey - visible: !installed && isChatGPT + id: apiKey + visible: !installed && isOnline Layout.topMargin: 20 Layout.leftMargin: 20 Layout.minimumWidth: 150 @@ -340,19 +340,19 @@ MyDialog { Layout.alignment: Qt.AlignTop | Qt.AlignHCenter wrapMode: Text.WrapAnywhere function showError() { - openaiKey.placeholderTextColor = theme.textErrorColor + apiKey.placeholderTextColor = theme.textErrorColor } onTextChanged: { - openaiKey.placeholderTextColor = theme.mutedTextColor + apiKey.placeholderTextColor = theme.mutedTextColor } - placeholderText: qsTr("enter $OPENAI_API_KEY") + placeholderText: qsTr("enter $API_KEY") Accessible.role: Accessible.EditableText Accessible.name: placeholderText Accessible.description: qsTr("Whether the file hash is being calculated") TextMetrics { id: textMetrics - font: openaiKey.font - text: openaiKey.placeholderText + font: apiKey.font + text: apiKey.placeholderText } } }