From b5380c9b7f6fe4eed0efa79c6dab44a4da697a1d Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 24 May 2023 14:49:43 -0400 Subject: [PATCH] Adds the collections to serialize and implement references for localdocs. --- gpt4all-chat/chat.cpp | 71 +++++++++++++++---- gpt4all-chat/chat.h | 5 +- gpt4all-chat/chatlistmodel.cpp | 2 +- gpt4all-chat/database.cpp | 125 +++++++++++++++++++++++---------- gpt4all-chat/database.h | 20 ++++-- gpt4all-chat/localdocs.cpp | 13 +--- gpt4all-chat/localdocs.h | 12 +--- 7 files changed, 171 insertions(+), 77 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 1f80fe8e..f9ef8b5c 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -15,7 +15,6 @@ Chat::Chat(QObject *parent) , m_llmodel(new ChatLLM(this)) , m_isServer(false) , m_shouldDeleteLater(false) - , m_contextContainsLocalDocs(false) { connectLLM(); } @@ -31,7 +30,6 @@ Chat::Chat(bool isServer, QObject *parent) , m_llmodel(new Server(this)) , m_isServer(true) , m_shouldDeleteLater(false) - , m_contextContainsLocalDocs(false) { connectLLM(); } @@ -103,7 +101,8 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { - m_contextContainsLocalDocs = false; + Q_ASSERT(m_results.isEmpty()); + m_results.clear(); // just in case, but the assert above is important m_responseInProgress = true; m_responseState = Chat::LocalDocsRetrieval; emit responseInProgressChanged(); @@ -116,18 +115,25 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t m_queuedPrompt.n_batch = n_batch; m_queuedPrompt.repeat_penalty = repeat_penalty; m_queuedPrompt.repeat_penalty_tokens = repeat_penalty_tokens; - LocalDocs::globalInstance()->requestRetrieve(m_collections, prompt); + LocalDocs::globalInstance()->requestRetrieve(m_id, m_collections, prompt); } -void Chat::handleLocalDocsRetrieved() +void Chat::handleLocalDocsRetrieved(const QString &uid, const QList &results) { + // If the uid doesn't match, then these are not our results + if (uid != m_id) + return; + + // Store our results locally + m_results = results; + + // Augment the prompt template with the results if any QList augmentedTemplate; - QList results = LocalDocs::globalInstance()->result(); - if (!results.isEmpty()) { + if (!m_results.isEmpty()) augmentedTemplate.append("### Context:"); - augmentedTemplate.append(results.join("\n\n")); - } - m_contextContainsLocalDocs = !results.isEmpty(); + for (const ResultInfo &info : m_results) + augmentedTemplate.append(info.text); + augmentedTemplate.append(m_queuedPrompt.prompt_template); emit promptRequested( m_queuedPrompt.prompt, @@ -191,13 +197,48 @@ void Chat::handleModelLoadedChanged() void Chat::promptProcessing() { - m_responseState = m_contextContainsLocalDocs ? Chat::LocalDocsProcessing : Chat::PromptProcessing; + m_responseState = !m_results.isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; emit responseStateChanged(); } void Chat::responseStopped() { - m_contextContainsLocalDocs = false; + const QString chatResponse = response(); + QList finalResponse { chatResponse }; + int validReferenceNumber = 1; + for (const ResultInfo &info : m_results) { + if (info.file.isEmpty()) + continue; + if (validReferenceNumber == 1) + finalResponse.append(QStringLiteral("---")); + QString reference; + { + QTextStream stream(&reference); + stream << (validReferenceNumber++) << ". "; + if (!info.title.isEmpty()) + stream << "\"" << info.title << "\". "; + if (!info.author.isEmpty()) + stream << "By " << info.author << ". "; + if (!info.date.isEmpty()) + stream << "Date: " << info.date << ". "; + stream << "In " << info.file << ". "; + if (info.page != -1) + stream << "Page " << info.page << ". "; + if (info.from != -1) { + stream << "Lines " << info.from; + if (info.to != -1) + stream << "-" << info.to; + stream << ". "; + } + } + finalResponse.append(reference); + } + + const int index = m_chatModel->count() - 1; + m_chatModel->updateValue(index, finalResponse.join("\n")); + emit responseChanged(); + + m_results.clear(); m_responseInProgress = false; m_responseState = Chat::ResponseStopped; emit responseInProgressChanged(); @@ -301,6 +342,8 @@ bool Chat::serialize(QDataStream &stream, int version) const stream << m_name; stream << m_userName; stream << m_savedModelName; + if (version > 2) + stream << m_collections; if (!m_llmodel->serialize(stream, version)) return false; if (!m_chatModel->serialize(stream, version)) @@ -321,6 +364,10 @@ bool Chat::deserialize(QDataStream &stream, int version) // unfortunately, we cannot deserialize these if (version < 2 && m_savedModelName.contains("gpt4all-j")) return false; + if (version > 2) { + stream >> m_collections; + emit collectionListChanged(); + } m_llmodel->setModelName(m_savedModelName); if (!m_llmodel->deserialize(stream, version)) return false; diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index da8ae7b4..0e9e6d47 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -7,6 +7,7 @@ #include "chatllm.h" #include "chatmodel.h" +#include "database.h" #include "server.h" class Chat : public QObject @@ -114,7 +115,7 @@ Q_SIGNALS: void collectionListChanged(); private Q_SLOTS: - void handleLocalDocsRetrieved(); + void handleLocalDocsRetrieved(const QString &uid, const QList &results); void handleResponseChanged(); void handleModelLoadedChanged(); void promptProcessing(); @@ -141,6 +142,7 @@ private: QString m_userName; QString m_savedModelName; QList m_collections; + QList m_results; ChatModel *m_chatModel; bool m_responseInProgress; ResponseState m_responseState; @@ -148,7 +150,6 @@ private: ChatLLM *m_llmodel; bool m_isServer; bool m_shouldDeleteLater; - bool m_contextContainsLocalDocs; Prompt m_queuedPrompt; }; diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index f2cb0ad2..97cd018c 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -6,7 +6,7 @@ #include #define CHAT_FORMAT_MAGIC 0xF5D553CC -#define CHAT_FORMAT_VERSION 2 +#define CHAT_FORMAT_VERSION 3 ChatListModel::ChatListModel(QObject *parent) : QAbstractListModel(parent) diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index a9c6e001..6cd5657a 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -11,12 +11,14 @@ const auto INSERT_CHUNK_SQL = QLatin1String(R"( insert into chunks(document_id, chunk_id, chunk_text, - embedding_id, embedding_path) values(?, ?, ?, ?, ?); + file, title, author, subject, keywords, page, line_from, line_to, + embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); )"); const auto INSERT_CHUNK_FTS_SQL = QLatin1String(R"( insert into chunks_fts(document_id, chunk_id, chunk_text, - embedding_id, embedding_path) values(?, ?, ?, ?, ?); + file, title, author, subject, keywords, page, line_from, line_to, + embedding_id, embedding_path) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); )"); const auto DELETE_CHUNKS_SQL = QLatin1String(R"( @@ -29,16 +31,21 @@ const auto DELETE_CHUNKS_FTS_SQL = QLatin1String(R"( const auto CHUNKS_SQL = QLatin1String(R"( create table chunks(document_id integer, chunk_id integer, chunk_text varchar, + file varchar, title varchar, author varchar, subject varchar, keywords varchar, + page integer, line_from integer, line_to integer, embedding_id integer, embedding_path varchar); )"); const auto FTS_CHUNKS_SQL = QLatin1String(R"( create virtual table chunks_fts using fts5(document_id unindexed, chunk_id unindexed, chunk_text, + file, title, author, subject, keywords, page, line_from, line_to, embedding_id unindexed, embedding_path unindexed, tokenize="trigram"); )"); const auto SELECT_SQL = QLatin1String(R"( - select chunks_fts.rowid, chunks_fts.document_id, chunks_fts.chunk_text + select chunks_fts.rowid, documents.document_time, + chunks_fts.chunk_text, chunks_fts.file, chunks_fts.title, chunks_fts.author, chunks_fts.page, + chunks_fts.line_from, chunks_fts.line_to from chunks_fts join documents ON chunks_fts.document_id = documents.id join folders ON documents.folder_id = folders.id @@ -48,8 +55,10 @@ const auto SELECT_SQL = QLatin1String(R"( limit %2; )"); -bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_text, int embedding_id, - const QString &embedding_path) +bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_text, + const QString &file, const QString &title, const QString &author, const QString &subject, const QString &keywords, + int page, int from, int to, + int embedding_id, const QString &embedding_path) { { if (!q.prepare(INSERT_CHUNK_SQL)) @@ -57,6 +66,14 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_ q.addBindValue(document_id); q.addBindValue(chunk_id); q.addBindValue(chunk_text); + q.addBindValue(file); + q.addBindValue(title); + q.addBindValue(author); + q.addBindValue(subject); + q.addBindValue(keywords); + q.addBindValue(page); + q.addBindValue(from); + q.addBindValue(to); q.addBindValue(embedding_id); q.addBindValue(embedding_path); if (!q.exec()) @@ -68,6 +85,14 @@ bool addChunk(QSqlQuery &q, int document_id, int chunk_id, const QString &chunk_ q.addBindValue(document_id); q.addBindValue(chunk_id); q.addBindValue(chunk_text); + q.addBindValue(file); + q.addBindValue(title); + q.addBindValue(author); + q.addBindValue(subject); + q.addBindValue(keywords); + q.addBindValue(page); + q.addBindValue(from); + q.addBindValue(to); q.addBindValue(embedding_id); q.addBindValue(embedding_path); if (!q.exec()) @@ -145,19 +170,6 @@ bool selectChunk(QSqlQuery &q, const QList &collection_names, const QSt return true; } -void printResults(QSqlQuery &q) -{ - while (q.next()) { - int rowid = q.value(0).toInt(); - QString collection_name = q.value(1).toString(); - QString chunk_text = q.value(2).toString(); - - qDebug() << "rowid:" << rowid - << "collection_name:" << collection_name - << "chunk_text:" << chunk_text; - } -} - const auto INSERT_COLLECTION_SQL = QLatin1String(R"( insert into collections(collection_name, folder_id) values(?, ?); )"); @@ -457,10 +469,18 @@ QSqlError initDb() QString chunk_text1 = "This is an example chunk."; QString chunk_text2 = "Another example chunk."; QString embedding_path = "/example/embeddings/embedding1.bin"; + QString file = "document1.txt"; + QString title; + QString author; + QString subject; + QString keywords; + int page = -1; + int from = -1; + int to = -1;; int embedding_id = 1; - if (!addChunk(q, document_id, 1, chunk_text1, embedding_id, embedding_path) || - !addChunk(q, document_id, 2, chunk_text2, embedding_id, embedding_path)) { + if (!addChunk(q, document_id, 1, chunk_text1, file, title, author, subject, keywords, page, from, to, embedding_id, embedding_path) || + !addChunk(q, document_id, 2, chunk_text2, file, title, author, subject, keywords, page, from, to, embedding_id, embedding_path)) { qDebug() << "Error adding chunks:" << q.lastError().text(); return q.lastError(); } @@ -468,13 +488,10 @@ QSqlError initDb() // Perform a search QList collection_names = {collection_name}; QString search_text = "example"; - if (!selectChunk(q, collection_names, search_text)) { + if (!selectChunk(q, collection_names, search_text, 3)) { qDebug() << "Error selecting chunks:" << q.lastError().text(); return q.lastError(); } - - // Print the results - printResults(q); #endif return QSqlError(); @@ -499,10 +516,13 @@ void Database::handleDocumentErrorAndScheduleNext(const QString &errorMessage, QTimer::singleShot(0, this, &Database::scanQueue); } -void Database::chunkStream(QTextStream &stream, int document_id) +void Database::chunkStream(QTextStream &stream, int document_id, const QString &file, + const QString &title, const QString &author, const QString &subject, const QString &keywords, int page) { int chunk_id = 0; int charCount = 0; + int line_from = -1; + int line_to = -1; QList words; while (!stream.atEnd()) { @@ -517,6 +537,14 @@ void Database::chunkStream(QTextStream &stream, int document_id) document_id, ++chunk_id, chunk, + file, + title, + author, + subject, + keywords, + page, + line_from, + line_to, 0 /*embedding_id*/, QString() /*embedding_path*/ )) { @@ -604,13 +632,18 @@ void Database::scanQueue() document_id, document_path, q.lastError()); return; } - QString text; for (int i = 0; i < doc.pageCount(); ++i) { const QPdfSelection selection = doc.getAllText(i); - text.append(selection.text()); + QString text = selection.text(); + QTextStream stream(&text); + chunkStream(stream, document_id, info.doc.fileName(), + doc.metaData(QPdfDocument::MetaDataField::Title).toString(), + doc.metaData(QPdfDocument::MetaDataField::Author).toString(), + doc.metaData(QPdfDocument::MetaDataField::Subject).toString(), + doc.metaData(QPdfDocument::MetaDataField::Keywords).toString(), + i + 1 + ); } - QTextStream stream(&text); - chunkStream(stream, document_id); } else { QFile file(document_path); if (!file.open( QIODevice::ReadOnly)) { @@ -618,7 +651,8 @@ void Database::scanQueue() existing_id, document_path, q.lastError()); } QTextStream stream(&file); - chunkStream(stream, document_id); + chunkStream(stream, document_id, file.fileName(), QString() /*title*/, QString() /*author*/, + QString() /*subject*/, QString() /*keywords*/, -1 /*page*/); file.close(); } QSqlDatabase::database().commit(); @@ -867,7 +901,7 @@ bool Database::removeFolderFromWatch(const QString &path) return true; } -void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize) +void Database::retrieveFromDB(const QString &uid, const QList &collections, const QString &text, int retrievalSize) { #if defined(DEBUG) qDebug() << "retrieveFromDB" << collections << text << retrievalSize; @@ -879,20 +913,33 @@ void Database::retrieveFromDB(const QList &collections, const QString & return; } - QList results; + QList results; while (q.next()) { - int rowid = q.value(0).toInt(); - QString collection_name = q.value(1).toString(); - QString chunk_text = q.value(2).toString(); - results.append(chunk_text); + const int rowid = q.value(0).toInt(); + const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"); + const QString chunk_text = q.value(2).toString(); + const QString file = q.value(3).toString(); + const QString title = q.value(4).toString(); + const QString author = q.value(5).toString(); + const int page = q.value(6).toInt(); + const int from =q.value(7).toInt(); + const int to =q.value(8).toInt(); + ResultInfo info; + info.file = file; + info.title = title; + info.author = author; + info.date = date; + info.text = chunk_text; + info.page = page; + info.from = from; + info.to = to; + results.append(info); #if defined(DEBUG) qDebug() << "retrieve rowid:" << rowid - << "collection_name:" << collection_name << "chunk_text:" << chunk_text; #endif } - - emit retrieveResult(results); + emit retrieveResult(uid, results); } void Database::cleanDB() diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index 76d8bf57..2f25ff0d 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -14,6 +14,17 @@ struct DocumentInfo QFileInfo doc; }; +struct ResultInfo { + QString file; // [Required] The name of the file, but not the full path + QString title; // [Optional] The title of the document + QString author; // [Optional] The author of the document + QString date; // [Required] The creation or the last modification date whichever is latest + QString text; // [Required] The text actually used in the augmented context + int page = -1; // [Optional] The page where the text was found + int from = -1; // [Optional] The line number where the text begins + int to = -1; // [Optional] The line number where the text ends +}; + struct CollectionItem { QString collection; QString folder_path; @@ -32,13 +43,13 @@ public Q_SLOTS: void scanDocuments(int folder_id, const QString &folder_path); void addFolder(const QString &collection, const QString &path); void removeFolder(const QString &collection, const QString &path); - void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize); + void retrieveFromDB(const QString &uid, const QList &collections, const QString &text, int retrievalSize); void cleanDB(); void changeChunkSize(int chunkSize); Q_SIGNALS: void docsToScanChanged(); - void retrieveResult(const QList &result); + void retrieveResult(const QString &uid, const QList &result); void collectionListUpdated(const QList &collectionList); private Q_SLOTS: @@ -51,14 +62,15 @@ private Q_SLOTS: private: void removeFolderInternal(const QString &collection, int folder_id, const QString &path); - void chunkStream(QTextStream &stream, int document_id); + void chunkStream(QTextStream &stream, int document_id, const QString &file, + const QString &title, const QString &author, const QString &subject, const QString &keywords, int page); void handleDocumentErrorAndScheduleNext(const QString &errorMessage, int document_id, const QString &document_path, const QSqlError &error); private: int m_chunkSize; QQueue m_docsToScan; - QList m_retrieve; + QList m_retrieve; QThread m_dbThread; QFileSystemWatcher *m_watcher; }; diff --git a/gpt4all-chat/localdocs.cpp b/gpt4all-chat/localdocs.cpp index a7905306..6e62c0a2 100644 --- a/gpt4all-chat/localdocs.cpp +++ b/gpt4all-chat/localdocs.cpp @@ -29,7 +29,7 @@ LocalDocs::LocalDocs() connect(this, &LocalDocs::requestChunkSizeChange, m_database, &Database::changeChunkSize, Qt::QueuedConnection); connect(m_database, &Database::retrieveResult, this, - &LocalDocs::handleRetrieveResult, Qt::QueuedConnection); + &LocalDocs::receivedResult, Qt::QueuedConnection); connect(m_database, &Database::collectionListUpdated, m_localDocsModel, &LocalDocsModel::handleCollectionListUpdated, Qt::QueuedConnection); } @@ -49,10 +49,9 @@ void LocalDocs::removeFolder(const QString &collection, const QString &path) emit requestRemoveFolder(collection, path); } -void LocalDocs::requestRetrieve(const QList &collections, const QString &text) +void LocalDocs::requestRetrieve(const QString &uid, const QList &collections, const QString &text) { - m_retrieveResult = QList(); - emit requestRetrieveFromDB(collections, text, m_retrievalSize); + emit requestRetrieveFromDB(uid, collections, text, m_retrievalSize); } int LocalDocs::chunkSize() const @@ -83,9 +82,3 @@ void LocalDocs::setRetrievalSize(int retrievalSize) m_retrievalSize = retrievalSize; emit retrievalSizeChanged(); } - -void LocalDocs::handleRetrieveResult(const QList &result) -{ - m_retrieveResult = result; - emit receivedResult(); -} diff --git a/gpt4all-chat/localdocs.h b/gpt4all-chat/localdocs.h index 9395655a..7011602c 100644 --- a/gpt4all-chat/localdocs.h +++ b/gpt4all-chat/localdocs.h @@ -20,9 +20,7 @@ public: Q_INVOKABLE void addFolder(const QString &collection, const QString &path); Q_INVOKABLE void removeFolder(const QString &collection, const QString &path); - - QList result() const { return m_retrieveResult; } - void requestRetrieve(const QList &collections, const QString &text); + void requestRetrieve(const QString &uid, const QList &collections, const QString &text); int chunkSize() const; void setChunkSize(int chunkSize); @@ -33,22 +31,18 @@ public: Q_SIGNALS: void requestAddFolder(const QString &collection, const QString &path); void requestRemoveFolder(const QString &collection, const QString &path); - void requestRetrieveFromDB(const QList &collections, const QString &text, int N); + void requestRetrieveFromDB(const QString &uid, const QList &collections, const QString &text, int retrievalSize); void requestChunkSizeChange(int chunkSize); - void receivedResult(); + void receivedResult(const QString &uid, const QList &result); void localDocsModelChanged(); void chunkSizeChanged(); void retrievalSizeChanged(); -private Q_SLOTS: - void handleRetrieveResult(const QList &result); - private: int m_chunkSize; int m_retrievalSize; LocalDocsModel *m_localDocsModel; Database *m_database; - QList m_retrieveResult; private: explicit LocalDocs();