LocalDocs version 2 with text embeddings.

This commit is contained in:
Adam Treat 2023-10-24 12:13:32 -04:00 committed by AT
parent d4ce9f4a7c
commit 371e2a5cbc
30 changed files with 3540 additions and 164 deletions

View File

@ -490,6 +490,9 @@ struct bert_ctx * bert_load_from_file(const char *fname)
#endif #endif
bert_ctx * new_bert = new bert_ctx; bert_ctx * new_bert = new bert_ctx;
new_bert->buf_compute.force_cpu = true;
new_bert->work_buf.force_cpu = true;
bert_model & model = new_bert->model; bert_model & model = new_bert->model;
bert_vocab & vocab = new_bert->vocab; bert_vocab & vocab = new_bert->vocab;

View File

@ -10,13 +10,14 @@ struct llm_buffer {
uint8_t * addr = NULL; uint8_t * addr = NULL;
size_t size = 0; size_t size = 0;
ggml_vk_memory memory; ggml_vk_memory memory;
bool force_cpu = false;
llm_buffer() = default; llm_buffer() = default;
void resize(size_t size) { void resize(size_t size) {
free(); free();
if (!ggml_vk_has_device()) { if (!ggml_vk_has_device() || force_cpu) {
this->addr = new uint8_t[size]; this->addr = new uint8_t[size];
this->size = size; this->size = size;
} else { } else {

View File

@ -75,7 +75,9 @@ qt_add_executable(chat
chatmodel.h chatlistmodel.h chatlistmodel.cpp chatmodel.h chatlistmodel.h chatlistmodel.cpp
chatgpt.h chatgpt.cpp chatgpt.h chatgpt.cpp
database.h database.cpp database.h database.cpp
embeddings.h embeddings.cpp
download.h download.cpp download.h download.cpp
embllm.cpp embllm.h
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp
llm.h llm.cpp llm.h llm.cpp
modellist.h modellist.cpp modellist.h modellist.cpp
@ -90,6 +92,7 @@ qt_add_executable(chat
qt_add_qml_module(chat qt_add_qml_module(chat
URI gpt4all URI gpt4all
VERSION 1.0 VERSION 1.0
NO_CACHEGEN
QML_FILES QML_FILES
main.qml main.qml
qml/ChatDrawer.qml qml/ChatDrawer.qml
@ -170,7 +173,7 @@ else()
PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf) PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf)
endif() endif()
target_link_libraries(chat target_link_libraries(chat
PRIVATE llmodel) PRIVATE llmodel bert-default)
set(COMPONENT_NAME_MAIN ${PROJECT_NAME}) set(COMPONENT_NAME_MAIN ${PROJECT_NAME})
set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install) set(CMAKE_INSTALL_PREFIX ${CMAKE_BINARY_DIR}/install)

View File

@ -18,6 +18,7 @@ Chat::Chat(QObject *parent)
, m_shouldDeleteLater(false) , m_shouldDeleteLater(false)
, m_isModelLoaded(false) , m_isModelLoaded(false)
, m_shouldLoadModelWhenInstalled(false) , m_shouldLoadModelWhenInstalled(false)
, m_collectionModel(new LocalDocsCollectionsModel(this))
{ {
connectLLM(); connectLLM();
} }
@ -35,6 +36,7 @@ Chat::Chat(bool isServer, QObject *parent)
, m_shouldDeleteLater(false) , m_shouldDeleteLater(false)
, m_isModelLoaded(false) , m_isModelLoaded(false)
, m_shouldLoadModelWhenInstalled(false) , m_shouldLoadModelWhenInstalled(false)
, m_collectionModel(new LocalDocsCollectionsModel(this))
{ {
connectLLM(); connectLLM();
} }
@ -71,6 +73,7 @@ void Chat::connectLLM()
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection); connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection); connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
connect(ModelList::globalInstance()->installedModels(), &InstalledModels::countChanged, connect(ModelList::globalInstance()->installedModels(), &InstalledModels::countChanged,
this, &Chat::handleModelInstalled, Qt::QueuedConnection); this, &Chat::handleModelInstalled, Qt::QueuedConnection);
} }

View File

@ -27,6 +27,7 @@ class Chat : public QObject
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged); Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
Q_PROPERTY(QString device READ device NOTIFY deviceChanged); Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged); Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged);
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
QML_ELEMENT QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!") QML_UNCREATABLE("Only creatable from c++!")
@ -83,6 +84,7 @@ public:
bool isServer() const { return m_isServer; } bool isServer() const { return m_isServer; }
QList<QString> collectionList() const; QList<QString> collectionList() const;
LocalDocsCollectionsModel *collectionModel() const { return m_collectionModel; }
Q_INVOKABLE bool hasCollection(const QString &collection) const; Q_INVOKABLE bool hasCollection(const QString &collection) const;
Q_INVOKABLE void addCollection(const QString &collection); Q_INVOKABLE void addCollection(const QString &collection);
@ -123,6 +125,7 @@ Q_SIGNALS:
void tokenSpeedChanged(); void tokenSpeedChanged();
void deviceChanged(); void deviceChanged();
void fallbackReasonChanged(); void fallbackReasonChanged();
void collectionModelChanged();
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(const QString &response); void handleResponseChanged(const QString &response);
@ -161,6 +164,7 @@ private:
bool m_shouldDeleteLater; bool m_shouldDeleteLater;
bool m_isModelLoaded; bool m_isModelLoaded;
bool m_shouldLoadModelWhenInstalled; bool m_shouldLoadModelWhenInstalled;
LocalDocsCollectionsModel *m_collectionModel;
}; };
#endif // CHAT_H #endif // CHAT_H

View File

@ -1,5 +1,7 @@
#include "database.h" #include "database.h"
#include "mysettings.h" #include "mysettings.h"
#include "embllm.h"
#include "embeddings.h"
#include <QTimer> #include <QTimer>
#include <QPdfDocument> #include <QPdfDocument>
@ -7,18 +9,18 @@
//#define DEBUG //#define DEBUG
//#define DEBUG_EXAMPLE //#define DEBUG_EXAMPLE
#define LOCALDOCS_VERSION 0 #define LOCALDOCS_VERSION 1
const auto INSERT_CHUNK_SQL = QLatin1String(R"( const auto INSERT_CHUNK_SQL = QLatin1String(R"(
insert into chunks(document_id, chunk_id, chunk_text, insert into chunks(document_id, chunk_text,
file, title, author, subject, keywords, page, line_from, line_to, file, title, author, subject, keywords, page, line_from, line_to)
embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
)"); )");
const auto INSERT_CHUNK_FTS_SQL = QLatin1String(R"( const auto INSERT_CHUNK_FTS_SQL = QLatin1String(R"(
insert into chunks_fts(document_id, chunk_id, chunk_text, insert into chunks_fts(document_id, chunk_id, chunk_text,
file, title, author, subject, keywords, page, line_from, line_to, file, title, author, subject, keywords, page, line_from, line_to)
embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
)"); )");
const auto DELETE_CHUNKS_SQL = QLatin1String(R"( const auto DELETE_CHUNKS_SQL = QLatin1String(R"(
@ -30,20 +32,33 @@ const auto DELETE_CHUNKS_FTS_SQL = QLatin1String(R"(
)"); )");
const auto CHUNKS_SQL = QLatin1String(R"( const auto CHUNKS_SQL = QLatin1String(R"(
create table chunks(document_id integer, chunk_id integer, chunk_text varchar, create table chunks(document_id integer, chunk_id integer primary key autoincrement, chunk_text varchar,
file varchar, title varchar, author varchar, subject varchar, keywords varchar, file varchar, title varchar, author varchar, subject varchar, keywords varchar,
page integer, line_from integer, line_to integer, page integer, line_from integer, line_to integer);
embedding_id integer, embedding_path varchar);
)"); )");
const auto FTS_CHUNKS_SQL = QLatin1String(R"( const auto FTS_CHUNKS_SQL = QLatin1String(R"(
create virtual table chunks_fts using fts5(document_id unindexed, chunk_id unindexed, chunk_text, create virtual table chunks_fts using fts5(document_id unindexed, chunk_id unindexed, chunk_text,
file, title, author, subject, keywords, page, line_from, line_to, file, title, author, subject, keywords, page, line_from, line_to, tokenize="trigram");
embedding_id unindexed, embedding_path unindexed, tokenize="trigram");
)"); )");
const auto SELECT_SQL = QLatin1String(R"( const auto SELECT_CHUNKS_BY_DOCUMENT_SQL = QLatin1String(R"(
select chunks_fts.rowid, documents.document_time, select chunk_id from chunks WHERE document_id = ?;
)");
const auto SELECT_CHUNKS_SQL = QLatin1String(R"(
select chunks.chunk_id, documents.document_time,
chunks.chunk_text, chunks.file, chunks.title, chunks.author, chunks.page,
chunks.line_from, chunks.line_to
from chunks
join documents ON chunks.document_id = documents.id
join folders ON documents.folder_id = folders.id
join collections ON folders.id = collections.folder_id
where chunks.chunk_id in (%1) and collections.collection_name in (%2);
)");
const auto SELECT_NGRAM_SQL = QLatin1String(R"(
select chunks_fts.chunk_id, documents.document_time,
chunks_fts.chunk_text, chunks_fts.file, chunks_fts.title, chunks_fts.author, chunks_fts.page, chunks_fts.chunk_text, chunks_fts.file, chunks_fts.title, chunks_fts.author, chunks_fts.page,
chunks_fts.line_from, chunks_fts.line_to chunks_fts.line_from, chunks_fts.line_to
from chunks_fts from chunks_fts
@ -55,16 +70,14 @@ const auto SELECT_SQL = QLatin1String(R"(
limit %2; limit %2;
)"); )");
bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_text, bool addChunk(QSqlQuery &q, int document_id, const QString &chunk_text,
const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords,
int page, int from, int to, int page, int from, int to, int *chunk_id)
int embedding_id, const QString &embedding_path)
{ {
{ {
if (!q.prepare(INSERT_CHUNK_SQL)) if (!q.prepare(INSERT_CHUNK_SQL))
return false; return false;
q.addBindValue(document_id); q.addBindValue(document_id);
q.addBindValue(chunk_id);
q.addBindValue(chunk_text); q.addBindValue(chunk_text);
q.addBindValue(file); q.addBindValue(file);
q.addBindValue(title); q.addBindValue(title);
@ -74,16 +87,19 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_
q.addBindValue(page); q.addBindValue(page);
q.addBindValue(from); q.addBindValue(from);
q.addBindValue(to); q.addBindValue(to);
q.addBindValue(embedding_id);
q.addBindValue(embedding_path);
if (!q.exec()) if (!q.exec())
return false; return false;
} }
if (!q.exec("select last_insert_rowid();"))
return false;
if (!q.next())
return false;
*chunk_id = q.value(0).toInt();
{ {
if (!q.prepare(INSERT_CHUNK_FTS_SQL)) if (!q.prepare(INSERT_CHUNK_FTS_SQL))
return false; return false;
q.addBindValue(document_id); q.addBindValue(document_id);
q.addBindValue(chunk_id); q.addBindValue(*chunk_id);
q.addBindValue(chunk_text); q.addBindValue(chunk_text);
q.addBindValue(file); q.addBindValue(file);
q.addBindValue(title); q.addBindValue(title);
@ -93,8 +109,6 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_
q.addBindValue(page); q.addBindValue(page);
q.addBindValue(from); q.addBindValue(from);
q.addBindValue(to); q.addBindValue(to);
q.addBindValue(embedding_id);
q.addBindValue(embedding_path);
if (!q.exec()) if (!q.exec())
return false; return false;
} }
@ -146,6 +160,18 @@ QStringList generateGrams(const QString &input, int N)
return ngrams; return ngrams;
} }
bool selectChunk(QSqlQuery &q, const QList<QString> &collection_names, const std::vector<qint64> &chunk_ids, int retrievalSize)
{
QString chunk_ids_str = QString::number(chunk_ids[0]);
for (size_t i = 1; i < chunk_ids.size(); ++i)
chunk_ids_str += "," + QString::number(chunk_ids[i]);
const QString collection_names_str = collection_names.join("', '");
const QString formatted_query = SELECT_CHUNKS_SQL.arg(chunk_ids_str).arg("'" + collection_names_str + "'");
if (!q.prepare(formatted_query))
return false;
return q.exec();
}
bool selectChunk(QSqlQuery &q, const QList<QString> &collection_names, const QString &chunk_text, int retrievalSize) bool selectChunk(QSqlQuery &q, const QList<QString> &collection_names, const QString &chunk_text, int retrievalSize)
{ {
static QRegularExpression spaces("\\s+"); static QRegularExpression spaces("\\s+");
@ -155,7 +181,7 @@ bool selectChunk(QSqlQuery &q, const QList<QString> &collection_names, const QSt
QList<QString> text = generateGrams(chunk_text, N); QList<QString> text = generateGrams(chunk_text, N);
QString orText = text.join(" OR "); QString orText = text.join(" OR ");
const QString collection_names_str = collection_names.join("', '"); const QString collection_names_str = collection_names.join("', '");
const QString formatted_query = SELECT_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize)); const QString formatted_query = SELECT_NGRAM_SQL.arg("'" + collection_names_str + "'").arg(QString::number(retrievalSize));
if (!q.prepare(formatted_query)) if (!q.prepare(formatted_query))
return false; return false;
q.addBindValue(orText); q.addBindValue(orText);
@ -248,7 +274,8 @@ bool selectAllFromCollections(QSqlQuery &q, QList<CollectionItem> *collections)
CollectionItem i; CollectionItem i;
i.collection = q.value(0).toString(); i.collection = q.value(0).toString();
i.folder_path = q.value(1).toString(); i.folder_path = q.value(1).toString();
i.folder_id = q.value(0).toInt(); i.folder_id = q.value(2).toInt();
i.indexing = false;
i.installed = true; i.installed = true;
collections->append(i); collections->append(i);
} }
@ -459,6 +486,12 @@ QSqlError initDb()
return q.lastError(); return q.lastError();
} }
CollectionItem i;
i.collection = collection_name;
i.folder_path = folder_path;
i.folder_id = folder_id;
emit addCollectionItem(i);
// Add a document // Add a document
int document_time = 123456789; int document_time = 123456789;
int document_id; int document_id;
@ -504,6 +537,8 @@ Database::Database(int chunkSize)
: QObject(nullptr) : QObject(nullptr)
, m_watcher(new QFileSystemWatcher(this)) , m_watcher(new QFileSystemWatcher(this))
, m_chunkSize(chunkSize) , m_chunkSize(chunkSize)
, m_embLLM(new EmbeddingLLM)
, m_embeddings(new Embeddings(this))
{ {
moveToThread(&m_dbThread); moveToThread(&m_dbThread);
connect(&m_dbThread, &QThread::started, this, &Database::start); connect(&m_dbThread, &QThread::started, this, &Database::start);
@ -511,22 +546,39 @@ Database::Database(int chunkSize)
m_dbThread.start(); m_dbThread.start();
} }
void Database::handleDocumentErrorAndScheduleNext(const QString &errorMessage, Database::~Database()
int document_id, const QString &document_path, const QSqlError &error)
{ {
qWarning() << errorMessage << document_id << document_path << error.text(); m_dbThread.quit();
m_dbThread.wait();
}
void Database::scheduleNext(int folder_id, size_t countForFolder)
{
emit updateCurrentDocsToIndex(folder_id, countForFolder);
if (!countForFolder) {
emit updateIndexing(folder_id, false);
emit updateInstalled(folder_id, true);
m_embeddings->save();
}
if (!m_docsToScan.isEmpty()) if (!m_docsToScan.isEmpty())
QTimer::singleShot(0, this, &Database::scanQueue); QTimer::singleShot(0, this, &Database::scanQueue);
} }
void Database::chunkStream(QTextStream &stream, int document_id, const QString &file, void Database::handleDocumentError(const QString &errorMessage,
const QString &title, const QString &author, const QString &subject, const QString &keywords, int page) int document_id, const QString &document_path, const QSqlError &error)
{
qWarning() << errorMessage << document_id << document_path << error.text();
}
size_t Database::chunkStream(QTextStream &stream, int document_id, const QString &file,
const QString &title, const QString &author, const QString &subject, const QString &keywords, int page,
int maxChunks)
{ {
int chunk_id = 0;
int charCount = 0; int charCount = 0;
int line_from = -1; int line_from = -1;
int line_to = -1; int line_to = -1;
QList<QString> words; QList<QString> words;
int chunks = 0;
while (!stream.atEnd()) { while (!stream.atEnd()) {
QString word; QString word;
@ -536,9 +588,9 @@ void Database::chunkStream(QTextStream &stream, int document_id, const QString &
if (charCount + words.size() - 1 >= m_chunkSize || stream.atEnd()) { if (charCount + words.size() - 1 >= m_chunkSize || stream.atEnd()) {
const QString chunk = words.join(" "); const QString chunk = words.join(" ");
QSqlQuery q; QSqlQuery q;
int chunk_id = 0;
if (!addChunk(q, if (!addChunk(q,
document_id, document_id,
++chunk_id,
chunk, chunk,
file, file,
title, title,
@ -548,15 +600,111 @@ void Database::chunkStream(QTextStream &stream, int document_id, const QString &
page, page,
line_from, line_from,
line_to, line_to,
0 /*embedding_id*/, &chunk_id
QString() /*embedding_path*/
)) { )) {
qWarning() << "ERROR: Could not insert chunk into db" << q.lastError(); qWarning() << "ERROR: Could not insert chunk into db" << q.lastError();
} }
const std::vector<float> result = m_embLLM->generateEmbeddings(chunk);
if (!m_embeddings->add(result, chunk_id))
qWarning() << "ERROR: Cannot add point to embeddings index";
++chunks;
words.clear(); words.clear();
charCount = 0; charCount = 0;
if (maxChunks > 0 && chunks == maxChunks)
return stream.pos();
} }
} }
return stream.pos();
}
void Database::removeEmbeddingsByDocumentId(int document_id)
{
QSqlQuery q;
if (!q.prepare(SELECT_CHUNKS_BY_DOCUMENT_SQL)) {
qWarning() << "ERROR: Cannot prepare sql for select chunks by document" << q.lastError();
return;
}
q.addBindValue(document_id);
if (!q.exec()) {
qWarning() << "ERROR: Cannot exec sql for select chunks by document" << q.lastError();
return;
}
while (q.next()) {
const int chunk_id = q.value(0).toInt();
m_embeddings->remove(chunk_id);
}
m_embeddings->save();
}
size_t Database::countOfDocuments(int folder_id) const
{
if (!m_docsToScan.contains(folder_id))
return 0;
return m_docsToScan.value(folder_id).size();
}
size_t Database::countOfBytes(int folder_id) const
{
if (!m_docsToScan.contains(folder_id))
return 0;
size_t totalBytes = 0;
const QQueue<DocumentInfo> &docs = m_docsToScan.value(folder_id);
for (const DocumentInfo &f : docs)
totalBytes += f.doc.size();
return totalBytes;
}
DocumentInfo Database::dequeueDocument()
{
Q_ASSERT(!m_docsToScan.isEmpty());
const int firstKey = m_docsToScan.firstKey();
QQueue<DocumentInfo> &queue = m_docsToScan[firstKey];
Q_ASSERT(!queue.isEmpty());
DocumentInfo result = queue.dequeue();
if (queue.isEmpty())
m_docsToScan.remove(firstKey);
return result;
}
void Database::removeFolderFromDocumentQueue(int folder_id)
{
if (!m_docsToScan.contains(folder_id))
return;
m_docsToScan.remove(folder_id);
emit removeFolderById(folder_id);
emit docsToScanChanged();
}
void Database::enqueueDocumentInternal(const DocumentInfo &info, bool prepend)
{
const int key = info.folder;
if (!m_docsToScan.contains(key))
m_docsToScan[key] = QQueue<DocumentInfo>();
if (prepend)
m_docsToScan[key].prepend(info);
else
m_docsToScan[key].enqueue(info);
}
void Database::enqueueDocuments(int folder_id, const QVector<DocumentInfo> &infos)
{
for (int i = 0; i < infos.size(); ++i)
enqueueDocumentInternal(infos[i]);
const size_t count = countOfDocuments(folder_id);
emit updateCurrentDocsToIndex(folder_id, count);
emit updateTotalDocsToIndex(folder_id, count);
const size_t bytes = countOfBytes(folder_id);
emit updateCurrentBytesToIndex(folder_id, bytes);
emit updateTotalBytesToIndex(folder_id, bytes);
emit docsToScanChanged();
} }
void Database::scanQueue() void Database::scanQueue()
@ -564,7 +712,9 @@ void Database::scanQueue()
if (m_docsToScan.isEmpty()) if (m_docsToScan.isEmpty())
return; return;
DocumentInfo info = m_docsToScan.dequeue(); DocumentInfo info = dequeueDocument();
const size_t countForFolder = countOfDocuments(info.folder);
const int folder_id = info.folder;
// Update info // Update info
info.doc.stat(); info.doc.stat();
@ -572,71 +722,74 @@ void Database::scanQueue()
// If the doc has since been deleted or no longer readable, then we schedule more work and return // If the doc has since been deleted or no longer readable, then we schedule more work and return
// leaving the cleanup for the cleanup handler // leaving the cleanup for the cleanup handler
if (!info.doc.exists() || !info.doc.isReadable()) { if (!info.doc.exists() || !info.doc.isReadable()) {
if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue); return scheduleNext(folder_id, countForFolder);
return;
} }
const int folder_id = info.folder;
const qint64 document_time = info.doc.fileTime(QFile::FileModificationTime).toMSecsSinceEpoch(); const qint64 document_time = info.doc.fileTime(QFile::FileModificationTime).toMSecsSinceEpoch();
const QString document_path = info.doc.canonicalFilePath(); const QString document_path = info.doc.canonicalFilePath();
const bool currentlyProcessing = info.currentlyProcessing;
#if defined(DEBUG)
qDebug() << "scanning document" << document_path;
#endif
// Check and see if we already have this document // Check and see if we already have this document
QSqlQuery q; QSqlQuery q;
int existing_id = -1; int existing_id = -1;
qint64 existing_time = -1; qint64 existing_time = -1;
if (!selectDocument(q, document_path, &existing_id, &existing_time)) { if (!selectDocument(q, document_path, &existing_id, &existing_time)) {
return handleDocumentErrorAndScheduleNext("ERROR: Cannot select document", handleDocumentError("ERROR: Cannot select document",
existing_id, document_path, q.lastError()); existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
} }
// If we have the document, we need to compare the last modification time and if it is newer // If we have the document, we need to compare the last modification time and if it is newer
// we must rescan the document, otherwise return // we must rescan the document, otherwise return
if (existing_id != -1) { if (existing_id != -1 && !currentlyProcessing) {
Q_ASSERT(existing_time != -1); Q_ASSERT(existing_time != -1);
if (document_time == existing_time) { if (document_time == existing_time) {
// No need to rescan, but we do have to schedule next // No need to rescan, but we do have to schedule next
if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue); return scheduleNext(folder_id, countForFolder);
return;
} else { } else {
removeEmbeddingsByDocumentId(existing_id);
if (!removeChunksByDocumentId(q, existing_id)) { if (!removeChunksByDocumentId(q, existing_id)) {
return handleDocumentErrorAndScheduleNext("ERROR: Cannot remove chunks of document", handleDocumentError("ERROR: Cannot remove chunks of document",
existing_id, document_path, q.lastError()); existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
} }
} }
} }
// Update the document_time for an existing document, or add it for the first time now // Update the document_time for an existing document, or add it for the first time now
int document_id = existing_id; int document_id = existing_id;
if (!currentlyProcessing) {
if (document_id != -1) { if (document_id != -1) {
if (!updateDocument(q, document_id, document_time)) { if (!updateDocument(q, document_id, document_time)) {
return handleDocumentErrorAndScheduleNext("ERROR: Could not update document_time", handleDocumentError("ERROR: Could not update document_time",
document_id, document_path, q.lastError()); document_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
} }
} else { } else {
if (!addDocument(q, folder_id, document_time, document_path, &document_id)) { if (!addDocument(q, folder_id, document_time, document_path, &document_id)) {
return handleDocumentErrorAndScheduleNext("ERROR: Could not add document", handleDocumentError("ERROR: Could not add document",
document_id, document_path, q.lastError()); document_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
} }
} }
QElapsedTimer timer;
timer.start();
QSqlDatabase::database().transaction(); QSqlDatabase::database().transaction();
Q_ASSERT(document_id != -1); Q_ASSERT(document_id != -1);
if (info.doc.suffix() == QLatin1String("pdf")) { if (info.isPdf()) {
QPdfDocument doc; QPdfDocument doc;
if (QPdfDocument::Error::None != doc.load(info.doc.canonicalFilePath())) { if (QPdfDocument::Error::None != doc.load(info.doc.canonicalFilePath())) {
return handleDocumentErrorAndScheduleNext("ERROR: Could not load pdf", handleDocumentError("ERROR: Could not load pdf",
document_id, document_path, q.lastError()); document_id, document_path, q.lastError());
return; return scheduleNext(folder_id, countForFolder);
} }
for (int i = 0; i < doc.pageCount(); ++i) { const size_t bytes = info.doc.size();
const QPdfSelection selection = doc.getAllText(i); const size_t bytesPerPage = std::floor(bytes / doc.pageCount());
const int pageIndex = info.currentPage;
#if defined(DEBUG)
qDebug() << "scanning page" << pageIndex << "of" << doc.pageCount() << document_path;
#endif
const QPdfSelection selection = doc.getAllText(pageIndex);
QString text = selection.text(); QString text = selection.text();
QTextStream stream(&text); QTextStream stream(&text);
chunkStream(stream, document_id, info.doc.fileName(), chunkStream(stream, document_id, info.doc.fileName(),
@ -644,27 +797,52 @@ void Database::scanQueue()
doc.metaData(QPdfDocument::MetaDataField::Author).toString(), doc.metaData(QPdfDocument::MetaDataField::Author).toString(),
doc.metaData(QPdfDocument::MetaDataField::Subject).toString(), doc.metaData(QPdfDocument::MetaDataField::Subject).toString(),
doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(), doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(),
i + 1 pageIndex + 1
); );
m_embeddings->save();
emit subtractCurrentBytesToIndex(info.folder, bytesPerPage);
if (info.currentPage < doc.pageCount()) {
info.currentPage += 1;
info.currentlyProcessing = true;
enqueueDocumentInternal(info, true /*prepend*/);
return scheduleNext(folder_id, countForFolder + 1);
} else {
emit subtractCurrentBytesToIndex(info.folder, bytes - (bytesPerPage * doc.pageCount()));
} }
} else { } else {
QFile file(document_path); QFile file(document_path);
if (!file.open( QIODevice::ReadOnly)) { if (!file.open(QIODevice::ReadOnly)) {
return handleDocumentErrorAndScheduleNext("ERROR: Cannot open file for scanning", handleDocumentError("ERROR: Cannot open file for scanning",
existing_id, document_path, q.lastError()); existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
} }
const size_t bytes = info.doc.size();
QTextStream stream(&file); QTextStream stream(&file);
chunkStream(stream, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/, const size_t byteIndex = info.currentPosition;
QString() /*subject*/, QString() /*keywords*/, -1 /*page*/); if (!stream.seek(byteIndex)) {
handleDocumentError("ERROR: Cannot seek to pos for scanning",
existing_id, document_path, q.lastError());
return scheduleNext(folder_id, countForFolder);
}
#if defined(DEBUG)
qDebug() << "scanning byteIndex" << byteIndex << "of" << bytes << document_path;
#endif
int pos = chunkStream(stream, document_id, info.doc.fileName(), QString() /*title*/, QString() /*author*/,
QString() /*subject*/, QString() /*keywords*/, -1 /*page*/, 5 /*maxChunks*/);
m_embeddings->save();
file.close(); file.close();
const size_t bytesChunked = pos - byteIndex;
emit subtractCurrentBytesToIndex(info.folder, bytesChunked);
if (info.currentPosition < bytes) {
info.currentPosition = pos;
info.currentlyProcessing = true;
enqueueDocumentInternal(info, true /*prepend*/);
return scheduleNext(folder_id, countForFolder + 1);
}
} }
QSqlDatabase::database().commit(); QSqlDatabase::database().commit();
return scheduleNext(folder_id, countForFolder);
#if defined(DEBUG)
qDebug() << "chunking" << document_path << "took" << timer.elapsed() << "ms";
#endif
if (!m_docsToScan.isEmpty()) QTimer::singleShot(0, this, &Database::scanQueue);
} }
void Database::scanDocuments(int folder_id, const QString &folder_path) void Database::scanDocuments(int folder_id, const QString &folder_path)
@ -687,6 +865,7 @@ void Database::scanDocuments(int folder_id, const QString &folder_path)
Q_ASSERT(dir.exists()); Q_ASSERT(dir.exists());
Q_ASSERT(dir.isReadable()); Q_ASSERT(dir.isReadable());
QDirIterator it(folder_path, QDir::Readable | QDir::Files, QDirIterator::Subdirectories); QDirIterator it(folder_path, QDir::Readable | QDir::Files, QDirIterator::Subdirectories);
QVector<DocumentInfo> infos;
while (it.hasNext()) { while (it.hasNext()) {
it.next(); it.next();
QFileInfo fileInfo = it.fileInfo(); QFileInfo fileInfo = it.fileInfo();
@ -701,9 +880,13 @@ void Database::scanDocuments(int folder_id, const QString &folder_path)
DocumentInfo info; DocumentInfo info;
info.folder = folder_id; info.folder = folder_id;
info.doc = fileInfo; info.doc = fileInfo;
m_docsToScan.enqueue(info); infos.append(info);
}
if (!infos.isEmpty()) {
emit updateIndexing(folder_id, true);
enqueueDocuments(folder_id, infos);
} }
emit docsToScanChanged();
} }
void Database::start() void Database::start()
@ -717,6 +900,10 @@ void Database::start()
if (err.type() != QSqlError::NoError) if (err.type() != QSqlError::NoError)
qWarning() << "ERROR: initializing db" << err.text(); qWarning() << "ERROR: initializing db" << err.text();
} }
if (m_embeddings->fileExists() && !m_embeddings->load())
qWarning() << "ERROR: Could not load embeddings";
addCurrentFolders(); addCurrentFolders();
} }
@ -733,25 +920,12 @@ void Database::addCurrentFolders()
return; return;
} }
emit collectionListUpdated(collections);
for (const auto &i : collections) for (const auto &i : collections)
addFolder(i.collection, i.folder_path); addFolder(i.collection, i.folder_path);
} }
void Database::updateCollectionList()
{
#if defined(DEBUG)
qDebug() << "updateCollectionList";
#endif
QSqlQuery q;
QList<CollectionItem> collections;
if (!selectAllFromCollections(q, &collections)) {
qWarning() << "ERROR: Cannot select collections" << q.lastError();
return;
}
emit collectionListUpdated(collections);
}
void Database::addFolder(const QString &collection, const QString &path) void Database::addFolder(const QString &collection, const QString &path)
{ {
QFileInfo info(path); QFileInfo info(path);
@ -784,14 +958,21 @@ void Database::addFolder(const QString &collection, const QString &path)
return; return;
} }
if (!folders.contains(folder_id) && !addCollection(q, collection, folder_id)) { if (!folders.contains(folder_id)) {
if (!addCollection(q, collection, folder_id)) {
qWarning() << "ERROR: Cannot add folder to collection" << collection << path << q.lastError(); qWarning() << "ERROR: Cannot add folder to collection" << collection << path << q.lastError();
return; return;
} }
CollectionItem i;
i.collection = collection;
i.folder_path = path;
i.folder_id = folder_id;
emit addCollectionItem(i);
}
addFolderToWatch(path); addFolderToWatch(path);
scanDocuments(folder_id, path); scanDocuments(folder_id, path);
updateCollectionList();
} }
void Database::removeFolder(const QString &collection, const QString &path) void Database::removeFolder(const QString &collection, const QString &path)
@ -840,15 +1021,8 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co
if (collections.count() > 1) if (collections.count() > 1)
return; return;
// First remove all upcoming jobs associated with this folder by performing an opt-in filter // First remove all upcoming jobs associated with this folder
QQueue<DocumentInfo> docsToScan; removeFolderFromDocumentQueue(folder_id);
for (const DocumentInfo &info : m_docsToScan) {
if (info.folder == folder_id)
continue;
docsToScan.append(info);
}
m_docsToScan = docsToScan;
emit docsToScanChanged();
// Get a list of all documents associated with folder // Get a list of all documents associated with folder
QList<int> documentIds; QList<int> documentIds;
@ -859,6 +1033,7 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co
// Remove all chunks and documents associated with this folder // Remove all chunks and documents associated with this folder
for (int document_id : documentIds) { for (int document_id : documentIds) {
removeEmbeddingsByDocumentId(document_id);
if (!removeChunksByDocumentId(q, document_id)) { if (!removeChunksByDocumentId(q, document_id)) {
qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << q.lastError(); qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << q.lastError();
return; return;
@ -875,8 +1050,9 @@ void Database::removeFolderInternal(const QString &collection, int folder_id, co
return; return;
} }
emit removeFolderById(folder_id);
removeFolderFromWatch(path); removeFolderFromWatch(path);
updateCollectionList();
} }
bool Database::addFolderToWatch(const QString &path) bool Database::addFolderToWatch(const QString &path)
@ -903,10 +1079,19 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
#endif #endif
QSqlQuery q; QSqlQuery q;
if (m_embeddings->isLoaded()) {
std::vector<float> result = m_embLLM->generateEmbeddings(text);
std::vector<qint64> embeddings = m_embeddings->search(result, retrievalSize);
if (!selectChunk(q, collections, embeddings, retrievalSize)) {
qDebug() << "ERROR: selecting chunks:" << q.lastError().text();
return;
}
} else {
if (!selectChunk(q, collections, text, retrievalSize)) { if (!selectChunk(q, collections, text, retrievalSize)) {
qDebug() << "ERROR: selecting chunks:" << q.lastError().text(); qDebug() << "ERROR: selecting chunks:" << q.lastError().text();
return; return;
} }
}
while (q.next()) { while (q.next()) {
#if defined(DEBUG) #if defined(DEBUG)
@ -986,6 +1171,7 @@ void Database::cleanDB()
// Remove all chunks and documents that either don't exist or have become unreadable // Remove all chunks and documents that either don't exist or have become unreadable
QSqlQuery query; QSqlQuery query;
removeEmbeddingsByDocumentId(document_id);
if (!removeChunksByDocumentId(query, document_id)) { if (!removeChunksByDocumentId(query, document_id)) {
qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError(); qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError();
} }
@ -994,7 +1180,6 @@ void Database::cleanDB()
qWarning() << "ERROR: Cannot remove document_id" << document_id << query.lastError(); qWarning() << "ERROR: Cannot remove document_id" << document_id << query.lastError();
} }
} }
updateCollectionList();
} }
void Database::changeChunkSize(int chunkSize) void Database::changeChunkSize(int chunkSize)
@ -1024,6 +1209,7 @@ void Database::changeChunkSize(int chunkSize)
int document_id = q.value(0).toInt(); int document_id = q.value(0).toInt();
// Remove all chunks and documents to change the chunk size // Remove all chunks and documents to change the chunk size
QSqlQuery query; QSqlQuery query;
removeEmbeddingsByDocumentId(document_id);
if (!removeChunksByDocumentId(query, document_id)) { if (!removeChunksByDocumentId(query, document_id)) {
qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError(); qWarning() << "ERROR: Cannot remove chunks of document_id" << document_id << query.lastError();
} }

View File

@ -8,10 +8,18 @@
#include <QThread> #include <QThread>
#include <QFileSystemWatcher> #include <QFileSystemWatcher>
class Embeddings;
class EmbeddingLLM;
struct DocumentInfo struct DocumentInfo
{ {
int folder; int folder;
QFileInfo doc; QFileInfo doc;
int currentPage = 0;
size_t currentPosition = 0;
bool currentlyProcessing = false;
bool isPdf() const {
return doc.suffix() == QLatin1String("pdf");
}
}; };
struct ResultInfo { struct ResultInfo {
@ -30,6 +38,11 @@ struct CollectionItem {
QString folder_path; QString folder_path;
int folder_id = -1; int folder_id = -1;
bool installed = false; bool installed = false;
bool indexing = false;
int currentDocsToIndex = 0;
int totalDocsToIndex = 0;
size_t currentBytesToIndex = 0;
size_t totalBytesToIndex = 0;
}; };
Q_DECLARE_METATYPE(CollectionItem) Q_DECLARE_METATYPE(CollectionItem)
@ -38,6 +51,7 @@ class Database : public QObject
Q_OBJECT Q_OBJECT
public: public:
Database(int chunkSize); Database(int chunkSize);
virtual ~Database();
public Q_SLOTS: public Q_SLOTS:
void scanQueue(); void scanQueue();
@ -50,6 +64,16 @@ public Q_SLOTS:
Q_SIGNALS: Q_SIGNALS:
void docsToScanChanged(); void docsToScanChanged();
void updateInstalled(int folder_id, bool b);
void updateIndexing(int folder_id, bool b);
void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex);
void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex);
void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes);
void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex);
void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex);
void addCollectionItem(const CollectionItem &item);
void removeFolderById(int folder_id);
void removeCollectionItem(const QString &collectionName);
void collectionListUpdated(const QList<CollectionItem> &collectionList); void collectionListUpdated(const QList<CollectionItem> &collectionList);
private Q_SLOTS: private Q_SLOTS:
@ -58,21 +82,31 @@ private Q_SLOTS:
bool addFolderToWatch(const QString &path); bool addFolderToWatch(const QString &path);
bool removeFolderFromWatch(const QString &path); bool removeFolderFromWatch(const QString &path);
void addCurrentFolders(); void addCurrentFolders();
void updateCollectionList();
private: private:
void removeFolderInternal(const QString &collection, int folder_id, const QString &path); void removeFolderInternal(const QString &collection, int folder_id, const QString &path);
void chunkStream(QTextStream &stream, int document_id, const QString &file, size_t chunkStream(QTextStream &stream, int document_id, const QString &file,
const QString &title, const QString &author, const QString &subject, const QString &keywords, int page); const QString &title, const QString &author, const QString &subject, const QString &keywords, int page,
void handleDocumentErrorAndScheduleNext(const QString &errorMessage, int maxChunks = -1);
void removeEmbeddingsByDocumentId(int document_id);
void scheduleNext(int folder_id, size_t countForFolder);
void handleDocumentError(const QString &errorMessage,
int document_id, const QString &document_path, const QSqlError &error); int document_id, const QString &document_path, const QSqlError &error);
size_t countOfDocuments(int folder_id) const;
size_t countOfBytes(int folder_id) const;
DocumentInfo dequeueDocument();
void removeFolderFromDocumentQueue(int folder_id);
void enqueueDocumentInternal(const DocumentInfo &info, bool prepend = false);
void enqueueDocuments(int folder_id, const QVector<DocumentInfo> &infos);
private: private:
int m_chunkSize; int m_chunkSize;
QQueue<DocumentInfo> m_docsToScan; QMap<int, QQueue<DocumentInfo>> m_docsToScan;
QList<ResultInfo> m_retrieve; QList<ResultInfo> m_retrieve;
QThread m_dbThread; QThread m_dbThread;
QFileSystemWatcher *m_watcher; QFileSystemWatcher *m_watcher;
EmbeddingLLM *m_embLLM;
Embeddings *m_embeddings;
}; };
#endif // DATABASE_H #endif // DATABASE_H

190
gpt4all-chat/embeddings.cpp Normal file
View File

@ -0,0 +1,190 @@
#include "embeddings.h"
#include <QFile>
#include <QFileInfo>
#include <QDebug>
#include "mysettings.h"
#include "hnswlib/hnswlib.h"
#define EMBEDDINGS_VERSION 0
const int s_dim = 384; // Dimension of the elements
const int s_ef_construction = 200; // Controls index search speed/build speed tradeoff
const int s_M = 16; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption
Embeddings::Embeddings(QObject *parent)
: QObject(parent)
, m_space(nullptr)
, m_hnsw(nullptr)
{
m_filePath = MySettings::globalInstance()->modelPath()
+ QString("embeddings_v%1.dat").arg(EMBEDDINGS_VERSION);
}
Embeddings::~Embeddings()
{
delete m_hnsw;
m_hnsw = nullptr;
delete m_space;
m_space = nullptr;
}
bool Embeddings::load()
{
QFileInfo info(m_filePath);
if (!info.exists()) {
qWarning() << "ERROR: loading embeddings file does not exist" << m_filePath;
return false;
}
if (!info.isReadable()) {
qWarning() << "ERROR: loading embeddings file is not readable" << m_filePath;
return false;
}
if (!info.isWritable()) {
qWarning() << "ERROR: loading embeddings file is not writeable" << m_filePath;
return false;
}
try {
m_space = new hnswlib::InnerProductSpace(s_dim);
m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, m_filePath.toStdString(), s_M, s_ef_construction);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not load hnswlib index:" << e.what();
return false;
}
return isLoaded();
}
bool Embeddings::load(qint64 maxElements)
{
try {
m_space = new hnswlib::InnerProductSpace(s_dim);
m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, maxElements, s_M, s_ef_construction);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not create hnswlib index:" << e.what();
return false;
}
return isLoaded();
}
bool Embeddings::save()
{
if (!isLoaded())
return false;
try {
m_hnsw->saveIndex(m_filePath.toStdString());
} catch (const std::exception &e) {
qWarning() << "ERROR: could not save hnswlib index:" << e.what();
return false;
}
return true;
}
bool Embeddings::isLoaded() const
{
return m_hnsw != nullptr;
}
bool Embeddings::fileExists() const
{
QFileInfo info(m_filePath);
return info.exists();
}
bool Embeddings::resize(qint64 size)
{
if (!isLoaded()) {
qWarning() << "ERROR: attempting to resize an embedding when the embeddings are not open!";
return false;
}
Q_ASSERT(m_hnsw);
try {
m_hnsw->resizeIndex(size);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not resize hnswlib index:" << e.what();
return false;
}
return true;
}
bool Embeddings::add(const std::vector<float> &embedding, qint64 label)
{
if (!isLoaded()) {
bool success = load(500);
if (!success) {
qWarning() << "ERROR: attempting to add an embedding when the embeddings are not open!";
return false;
}
}
Q_ASSERT(m_hnsw);
if (m_hnsw->cur_element_count + 1 > m_hnsw->max_elements_) {
if (!resize(m_hnsw->max_elements_ + 500)) {
return false;
}
}
try {
m_hnsw->addPoint(embedding.data(), label, false);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not add embedding to hnswlib index:" << e.what();
return false;
}
return true;
}
void Embeddings::remove(qint64 label)
{
if (!isLoaded()) {
qWarning() << "ERROR: attempting to remove an embedding when the embeddings are not open!";
return;
}
Q_ASSERT(m_hnsw);
try {
m_hnsw->markDelete(label);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not add remove embedding from hnswlib index:" << e.what();
}
}
void Embeddings::clear()
{
delete m_hnsw;
m_hnsw = nullptr;
delete m_space;
m_space = nullptr;
}
std::vector<qint64> Embeddings::search(const std::vector<float> &embedding, int K)
{
if (!isLoaded())
return std::vector<qint64>();
Q_ASSERT(m_hnsw);
std::priority_queue<std::pair<float, hnswlib::labeltype>> result;
try {
result = m_hnsw->searchKnn(embedding.data(), K);
} catch (const std::exception &e) {
qWarning() << "ERROR: could not search hnswlib index:" << e.what();
return std::vector<qint64>();
}
std::vector<qint64> neighbors;
neighbors.reserve(K);
while(!result.empty()) {
neighbors.push_back(result.top().second);
result.pop();
}
// Reverse the neighbors, as the top of the priority queue is the farthest neighbor.
std::reverse(neighbors.begin(), neighbors.end());
return neighbors;
}

45
gpt4all-chat/embeddings.h Normal file
View File

@ -0,0 +1,45 @@
#ifndef EMBEDDINGS_H
#define EMBEDDINGS_H
#include <QObject>
namespace hnswlib {
template <typename T>
class HierarchicalNSW;
class InnerProductSpace;
}
class Embeddings : public QObject
{
Q_OBJECT
public:
Embeddings(QObject *parent);
virtual ~Embeddings();
bool load();
bool load(qint64 maxElements);
bool save();
bool isLoaded() const;
bool fileExists() const;
bool resize(qint64 size);
// Adds the embedding and returns the label used
bool add(const std::vector<float> &embedding, qint64 label);
// Removes the embedding at label by marking it as unused
void remove(qint64 label);
// Clears the embeddings
void clear();
// Performs a nearest neighbor search of the embeddings and returns a vector of labels
// for the K nearest neighbors of the given embedding
std::vector<qint64> search(const std::vector<float> &embedding, int K);
private:
QString m_filePath;
hnswlib::InnerProductSpace *m_space;
hnswlib::HierarchicalNSW<float> *m_hnsw;
};
#endif // EMBEDDINGS_H

64
gpt4all-chat/embllm.cpp Normal file
View File

@ -0,0 +1,64 @@
#include "embllm.h"
#include "modellist.h"
EmbeddingLLM::EmbeddingLLM()
: QObject{nullptr}
, m_model{nullptr}
{
}
EmbeddingLLM::~EmbeddingLLM()
{
delete m_model;
m_model = nullptr;
}
bool EmbeddingLLM::loadModel()
{
const EmbeddingModels *embeddingModels = ModelList::globalInstance()->embeddingModels();
if (!embeddingModels->count())
return false;
const ModelInfo defaultModel = embeddingModels->defaultModelInfo();
QString filePath = defaultModel.dirpath + defaultModel.filename();
QFileInfo fileInfo(filePath);
if (!fileInfo.exists()) {
qWarning() << "WARNING: Could not load sbert because file does not exist";
m_model = nullptr;
return false;
}
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
bool success = m_model->loadModel(filePath.toStdString());
if (!success) {
qWarning() << "WARNING: Could not load sbert";
delete m_model;
m_model = nullptr;
return false;
}
if (m_model->implementation().modelType()[0] != 'B') {
qWarning() << "WARNING: Model type is not sbert";
delete m_model;
m_model = nullptr;
return false;
}
return true;
}
bool EmbeddingLLM::hasModel() const
{
return m_model;
}
std::vector<float> EmbeddingLLM::generateEmbeddings(const QString &text)
{
if (!hasModel() && !loadModel()) {
qWarning() << "WARNING: Could not load sbert model for embeddings";
return std::vector<float>();
}
Q_ASSERT(hasModel());
return m_model->embedding(text.toStdString());
}

27
gpt4all-chat/embllm.h Normal file
View File

@ -0,0 +1,27 @@
#ifndef EMBLLM_H
#define EMBLLM_H
#include <QObject>
#include <QThread>
#include "../gpt4all-backend/llmodel.h"
class EmbeddingLLM : public QObject
{
Q_OBJECT
public:
EmbeddingLLM();
virtual ~EmbeddingLLM();
bool hasModel() const;
public Q_SLOTS:
std::vector<float> generateEmbeddings(const QString &text);
private:
bool loadModel();
private:
LLModel *m_model = nullptr;
};
#endif // EMBLLM_H

View File

@ -0,0 +1,167 @@
#pragma once
#include <unordered_map>
#include <fstream>
#include <mutex>
#include <algorithm>
#include <assert.h>
namespace hnswlib {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
public:
char *data_;
size_t maxelements_;
size_t cur_element_count;
size_t size_per_element_;
size_t data_size_;
DISTFUNC <dist_t> fstdistfunc_;
void *dist_func_param_;
std::mutex index_lock;
std::unordered_map<labeltype, size_t > dict_external_to_internal;
BruteforceSearch(SpaceInterface <dist_t> *s)
: data_(nullptr),
maxelements_(0),
cur_element_count(0),
size_per_element_(0),
data_size_(0),
dist_func_param_(nullptr) {
}
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
: data_(nullptr),
maxelements_(0),
cur_element_count(0),
size_per_element_(0),
data_size_(0),
dist_func_param_(nullptr) {
loadIndex(location, s);
}
BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
maxelements_ = maxElements;
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxElements * size_per_element_);
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
cur_element_count = 0;
}
~BruteforceSearch() {
free(data_);
}
void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
int idx;
{
std::unique_lock<std::mutex> lock(index_lock);
auto search = dict_external_to_internal.find(label);
if (search != dict_external_to_internal.end()) {
idx = search->second;
} else {
if (cur_element_count >= maxelements_) {
throw std::runtime_error("The number of elements exceeds the specified limit\n");
}
idx = cur_element_count;
dict_external_to_internal[label] = idx;
cur_element_count++;
}
}
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
}
void removePoint(labeltype cur_external) {
size_t cur_c = dict_external_to_internal[cur_external];
dict_external_to_internal.erase(cur_external);
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
dict_external_to_internal[label] = cur_c;
memcpy(data_ + size_per_element_ * cur_c,
data_ + size_per_element_ * (cur_element_count-1),
data_size_+sizeof(labeltype));
cur_element_count--;
}
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
topResults.pop();
if (!topResults.empty()) {
lastdist = topResults.top().first;
}
}
}
return topResults;
}
void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
std::streampos position;
writeBinaryPOD(output, maxelements_);
writeBinaryPOD(output, size_per_element_);
writeBinaryPOD(output, cur_element_count);
output.write(data_, maxelements_ * size_per_element_);
output.close();
}
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
std::ifstream input(location, std::ios::binary);
std::streampos position;
readBinaryPOD(input, maxelements_);
readBinaryPOD(input, size_per_element_);
readBinaryPOD(input, cur_element_count);
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxelements_ * size_per_element_);
if (data_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
input.read(data_, maxelements_ * size_per_element_);
input.close();
}
};
} // namespace hnswlib

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,199 @@
#pragma once
#ifndef NO_MANUAL_VECTORIZATION
#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
#define USE_SSE
#ifdef __AVX__
#define USE_AVX
#ifdef __AVX512F__
#define USE_AVX512
#endif
#endif
#endif
#endif
#if defined(USE_AVX) || defined(USE_SSE)
#ifdef _MSC_VER
#include <intrin.h>
#include <stdexcept>
void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
static __int64 xgetbv(unsigned int x) {
return _xgetbv(x);
}
#else
#include <x86intrin.h>
#include <cpuid.h>
#include <stdint.h>
static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
}
static uint64_t xgetbv(unsigned int index) {
uint32_t eax, edx;
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
return ((uint64_t)edx << 32) | eax;
}
#endif
#if defined(USE_AVX512)
#include <immintrin.h>
#endif
#if defined(__GNUC__)
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
#define PORTABLE_ALIGN64 __attribute__((aligned(64)))
#else
#define PORTABLE_ALIGN32 __declspec(align(32))
#define PORTABLE_ALIGN64 __declspec(align(64))
#endif
// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0
static bool AVXCapable() {
int cpuInfo[4];
// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];
bool HW_AVX = false;
if (nIds >= 0x00000001) {
cpuid(cpuInfo, 0x00000001, 0);
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
}
// OS support
cpuid(cpuInfo, 1, 0);
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
bool avxSupported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
}
return HW_AVX && avxSupported;
}
static bool AVX512Capable() {
if (!AVXCapable()) return false;
int cpuInfo[4];
// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];
bool HW_AVX512F = false;
if (nIds >= 0x00000007) { // AVX512 Foundation
cpuid(cpuInfo, 0x00000007, 0);
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
}
// OS support
cpuid(cpuInfo, 1, 0);
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
bool avx512Supported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
}
return HW_AVX512F && avx512Supported;
}
#endif
#include <queue>
#include <vector>
#include <iostream>
#include <string.h>
namespace hnswlib {
typedef size_t labeltype;
// This can be extended to store state for filtering (e.g. from a std::set)
class BaseFilterFunctor {
public:
virtual bool operator()(hnswlib::labeltype id) { return true; }
};
template <typename T>
class pairGreater {
public:
bool operator()(const T& p1, const T& p2) {
return p1.first > p2.first;
}
};
template<typename T>
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
out.write((char *) &podRef, sizeof(T));
}
template<typename T>
static void readBinaryPOD(std::istream &in, T &podRef) {
in.read((char *) &podRef, sizeof(T));
}
template<typename MTYPE>
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
template<typename MTYPE>
class SpaceInterface {
public:
// virtual void search(void *);
virtual size_t get_data_size() = 0;
virtual DISTFUNC<MTYPE> get_dist_func() = 0;
virtual void *get_dist_func_param() = 0;
virtual ~SpaceInterface() {}
};
template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;
virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
virtual void saveIndex(const std::string &location) = 0;
virtual ~AlgorithmInterface(){
}
};
template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
BaseFilterFunctor* isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;
// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
while (!ret.empty()) {
result[--sz] = ret.top();
ret.pop();
}
}
return result;
}
} // namespace hnswlib
#include "space_l2.h"
#include "space_ip.h"
#include "bruteforce.h"
#include "hnswalg.h"

View File

@ -0,0 +1,375 @@
#pragma once
#include "hnswlib.h"
namespace hnswlib {
static float
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
float res = 0;
for (unsigned i = 0; i < qty; i++) {
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
}
return res;
}
static float
InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
}
#if defined(USE_AVX)
// Favor using AVX if available.
static float
InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
size_t qty4 = qty / 4;
const float *pEnd1 = pVect1 + 16 * qty16;
const float *pEnd2 = pVect1 + 4 * qty4;
__m256 sum256 = _mm256_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m256 v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
__m256 v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
}
__m128 v1, v2;
__m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
while (pVect1 < pEnd2) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return sum;
}
static float
InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_SSE)
static float
InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
size_t qty4 = qty / 4;
const float *pEnd1 = pVect1 + 16 * qty16;
const float *pEnd2 = pVect1 + 4 * qty4;
__m128 v1, v2;
__m128 sum_prod = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
while (pVect1 < pEnd2) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return sum;
}
static float
InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_AVX512)
static float
InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN64 TmpRes[16];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
const float *pEnd1 = pVect1 + 16 * qty16;
__m512 sum512 = _mm512_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m512 v1 = _mm512_loadu_ps(pVect1);
pVect1 += 16;
__m512 v2 = _mm512_loadu_ps(pVect2);
pVect2 += 16;
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
}
_mm512_store_ps(TmpRes, sum512);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
return sum;
}
static float
InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_AVX)
static float
InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
const float *pEnd1 = pVect1 + 16 * qty16;
__m256 sum256 = _mm256_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
__m256 v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
__m256 v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
}
_mm256_store_ps(TmpRes, sum256);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
return sum;
}
static float
InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_SSE)
static float
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty / 16;
const float *pEnd1 = pVect1 + 16 * qty16;
__m128 v1, v2;
__m128 sum_prod = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
}
_mm_store_ps(TmpRes, sum_prod);
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
return sum;
}
static float
InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr);
}
#endif
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
static DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
static DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
static DISTFUNC<float> InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
static DISTFUNC<float> InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
static float
InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;
size_t qty_left = qty - qty16;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
return 1.0f - (res + res_tail);
}
static float
InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;
float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;
float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
return 1.0f - (res + res_tail);
}
#endif
class InnerProductSpace : public SpaceInterface<float> {
DISTFUNC<float> fstdistfunc_;
size_t data_size_;
size_t dim_;
public:
InnerProductSpace(size_t dim) {
fstdistfunc_ = InnerProductDistance;
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
} else if (AVXCapable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
}
#elif defined(USE_AVX)
if (AVXCapable()) {
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
}
#endif
#if defined(USE_AVX)
if (AVXCapable()) {
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
}
#endif
if (dim % 16 == 0)
fstdistfunc_ = InnerProductDistanceSIMD16Ext;
else if (dim % 4 == 0)
fstdistfunc_ = InnerProductDistanceSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
}
size_t get_data_size() {
return data_size_;
}
DISTFUNC<float> get_dist_func() {
return fstdistfunc_;
}
void *get_dist_func_param() {
return &dim_;
}
~InnerProductSpace() {}
};
} // namespace hnswlib

View File

@ -0,0 +1,324 @@
#pragma once
#include "hnswlib.h"
namespace hnswlib {
static float
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float res = 0;
for (size_t i = 0; i < qty; i++) {
float t = *pVect1 - *pVect2;
pVect1++;
pVect2++;
res += t * t;
}
return (res);
}
#if defined(USE_AVX512)
// Favor using AVX512 if available.
static float
L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN64 TmpRes[16];
size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4);
__m512 diff, v1, v2;
__m512 sum = _mm512_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm512_loadu_ps(pVect1);
pVect1 += 16;
v2 = _mm512_loadu_ps(pVect2);
pVect2 += 16;
diff = _mm512_sub_ps(v1, v2);
// sum = _mm512_fmadd_ps(diff, diff, sum);
sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
}
_mm512_store_ps(TmpRes, sum);
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
TmpRes[13] + TmpRes[14] + TmpRes[15];
return (res);
}
#endif
#if defined(USE_AVX)
// Favor using AVX if available.
static float
L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN32 TmpRes[8];
size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4);
__m256 diff, v1, v2;
__m256 sum = _mm256_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
v1 = _mm256_loadu_ps(pVect1);
pVect1 += 8;
v2 = _mm256_loadu_ps(pVect2);
pVect2 += 8;
diff = _mm256_sub_ps(v1, v2);
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
}
_mm256_store_ps(TmpRes, sum);
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
}
#endif
#if defined(USE_SSE)
static float
L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
float PORTABLE_ALIGN32 TmpRes[8];
size_t qty16 = qty >> 4;
const float *pEnd1 = pVect1 + (qty16 << 4);
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
_mm_store_ps(TmpRes, sum);
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}
#endif
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
static DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;
static float
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty16 = qty >> 4 << 4;
float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
float *pVect1 = (float *) pVect1v + qty16;
float *pVect2 = (float *) pVect2v + qty16;
size_t qty_left = qty - qty16;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif
#if defined(USE_SSE)
static float
L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
float PORTABLE_ALIGN32 TmpRes[8];
float *pVect1 = (float *) pVect1v;
float *pVect2 = (float *) pVect2v;
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2;
const float *pEnd1 = pVect1 + (qty4 << 2);
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
while (pVect1 < pEnd1) {
v1 = _mm_loadu_ps(pVect1);
pVect1 += 4;
v2 = _mm_loadu_ps(pVect2);
pVect2 += 4;
diff = _mm_sub_ps(v1, v2);
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
}
_mm_store_ps(TmpRes, sum);
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
}
static float
L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
size_t qty4 = qty >> 2 << 2;
float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
size_t qty_left = qty - qty4;
float *pVect1 = (float *) pVect1v + qty4;
float *pVect2 = (float *) pVect2v + qty4;
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
return (res + res_tail);
}
#endif
class L2Space : public SpaceInterface<float> {
DISTFUNC<float> fstdistfunc_;
size_t data_size_;
size_t dim_;
public:
L2Space(size_t dim) {
fstdistfunc_ = L2Sqr;
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
#if defined(USE_AVX512)
if (AVX512Capable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
else if (AVXCapable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#elif defined(USE_AVX)
if (AVXCapable())
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
#endif
if (dim % 16 == 0)
fstdistfunc_ = L2SqrSIMD16Ext;
else if (dim % 4 == 0)
fstdistfunc_ = L2SqrSIMD4Ext;
else if (dim > 16)
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
else if (dim > 4)
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
#endif
dim_ = dim;
data_size_ = dim * sizeof(float);
}
size_t get_data_size() {
return data_size_;
}
DISTFUNC<float> get_dist_func() {
return fstdistfunc_;
}
void *get_dist_func_param() {
return &dim_;
}
~L2Space() {}
};
static int
L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
size_t qty = *((size_t *) qty_ptr);
int res = 0;
unsigned char *a = (unsigned char *) pVect1;
unsigned char *b = (unsigned char *) pVect2;
qty = qty >> 2;
for (size_t i = 0; i < qty; i++) {
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
}
return (res);
}
static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
size_t qty = *((size_t*)qty_ptr);
int res = 0;
unsigned char* a = (unsigned char*)pVect1;
unsigned char* b = (unsigned char*)pVect2;
for (size_t i = 0; i < qty; i++) {
res += ((*a) - (*b)) * ((*a) - (*b));
a++;
b++;
}
return (res);
}
class L2SpaceI : public SpaceInterface<int> {
DISTFUNC<int> fstdistfunc_;
size_t data_size_;
size_t dim_;
public:
L2SpaceI(size_t dim) {
if (dim % 4 == 0) {
fstdistfunc_ = L2SqrI4x;
} else {
fstdistfunc_ = L2SqrI;
}
dim_ = dim;
data_size_ = dim * sizeof(unsigned char);
}
size_t get_data_size() {
return data_size_;
}
DISTFUNC<int> get_dist_func() {
return fstdistfunc_;
}
void *get_dist_func_param() {
return &dim_;
}
~L2SpaceI() {}
};
} // namespace hnswlib

View File

@ -0,0 +1,78 @@
#pragma once
#include <mutex>
#include <string.h>
#include <deque>
namespace hnswlib {
typedef unsigned short int vl_type;
class VisitedList {
public:
vl_type curV;
vl_type *mass;
unsigned int numelements;
VisitedList(int numelements1) {
curV = -1;
numelements = numelements1;
mass = new vl_type[numelements];
}
void reset() {
curV++;
if (curV == 0) {
memset(mass, 0, sizeof(vl_type) * numelements);
curV++;
}
}
~VisitedList() { delete[] mass; }
};
///////////////////////////////////////////////////////////
//
// Class for multi-threaded pool-management of VisitedLists
//
/////////////////////////////////////////////////////////
class VisitedListPool {
std::deque<VisitedList *> pool;
std::mutex poolguard;
int numelements;
public:
VisitedListPool(int initmaxpools, int numelements1) {
numelements = numelements1;
for (int i = 0; i < initmaxpools; i++)
pool.push_front(new VisitedList(numelements));
}
VisitedList *getFreeVisitedList() {
VisitedList *rez;
{
std::unique_lock <std::mutex> lock(poolguard);
if (pool.size() > 0) {
rez = pool.front();
pool.pop_front();
} else {
rez = new VisitedList(numelements);
}
}
rez->reset();
return rez;
}
void releaseVisitedList(VisitedList *vl) {
std::unique_lock <std::mutex> lock(poolguard);
pool.push_front(vl);
}
~VisitedListPool() {
while (pool.size()) {
VisitedList *rez = pool.front();
pool.pop_front();
delete rez;
}
}
};
} // namespace hnswlib

View File

@ -24,24 +24,50 @@ LocalDocs::LocalDocs()
&Database::removeFolder, Qt::QueuedConnection); &Database::removeFolder, Qt::QueuedConnection);
connect(this, &LocalDocs::requestChunkSizeChange, m_database, connect(this, &LocalDocs::requestChunkSizeChange, m_database,
&Database::changeChunkSize, Qt::QueuedConnection); &Database::changeChunkSize, Qt::QueuedConnection);
// Connections for modifying the model and keeping it updated with the database
connect(m_database, &Database::updateInstalled,
m_localDocsModel, &LocalDocsModel::updateInstalled, Qt::QueuedConnection);
connect(m_database, &Database::updateIndexing,
m_localDocsModel, &LocalDocsModel::updateIndexing, Qt::QueuedConnection);
connect(m_database, &Database::updateCurrentDocsToIndex,
m_localDocsModel, &LocalDocsModel::updateCurrentDocsToIndex, Qt::QueuedConnection);
connect(m_database, &Database::updateTotalDocsToIndex,
m_localDocsModel, &LocalDocsModel::updateTotalDocsToIndex, Qt::QueuedConnection);
connect(m_database, &Database::subtractCurrentBytesToIndex,
m_localDocsModel, &LocalDocsModel::subtractCurrentBytesToIndex, Qt::QueuedConnection);
connect(m_database, &Database::updateCurrentBytesToIndex,
m_localDocsModel, &LocalDocsModel::updateCurrentBytesToIndex, Qt::QueuedConnection);
connect(m_database, &Database::updateTotalBytesToIndex,
m_localDocsModel, &LocalDocsModel::updateTotalBytesToIndex, Qt::QueuedConnection);
connect(m_database, &Database::addCollectionItem,
m_localDocsModel, &LocalDocsModel::addCollectionItem, Qt::QueuedConnection);
connect(m_database, &Database::removeFolderById,
m_localDocsModel, &LocalDocsModel::removeFolderById, Qt::QueuedConnection);
connect(m_database, &Database::removeCollectionItem,
m_localDocsModel, &LocalDocsModel::removeCollectionItem, Qt::QueuedConnection);
connect(m_database, &Database::collectionListUpdated, connect(m_database, &Database::collectionListUpdated,
m_localDocsModel, &LocalDocsModel::handleCollectionListUpdated, Qt::QueuedConnection); m_localDocsModel, &LocalDocsModel::collectionListUpdated, Qt::QueuedConnection);
connect(qApp, &QCoreApplication::aboutToQuit, this, &LocalDocs::aboutToQuit);
}
void LocalDocs::aboutToQuit()
{
delete m_database;
m_database = nullptr;
} }
void LocalDocs::addFolder(const QString &collection, const QString &path) void LocalDocs::addFolder(const QString &collection, const QString &path)
{ {
const QUrl url(path); const QUrl url(path);
const QString localPath = url.isLocalFile() ? url.toLocalFile() : path; const QString localPath = url.isLocalFile() ? url.toLocalFile() : path;
// Add a placeholder collection that is not installed yet
CollectionItem i;
i.collection = collection;
i.folder_path = localPath;
m_localDocsModel->addCollectionItem(i);
emit requestAddFolder(collection, localPath); emit requestAddFolder(collection, localPath);
} }
void LocalDocs::removeFolder(const QString &collection, const QString &path) void LocalDocs::removeFolder(const QString &collection, const QString &path)
{ {
m_localDocsModel->removeCollectionPath(collection, path);
emit requestRemoveFolder(collection, path); emit requestRemoveFolder(collection, path);
} }

View File

@ -23,6 +23,7 @@ public:
public Q_SLOTS: public Q_SLOTS:
void handleChunkSizeChanged(); void handleChunkSizeChanged();
void aboutToQuit();
Q_SIGNALS: Q_SIGNALS:
void requestAddFolder(const QString &collection, const QString &path); void requestAddFolder(const QString &collection, const QString &path);
@ -36,7 +37,6 @@ private:
private: private:
explicit LocalDocs(); explicit LocalDocs();
~LocalDocs() {}
friend class MyLocalDocs; friend class MyLocalDocs;
}; };

View File

@ -1,5 +1,27 @@
#include "localdocsmodel.h" #include "localdocsmodel.h"
#include "localdocs.h"
LocalDocsCollectionsModel::LocalDocsCollectionsModel(QObject *parent)
: QSortFilterProxyModel(parent)
{
setSourceModel(LocalDocs::globalInstance()->localDocsModel());
}
bool LocalDocsCollectionsModel::filterAcceptsRow(int sourceRow,
const QModelIndex &sourceParent) const
{
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
const QString collection = sourceModel()->data(index, LocalDocsModel::CollectionRole).toString();
return m_collections.contains(collection);
}
void LocalDocsCollectionsModel::setCollections(const QList<QString> &collections)
{
m_collections = collections;
invalidateFilter();
}
LocalDocsModel::LocalDocsModel(QObject *parent) LocalDocsModel::LocalDocsModel(QObject *parent)
: QAbstractListModel(parent) : QAbstractListModel(parent)
{ {
@ -24,6 +46,16 @@ QVariant LocalDocsModel::data(const QModelIndex &index, int role) const
return item.folder_path; return item.folder_path;
case InstalledRole: case InstalledRole:
return item.installed; return item.installed;
case IndexingRole:
return item.indexing;
case CurrentDocsToIndexRole:
return item.currentDocsToIndex;
case TotalDocsToIndexRole:
return item.totalDocsToIndex;
case CurrentBytesToIndexRole:
return quint64(item.currentBytesToIndex);
case TotalBytesToIndexRole:
return quint64(item.totalBytesToIndex);
} }
return QVariant(); return QVariant();
@ -35,9 +67,98 @@ QHash<int, QByteArray> LocalDocsModel::roleNames() const
roles[CollectionRole] = "collection"; roles[CollectionRole] = "collection";
roles[FolderPathRole] = "folder_path"; roles[FolderPathRole] = "folder_path";
roles[InstalledRole] = "installed"; roles[InstalledRole] = "installed";
roles[IndexingRole] = "indexing";
roles[CurrentDocsToIndexRole] = "currentDocsToIndex";
roles[TotalDocsToIndexRole] = "totalDocsToIndex";
roles[CurrentBytesToIndexRole] = "currentBytesToIndex";
roles[TotalBytesToIndexRole] = "totalBytesToIndex";
return roles; return roles;
} }
void LocalDocsModel::updateInstalled(int folder_id, bool b)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].installed = b;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {InstalledRole});
}
}
void LocalDocsModel::updateIndexing(int folder_id, bool b)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].indexing = b;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {IndexingRole});
}
}
void LocalDocsModel::updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].currentDocsToIndex = currentDocsToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {CurrentDocsToIndexRole});
}
}
void LocalDocsModel::updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].totalDocsToIndex = totalDocsToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {TotalDocsToIndexRole});
}
}
void LocalDocsModel::subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].currentBytesToIndex -= subtractedBytes;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {CurrentBytesToIndexRole});
}
}
void LocalDocsModel::updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].currentBytesToIndex = currentBytesToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {CurrentBytesToIndexRole});
}
}
void LocalDocsModel::updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex)
{
for (int i = 0; i < m_collectionList.size(); ++i) {
if (m_collectionList.at(i).folder_id != folder_id)
continue;
m_collectionList[i].totalBytesToIndex = totalBytesToIndex;
emit collectionItemUpdated(i, m_collectionList[i]);
emit dataChanged(this->index(i), this->index(i), {TotalBytesToIndexRole});
}
}
void LocalDocsModel::addCollectionItem(const CollectionItem &item) void LocalDocsModel::addCollectionItem(const CollectionItem &item)
{ {
beginInsertRows(QModelIndex(), m_collectionList.size(), m_collectionList.size()); beginInsertRows(QModelIndex(), m_collectionList.size(), m_collectionList.size());
@ -45,7 +166,46 @@ void LocalDocsModel::addCollectionItem(const CollectionItem &item)
endInsertRows(); endInsertRows();
} }
void LocalDocsModel::handleCollectionListUpdated(const QList<CollectionItem> &collectionList) void LocalDocsModel::removeFolderById(int folder_id)
{
for (int i = 0; i < m_collectionList.size();) {
if (m_collectionList.at(i).folder_id == folder_id) {
beginRemoveRows(QModelIndex(), i, i);
m_collectionList.removeAt(i);
endRemoveRows();
} else {
++i;
}
}
}
void LocalDocsModel::removeCollectionPath(const QString &name, const QString &path)
{
for (int i = 0; i < m_collectionList.size();) {
if (m_collectionList.at(i).collection == name && m_collectionList.at(i).folder_path == path) {
beginRemoveRows(QModelIndex(), i, i);
m_collectionList.removeAt(i);
endRemoveRows();
} else {
++i;
}
}
}
void LocalDocsModel::removeCollectionItem(const QString &collectionName)
{
for (int i = 0; i < m_collectionList.size();) {
if (m_collectionList.at(i).collection == collectionName) {
beginRemoveRows(QModelIndex(), i, i);
m_collectionList.removeAt(i);
endRemoveRows();
} else {
++i;
}
}
}
void LocalDocsModel::collectionListUpdated(const QList<CollectionItem> &collectionList)
{ {
beginResetModel(); beginResetModel();
m_collectionList = collectionList; m_collectionList = collectionList;

View File

@ -4,6 +4,22 @@
#include <QAbstractListModel> #include <QAbstractListModel>
#include "database.h" #include "database.h"
class LocalDocsCollectionsModel : public QSortFilterProxyModel
{
Q_OBJECT
public:
explicit LocalDocsCollectionsModel(QObject *parent);
public Q_SLOTS:
void setCollections(const QList<QString> &collections);
protected:
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
private:
QList<QString> m_collections;
};
class LocalDocsModel : public QAbstractListModel class LocalDocsModel : public QAbstractListModel
{ {
Q_OBJECT Q_OBJECT
@ -12,7 +28,13 @@ public:
enum Roles { enum Roles {
CollectionRole = Qt::UserRole + 1, CollectionRole = Qt::UserRole + 1,
FolderPathRole, FolderPathRole,
InstalledRole InstalledRole,
IndexingRole,
EmbeddingRole,
CurrentDocsToIndexRole,
TotalDocsToIndexRole,
CurrentBytesToIndexRole,
TotalBytesToIndexRole
}; };
explicit LocalDocsModel(QObject *parent = nullptr); explicit LocalDocsModel(QObject *parent = nullptr);
@ -20,9 +42,25 @@ public:
QVariant data(const QModelIndex &index, int role) const override; QVariant data(const QModelIndex &index, int role) const override;
QHash<int, QByteArray> roleNames() const override; QHash<int, QByteArray> roleNames() const override;
Q_SIGNALS:
void collectionItemUpdated(int index, const CollectionItem& item);
public Q_SLOTS: public Q_SLOTS:
void updateInstalled(int folder_id, bool b);
void updateIndexing(int folder_id, bool b);
void updateCurrentDocsToIndex(int folder_id, size_t currentDocsToIndex);
void updateTotalDocsToIndex(int folder_id, size_t totalDocsToIndex);
void subtractCurrentBytesToIndex(int folder_id, size_t subtractedBytes);
void updateCurrentBytesToIndex(int folder_id, size_t currentBytesToIndex);
void updateTotalBytesToIndex(int folder_id, size_t totalBytesToIndex);
void addCollectionItem(const CollectionItem &item); void addCollectionItem(const CollectionItem &item);
void handleCollectionListUpdated(const QList<CollectionItem> &collectionList); void removeFolderById(int folder_id);
void removeCollectionPath(const QString &name, const QString &path);
void removeCollectionItem(const QString &collectionName);
void collectionListUpdated(const QList<CollectionItem> &collectionList);
private:
void updateItem(int index, const CollectionItem& item);
private: private:
QList<CollectionItem> m_collectionList; QList<CollectionItem> m_collectionList;

View File

@ -325,6 +325,10 @@ Window {
anchors.centerIn: parent anchors.centerIn: parent
width: Math.min(1280, window.width - (window.width * .1)) width: Math.min(1280, window.width - (window.width * .1))
height: window.height - (window.height * .1) height: window.height - (window.height * .1)
onDownloadClicked: {
downloadNewModels.showEmbeddingModels = true
downloadNewModels.open()
}
} }
Button { Button {
@ -652,6 +656,7 @@ Window {
width: Math.min(600, 0.3 * window.width) width: Math.min(600, 0.3 * window.width)
height: window.height - y height: window.height - y
onDownloadClicked: { onDownloadClicked: {
downloadNewModels.showEmbeddingModels = false
downloadNewModels.open() downloadNewModels.open()
} }
onAboutClicked: { onAboutClicked: {
@ -818,11 +823,11 @@ Window {
color: theme.textAccent color: theme.textAccent
text: { text: {
switch (currentChat.responseState) { switch (currentChat.responseState) {
case Chat.ResponseStopped: return "response stopped ..."; case Chat.ResponseStopped: return qsTr("response stopped ...");
case Chat.LocalDocsRetrieval: return "retrieving " + currentChat.collectionList.join(", ") + " ..."; case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: ") + currentChat.collectionList.join(", ") + " ...";
case Chat.LocalDocsProcessing: return "processing " + currentChat.collectionList.join(", ") + " ..."; case Chat.LocalDocsProcessing: return qsTr("searching localdocs: ") + currentChat.collectionList.join(", ") + " ...";
case Chat.PromptProcessing: return "processing ..." case Chat.PromptProcessing: return qsTr("processing ...")
case Chat.ResponseGeneration: return "generating response ..."; case Chat.ResponseGeneration: return qsTr("generating response ...");
default: return ""; // handle unexpected values default: return ""; // handle unexpected values
} }
} }

View File

@ -138,7 +138,7 @@
"type": "Replit", "type": "Replit",
"systemPrompt": " ", "systemPrompt": " ",
"promptTemplate": "%1", "promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use</ul>", "description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-q4_0.gguf" "url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-q4_0.gguf"
}, },
{ {
@ -155,7 +155,7 @@
"type": "Starcoder", "type": "Starcoder",
"systemPrompt": " ", "systemPrompt": " ",
"promptTemplate": "%1", "promptTemplate": "%1",
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based</ul>", "description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>",
"url": "https://gpt4all.io/models/gguf/starcoder-q4_0.gguf" "url": "https://gpt4all.io/models/gguf/starcoder-q4_0.gguf"
}, },
{ {
@ -172,7 +172,7 @@
"type": "LLaMA", "type": "LLaMA",
"systemPrompt": " ", "systemPrompt": " ",
"promptTemplate": "%1", "promptTemplate": "%1",
"description": "Code completion model", "description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>",
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf" "url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf"
}, },
{ {
@ -184,11 +184,11 @@
"filesize": "45887744", "filesize": "45887744",
"requires": "2.5.0", "requires": "2.5.0",
"ramrequired": "1", "ramrequired": "1",
"parameters": "1 million", "parameters": "40 million",
"quant": "f16", "quant": "f16",
"type": "Bert", "type": "Bert",
"systemPrompt": " ", "systemPrompt": " ",
"description": "<strong>Sbert</strong><br><ul><li>For embeddings", "description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>Necessary for LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf" "url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf"
}, },
{ {

View File

@ -139,6 +139,50 @@ void ModelInfo::setSystemPrompt(const QString &p)
m_systemPrompt = p; m_systemPrompt = p;
} }
EmbeddingModels::EmbeddingModels(QObject *parent)
: QSortFilterProxyModel(parent)
{
connect(this, &EmbeddingModels::rowsInserted, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::rowsRemoved, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::modelReset, this, &EmbeddingModels::countChanged);
connect(this, &EmbeddingModels::layoutChanged, this, &EmbeddingModels::countChanged);
}
bool EmbeddingModels::filterAcceptsRow(int sourceRow,
const QModelIndex &sourceParent) const
{
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool();
bool isEmbedding = sourceModel()->data(index, ModelList::FilenameRole).toString() == "all-MiniLM-L6-v2-f16.gguf";
return isInstalled && isEmbedding;
}
int EmbeddingModels::count() const
{
return rowCount();
}
ModelInfo EmbeddingModels::defaultModelInfo() const
{
if (!sourceModel())
return ModelInfo();
const ModelList *sourceListModel = qobject_cast<const ModelList*>(sourceModel());
if (!sourceListModel)
return ModelInfo();
const int rows = sourceListModel->rowCount();
for (int i = 0; i < rows; ++i) {
QModelIndex sourceIndex = sourceListModel->index(i, 0);
if (filterAcceptsRow(i, sourceIndex.parent())) {
const QString id = sourceListModel->data(sourceIndex, ModelList::IdRole).toString();
return sourceListModel->modelInfo(id);
}
}
return ModelInfo();
}
InstalledModels::InstalledModels(QObject *parent) InstalledModels::InstalledModels(QObject *parent)
: QSortFilterProxyModel(parent) : QSortFilterProxyModel(parent)
{ {
@ -153,7 +197,8 @@ bool InstalledModels::filterAcceptsRow(int sourceRow,
{ {
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool(); bool isInstalled = sourceModel()->data(index, ModelList::InstalledRole).toBool();
return isInstalled; bool showInGUI = !sourceModel()->data(index, ModelList::DisableGUIRole).toBool();
return isInstalled && showInGUI;
} }
int InstalledModels::count() const int InstalledModels::count() const
@ -178,8 +223,7 @@ bool DownloadableModels::filterAcceptsRow(int sourceRow,
bool withinLimit = sourceRow < (m_expanded ? sourceModel()->rowCount() : m_limit); bool withinLimit = sourceRow < (m_expanded ? sourceModel()->rowCount() : m_limit);
QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent); QModelIndex index = sourceModel()->index(sourceRow, 0, sourceParent);
bool isDownloadable = !sourceModel()->data(index, ModelList::DescriptionRole).toString().isEmpty(); bool isDownloadable = !sourceModel()->data(index, ModelList::DescriptionRole).toString().isEmpty();
bool showInGUI = !sourceModel()->data(index, ModelList::DisableGUIRole).toBool(); return withinLimit && isDownloadable;
return withinLimit && isDownloadable && showInGUI;
} }
int DownloadableModels::count() const int DownloadableModels::count() const
@ -210,10 +254,12 @@ ModelList *ModelList::globalInstance()
ModelList::ModelList() ModelList::ModelList()
: QAbstractListModel(nullptr) : QAbstractListModel(nullptr)
, m_embeddingModels(new EmbeddingModels(this))
, m_installedModels(new InstalledModels(this)) , m_installedModels(new InstalledModels(this))
, m_downloadableModels(new DownloadableModels(this)) , m_downloadableModels(new DownloadableModels(this))
, m_asyncModelRequestOngoing(false) , m_asyncModelRequestOngoing(false)
{ {
m_embeddingModels->setSourceModel(this);
m_installedModels->setSourceModel(this); m_installedModels->setSourceModel(this);
m_downloadableModels->setSourceModel(this); m_downloadableModels->setSourceModel(this);
m_watcher = new QFileSystemWatcher(this); m_watcher = new QFileSystemWatcher(this);
@ -280,6 +326,17 @@ const QList<QString> ModelList::userDefaultModelList() const
return models; return models;
} }
int ModelList::defaultEmbeddingModelIndex() const
{
QMutexLocker locker(&m_mutex);
for (int i = 0; i < m_models.size(); ++i) {
const ModelInfo *info = m_models.at(i);
const bool isEmbedding = info->filename() == "all-MiniLM-L6-v2-f16.gguf";
if (isEmbedding) return i;
}
return -1;
}
ModelInfo ModelList::defaultModelInfo() const ModelInfo ModelList::defaultModelInfo() const
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);

View File

@ -120,6 +120,24 @@ private:
}; };
Q_DECLARE_METATYPE(ModelInfo) Q_DECLARE_METATYPE(ModelInfo)
class EmbeddingModels : public QSortFilterProxyModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
public:
explicit EmbeddingModels(QObject *parent);
int count() const;
ModelInfo defaultModelInfo() const;
Q_SIGNALS:
void countChanged();
void defaultModelIndexChanged();
protected:
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
};
class InstalledModels : public QSortFilterProxyModel class InstalledModels : public QSortFilterProxyModel
{ {
Q_OBJECT Q_OBJECT
@ -165,6 +183,8 @@ class ModelList : public QAbstractListModel
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged) Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(int defaultEmbeddingModelIndex READ defaultEmbeddingModelIndex NOTIFY defaultEmbeddingModelIndexChanged)
Q_PROPERTY(EmbeddingModels* embeddingModels READ embeddingModels NOTIFY embeddingModelsChanged)
Q_PROPERTY(InstalledModels* installedModels READ installedModels NOTIFY installedModelsChanged) Q_PROPERTY(InstalledModels* installedModels READ installedModels NOTIFY installedModelsChanged)
Q_PROPERTY(DownloadableModels* downloadableModels READ downloadableModels NOTIFY downloadableModelsChanged) Q_PROPERTY(DownloadableModels* downloadableModels READ downloadableModels NOTIFY downloadableModelsChanged)
Q_PROPERTY(QList<QString> userDefaultModelList READ userDefaultModelList NOTIFY userDefaultModelListChanged) Q_PROPERTY(QList<QString> userDefaultModelList READ userDefaultModelList NOTIFY userDefaultModelListChanged)
@ -273,6 +293,7 @@ public:
Q_INVOKABLE QString clone(const ModelInfo &model); Q_INVOKABLE QString clone(const ModelInfo &model);
Q_INVOKABLE void remove(const ModelInfo &model); Q_INVOKABLE void remove(const ModelInfo &model);
ModelInfo defaultModelInfo() const; ModelInfo defaultModelInfo() const;
int defaultEmbeddingModelIndex() const;
void addModel(const QString &id); void addModel(const QString &id);
void changeId(const QString &oldId, const QString &newId); void changeId(const QString &oldId, const QString &newId);
@ -280,6 +301,7 @@ public:
const QList<ModelInfo> exportModelList() const; const QList<ModelInfo> exportModelList() const;
const QList<QString> userDefaultModelList() const; const QList<QString> userDefaultModelList() const;
EmbeddingModels *embeddingModels() const { return m_embeddingModels; }
InstalledModels *installedModels() const { return m_installedModels; } InstalledModels *installedModels() const { return m_installedModels; }
DownloadableModels *downloadableModels() const { return m_downloadableModels; } DownloadableModels *downloadableModels() const { return m_downloadableModels; }
@ -300,10 +322,12 @@ public:
Q_SIGNALS: Q_SIGNALS:
void countChanged(); void countChanged();
void embeddingModelsChanged();
void installedModelsChanged(); void installedModelsChanged();
void downloadableModelsChanged(); void downloadableModelsChanged();
void userDefaultModelListChanged(); void userDefaultModelListChanged();
void asyncModelRequestOngoingChanged(); void asyncModelRequestOngoingChanged();
void defaultEmbeddingModelIndexChanged();
private Q_SLOTS: private Q_SLOTS:
void updateModelsFromJson(); void updateModelsFromJson();
@ -326,6 +350,7 @@ private:
private: private:
mutable QMutex m_mutex; mutable QMutex m_mutex;
QNetworkAccessManager m_networkManager; QNetworkAccessManager m_networkManager;
EmbeddingModels *m_embeddingModels;
InstalledModels *m_installedModels; InstalledModels *m_installedModels;
DownloadableModels *m_downloadableModels; DownloadableModels *m_downloadableModels;
QList<ModelInfo*> m_models; QList<ModelInfo*> m_models;

View File

@ -21,7 +21,7 @@ MyDialog {
id: listLabel id: listLabel
anchors.top: parent.top anchors.top: parent.top
anchors.left: parent.left anchors.left: parent.left
text: "Available LocalDocs Collections:" text: qsTr("Local Documents:")
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
color: theme.textColor color: theme.textColor
} }
@ -63,17 +63,60 @@ MyDialog {
currentChat.removeCollection(collection) currentChat.removeCollection(collection)
} }
} }
ToolTip.text: qsTr("Warning: searching collections while indexing can return incomplete results")
ToolTip.visible: hovered && model.indexing
} }
Text { Text {
id: collectionId id: collectionId
anchors.verticalCenter: parent.verticalCenter anchors.verticalCenter: parent.verticalCenter
anchors.left: checkBox.right anchors.left: checkBox.right
anchors.margins: 20 anchors.margins: 20
anchors.leftMargin: 10
text: collection text: collection
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
elide: Text.ElideRight elide: Text.ElideRight
color: theme.textColor color: theme.textColor
} }
ProgressBar {
id: itemProgressBar
anchors.verticalCenter: parent.verticalCenter
anchors.left: collectionId.right
anchors.right: parent.right
anchors.margins: 20
anchors.leftMargin: 40
visible: model.indexing
value: (model.totalBytesToIndex - model.currentBytesToIndex) / model.totalBytesToIndex
background: Rectangle {
implicitHeight: 45
color: theme.backgroundDarkest
radius: 3
}
contentItem: Item {
implicitHeight: 40
Rectangle {
width: itemProgressBar.visualPosition * parent.width
height: parent.height
radius: 2
color: theme.assistantColor
}
}
Accessible.role: Accessible.ProgressBar
Accessible.name: qsTr("Indexing progressBar")
Accessible.description: qsTr("Shows the progress made in the indexing")
}
Label {
id: speedLabel
color: theme.textColor
visible: model.indexing
anchors.verticalCenter: itemProgressBar.verticalCenter
anchors.left: itemProgressBar.left
anchors.right: itemProgressBar.right
horizontalAlignment: Text.AlignHCenter
text: qsTr("indexing...")
elide: Text.ElideRight
font.pixelSize: theme.fontSizeLarge
}
} }
} }
} }

View File

@ -5,6 +5,7 @@ import QtQuick.Controls.Basic
import QtQuick.Layouts import QtQuick.Layouts
import QtQuick.Dialogs import QtQuick.Dialogs
import localdocs import localdocs
import modellist
import mysettings import mysettings
import network import network
@ -13,7 +14,11 @@ MySettingsTab {
MySettings.restoreLocalDocsDefaults(); MySettings.restoreLocalDocsDefaults();
} }
title: qsTr("LocalDocs Plugin (BETA)") property bool hasEmbeddingModel: ModelList.embeddingModels.count !== 0
showAdvancedSettingsButton: hasEmbeddingModel
showRestoreDefaultsButton: hasEmbeddingModel
title: qsTr("LocalDocs")
contentItem: ColumnLayout { contentItem: ColumnLayout {
id: root id: root
spacing: 10 spacing: 10
@ -21,7 +26,30 @@ MySettingsTab {
property alias collection: collection.text property alias collection: collection.text
property alias folder_path: folderEdit.text property alias folder_path: folderEdit.text
Label {
id: downloadLabel
Layout.fillWidth: true
Layout.maximumWidth: parent.width
wrapMode: Text.Wrap
visible: !hasEmbeddingModel
Layout.alignment: Qt.AlignLeft
text: qsTr("This feature requires the download of a text embedding model in order to index documents for later search. Please download the <b>SBert</a> text embedding model from the download dialog to proceed.")
font.pixelSize: theme.fontSizeLarger
}
MyButton {
visible: !hasEmbeddingModel
Layout.topMargin: 20
Layout.alignment: Qt.AlignLeft
text: qsTr("Download")
font.pixelSize: theme.fontSizeLarger
onClicked: {
downloadClicked()
}
}
Item { Item {
visible: hasEmbeddingModel
Layout.fillWidth: true Layout.fillWidth: true
height: row.height height: row.height
RowLayout { RowLayout {
@ -106,6 +134,7 @@ MySettingsTab {
} }
ColumnLayout { ColumnLayout {
visible: hasEmbeddingModel
spacing: 0 spacing: 0
Repeater { Repeater {
model: LocalDocs.localDocsModel model: LocalDocs.localDocsModel
@ -145,29 +174,25 @@ MySettingsTab {
anchors.right: parent.right anchors.right: parent.right
anchors.verticalCenter: parent.verticalCenter anchors.verticalCenter: parent.verticalCenter
anchors.margins: 20 anchors.margins: 20
width: Math.max(removeButton.width, busyIndicator.width) width: removeButton.width
height: Math.max(removeButton.height, busyIndicator.height) height:removeButton.height
MyButton { MyButton {
id: removeButton id: removeButton
anchors.centerIn: parent anchors.centerIn: parent
text: qsTr("Remove") text: qsTr("Remove")
visible: !item.removing && installed visible: !item.removing
onClicked: { onClicked: {
item.removing = true item.removing = true
LocalDocs.removeFolder(collection, folder_path) LocalDocs.removeFolder(collection, folder_path)
} }
} }
MyBusyIndicator {
id: busyIndicator
anchors.centerIn: parent
visible: item.removing || !installed
}
} }
} }
} }
} }
RowLayout { RowLayout {
visible: hasEmbeddingModel
Label { Label {
id: showReferencesLabel id: showReferencesLabel
text: qsTr("Show references:") text: qsTr("Show references:")
@ -186,6 +211,7 @@ MySettingsTab {
} }
Rectangle { Rectangle {
visible: hasEmbeddingModel
Layout.fillWidth: true Layout.fillWidth: true
height: 1 height: 1
color: theme.tabBorder color: theme.tabBorder
@ -196,6 +222,7 @@ MySettingsTab {
columns: 3 columns: 3
rowSpacing: 10 rowSpacing: 10
columnSpacing: 10 columnSpacing: 10
visible: hasEmbeddingModel
Rectangle { Rectangle {
Layout.row: 3 Layout.row: 3

View File

@ -16,9 +16,17 @@ MyDialog {
modal: true modal: true
closePolicy: Popup.CloseOnEscape | Popup.CloseOnPressOutside closePolicy: Popup.CloseOnEscape | Popup.CloseOnPressOutside
padding: 10 padding: 10
property bool showEmbeddingModels: false
onOpened: { onOpened: {
Network.sendModelDownloaderDialog(); Network.sendModelDownloaderDialog();
if (showEmbeddingModels) {
ModelList.downloadableModels.expanded = true
var targetModelIndex = ModelList.defaultEmbeddingModelIndex
console.log("targetModelIndex " + targetModelIndex)
modelListView.positionViewAtIndex(targetModelIndex, ListView.Contain);
}
} }
PopupDialog { PopupDialog {

View File

@ -9,8 +9,11 @@ Item {
property string title: "" property string title: ""
property Item contentItem: null property Item contentItem: null
property Item advancedSettings: null property Item advancedSettings: null
property bool showAdvancedSettingsButton: true
property bool showRestoreDefaultsButton: true
property var openFolderDialog property var openFolderDialog
signal restoreDefaultsClicked signal restoreDefaultsClicked
signal downloadClicked
onContentItemChanged: function() { onContentItemChanged: function() {
if (contentItem) { if (contentItem) {
@ -64,6 +67,7 @@ Item {
MyButton { MyButton {
id: restoreDefaultsButton id: restoreDefaultsButton
anchors.left: parent.left anchors.left: parent.left
visible: showRestoreDefaultsButton
width: implicitWidth width: implicitWidth
text: qsTr("Restore Defaults") text: qsTr("Restore Defaults")
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
@ -77,7 +81,7 @@ Item {
MyButton { MyButton {
id: advancedSettingsButton id: advancedSettingsButton
anchors.right: parent.right anchors.right: parent.right
visible: root.advancedSettings visible: root.advancedSettings && showAdvancedSettingsButton
width: implicitWidth width: implicitWidth
text: !advancedInner.visible ? qsTr("Advanced Settings") : qsTr("Hide Advanced Settings") text: !advancedInner.visible ? qsTr("Advanced Settings") : qsTr("Hide Advanced Settings")
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge

View File

@ -19,6 +19,8 @@ MyDialog {
Network.sendSettingsDialog(); Network.sendSettingsDialog();
} }
signal downloadClicked
Item { Item {
Accessible.role: Accessible.Dialog Accessible.role: Accessible.Dialog
Accessible.name: qsTr("Settings") Accessible.name: qsTr("Settings")
@ -28,13 +30,13 @@ MyDialog {
ListModel { ListModel {
id: stacksModel id: stacksModel
ListElement { ListElement {
title: "Models" title: qsTr("Models")
} }
ListElement { ListElement {
title: "Application" title: qsTr("Application")
} }
ListElement { ListElement {
title: "Plugins" title: qsTr("LocalDocs")
} }
} }
@ -107,9 +109,16 @@ MyDialog {
} }
MySettingsStack { MySettingsStack {
title: qsTr("LocalDocs Plugin (BETA) Settings") title: qsTr("Local Document Collections")
tabs: [ tabs: [
Component { LocalDocsSettings { } } Component {
LocalDocsSettings {
id: localDocsSettings
Component.onCompleted: {
localDocsSettings.downloadClicked.connect(settingsDialog.downloadClicked);
}
}
}
] ]
} }
} }