From 01f67c74ea16150ad11ade1354bfbedbb9defe71 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 31 Jul 2024 22:46:30 -0400 Subject: [PATCH] Begin converting the localdocs to a tool. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 2 +- gpt4all-chat/bravesearch.cpp | 22 ++++++++----- gpt4all-chat/bravesearch.h | 7 ++-- gpt4all-chat/chatllm.cpp | 40 ++++++++++++++--------- gpt4all-chat/chatllm.h | 1 - gpt4all-chat/database.cpp | 56 +++++++++++++++++++------------- gpt4all-chat/database.h | 3 +- gpt4all-chat/localdocssearch.cpp | 50 ++++++++++++++++++++++++++++ gpt4all-chat/localdocssearch.h | 36 ++++++++++++++++++++ gpt4all-chat/sourceexcerpt.h | 14 +------- gpt4all-chat/tool.h | 16 --------- 11 files changed, 163 insertions(+), 84 deletions(-) create mode 100644 gpt4all-chat/localdocssearch.cpp create mode 100644 gpt4all-chat/localdocssearch.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 7f0c5143..1ab4db20 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -116,7 +116,7 @@ qt_add_executable(chat database.h database.cpp download.h download.cpp embllm.cpp embllm.h - localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp + localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocssearch.h localdocssearch.cpp llm.h llm.cpp modellist.h modellist.cpp mysettings.h mysettings.cpp diff --git a/gpt4all-chat/bravesearch.cpp b/gpt4all-chat/bravesearch.cpp index 85e84971..e691810b 100644 --- a/gpt4all-chat/bravesearch.cpp +++ b/gpt4all-chat/bravesearch.cpp @@ -35,10 +35,8 @@ QString BraveSearch::run(const QJsonObject ¶meters, qint64 timeout) return worker.response(); } -void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK) +void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count) { - m_topK = topK; - // Documentation on the brave web search: // https://api.search.brave.com/app/documentation/web-search/get-started QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search"); @@ -47,7 +45,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to //https://api.search.brave.com/app/documentation/web-search/query QUrlQuery urlQuery; urlQuery.addQueryItem("q", query); - urlQuery.addQueryItem("count", QString::number(topK)); + urlQuery.addQueryItem("count", QString::number(count)); urlQuery.addQueryItem("result_filter", "web"); urlQuery.addQueryItem("extra_snippets", "true"); jsonUrl.setQuery(urlQuery); @@ -64,7 +62,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred); } -static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1) +static QString cleanBraveResponse(const QByteArray& jsonResponse) { // This parses the response from brave and formats it in json that conforms to the de facto // standard in SourceExcerpts::fromJson(...) @@ -77,7 +75,6 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK QString query; QJsonObject searchResponse = document.object(); - QJsonObject cleanResponse; QJsonArray cleanArray; if (searchResponse.contains("query")) { @@ -99,7 +96,7 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK const int idx = m["index"].toInt(); QJsonObject resultObj = resultsArray[idx].toObject(); - QStringList selectedKeys = {"type", "title", "url", "description"}; + QStringList selectedKeys = {"type", "title", "url"}; QJsonObject result; for (const auto& key : selectedKeys) if (resultObj.contains(key)) @@ -107,6 +104,8 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK if (resultObj.contains("page_age")) result.insert("date", resultObj["page_age"]); + else + result.insert("date", QDate::currentDate().toString()); QJsonArray excerpts; if (resultObj.contains("extra_snippets")) { @@ -117,12 +116,18 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK excerpt.insert("text", snippet); excerpts.append(excerpt); } + if (resultObj.contains("description")) + result.insert("description", resultObj["description"]); + } else { + QJsonObject excerpt; + excerpt.insert("text", resultObj["description"]); } result.insert("excerpts", excerpts); cleanArray.append(QJsonValue(result)); } } + QJsonObject cleanResponse; cleanResponse.insert("query", query); cleanResponse.insert("results", cleanArray); QJsonDocument cleanedDoc(cleanResponse); @@ -139,12 +144,13 @@ void BraveAPIWorker::handleFinished() if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) { QByteArray jsonData = jsonReply->readAll(); jsonReply->deleteLater(); - m_response = cleanBraveResponse(jsonData, m_topK); + m_response = cleanBraveResponse(jsonData); } else { QByteArray jsonData = jsonReply->readAll(); qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData; jsonReply->deleteLater(); } + emit finished(); } void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code) diff --git a/gpt4all-chat/bravesearch.h b/gpt4all-chat/bravesearch.h index 6f617c4a..28f84b15 100644 --- a/gpt4all-chat/bravesearch.h +++ b/gpt4all-chat/bravesearch.h @@ -1,7 +1,6 @@ #ifndef BRAVESEARCH_H #define BRAVESEARCH_H -#include "sourceexcerpt.h" #include "tool.h" #include @@ -14,14 +13,13 @@ class BraveAPIWorker : public QObject { public: BraveAPIWorker() : QObject(nullptr) - , m_networkManager(nullptr) - , m_topK(1) {} + , m_networkManager(nullptr) {} virtual ~BraveAPIWorker() {} QString response() const { return m_response; } public Q_SLOTS: - void request(const QString &apiKey, const QString &query, int topK); + void request(const QString &apiKey, const QString &query, int count); Q_SIGNALS: void finished(); @@ -33,7 +31,6 @@ private Q_SLOTS: private: QNetworkAccessManager *m_networkManager; QString m_response; - int m_topK; }; class BraveSearch : public Tool { diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 5c4f168c..31b6e92f 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -3,7 +3,7 @@ #include "bravesearch.h" #include "chat.h" #include "chatapi.h" -#include "localdocs.h" +#include "localdocssearch.h" #include "mysettings.h" #include "network.h" @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -128,11 +129,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged); - - // The following are blocking operations and will block the llm thread - connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, - Qt::BlockingQueuedConnection); - m_llmThread.setObjectName(parent->id()); m_llmThread.start(); } @@ -767,21 +763,33 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString if (!isModelLoaded()) return false; - QList databaseResults; - const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); + QList localDocsExcerpts; if (!collectionList.isEmpty() && !isToolCallResponse) { - emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks - emit sourceExcerptsChanged(databaseResults); + LocalDocsSearch localdocs; + QJsonObject parameters; + parameters.insert("text", prompt); + parameters.insert("count", MySettings::globalInstance()->localDocsRetrievalSize()); + parameters.insert("collections", QJsonArray::fromStringList(collectionList)); + + // FIXME: This has to handle errors of the tool call + const QString localDocsResponse = localdocs.run(parameters, 2000 /*msecs to timeout*/); + + QString parseError; + localDocsExcerpts = SourceExcerpt::fromJson(localDocsResponse, parseError); + if (!parseError.isEmpty()) { + qWarning() << "ERROR: Could not parse source excerpts for localdocs response:" << parseError; + } else if (!localDocsExcerpts.isEmpty()) { + emit sourceExcerptsChanged(localDocsExcerpts); + } } // Augment the prompt template with the results if any QString docsContext; - if (!databaseResults.isEmpty()) { + if (!localDocsExcerpts.isEmpty()) { + // FIXME(adam): we should be using the new tool template if available otherwise this I guess QStringList results; - for (const SourceExcerpt &info : databaseResults) + for (const SourceExcerpt &info : localDocsExcerpts) results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text); - - // FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n")); } @@ -887,7 +895,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString QString parseError; QList sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError); if (!parseError.isEmpty()) { - qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError; + qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError; } else if (!sourceExcerpts.isEmpty()) { emit sourceExcerptsChanged(sourceExcerpts); } @@ -912,7 +920,7 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString } SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); - if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!databaseResults.isEmpty() || isToolCallResponse))) + if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse))) generateQuestions(elapsed); else emit responseStopped(elapsed); diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 050c0b7d..feacd744 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -189,7 +189,6 @@ Q_SIGNALS: void shouldBeLoadedChanged(); void trySwitchContextRequested(const ModelInfo &modelInfo); void trySwitchContextOfLoadedModelCompleted(int value); - void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); void reportDevice(const QString &device); void reportFallbackReason(const QString &fallbackReason); diff --git a/gpt4all-chat/database.cpp b/gpt4all-chat/database.cpp index f2fc5094..0713fc9f 100644 --- a/gpt4all-chat/database.cpp +++ b/gpt4all-chat/database.cpp @@ -1938,7 +1938,7 @@ QList Database::searchEmbeddings(const std::vector &query, const QLi } void Database::retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, - QList *results) + QString &jsonResult) { #if defined(DEBUG) qDebug() << "retrieveFromDB" << collections << text << retrievalSize; @@ -1960,37 +1960,49 @@ void Database::retrieveFromDB(const QList &collections, const QString & return; } + QMap results; while (q.next()) { #if defined(DEBUG) const int rowid = q.value(0).toInt(); #endif - const QString document_path = q.value(2).toString(); - const QString chunk_text = q.value(3).toString(); - const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"); const QString file = q.value(4).toString(); - const QString title = q.value(5).toString(); - const QString author = q.value(6).toString(); - const int page = q.value(7).toInt(); - const int from = q.value(8).toInt(); - const int to = q.value(9).toInt(); - const QString collectionName = q.value(10).toString(); - SourceExcerpt info; - info.collection = collectionName; - info.path = document_path; - info.file = file; - info.title = title; - info.author = author; - info.date = date; - info.text = chunk_text; - info.page = page; - info.from = from; - info.to = to; - results->append(info); + QJsonObject resultObject = results.value(file); + resultObject.insert("file", file); + resultObject.insert("path", q.value(2).toString()); + resultObject.insert("date", QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd")); + resultObject.insert("title", q.value(5).toString()); + resultObject.insert("author", q.value(6).toString()); + resultObject.insert("collection", q.value(10).toString()); + + QJsonArray excerpts; + if (resultObject.contains("excerpts")) + excerpts = resultObject["excerpts"].toArray(); + + QJsonObject excerptObject; + excerptObject.insert("text", q.value(3).toString()); + excerptObject.insert("page", q.value(7).toInt()); + excerptObject.insert("from", q.value(8).toInt()); + excerptObject.insert("to", q.value(9).toInt()); + excerpts.append(excerptObject); + resultObject.insert("excerpts", excerpts); + results.insert(file, resultObject); + #if defined(DEBUG) qDebug() << "retrieve rowid:" << rowid << "chunk_text:" << chunk_text; #endif } + + QJsonArray resultsArray; + QList resultsList = results.values(); + for (const QJsonObject &result : resultsList) + resultsArray.append(QJsonValue(result)); + + QJsonObject response; + response.insert("results", resultsArray); + QJsonDocument document(response); +// qDebug().noquote() << document.toJson(QJsonDocument::Indented); + jsonResult = document.toJson(QJsonDocument::Compact); } // FIXME This is very slow and non-interruptible and when we close the application and we're diff --git a/gpt4all-chat/database.h b/gpt4all-chat/database.h index fd0a78d9..a2e7cf22 100644 --- a/gpt4all-chat/database.h +++ b/gpt4all-chat/database.h @@ -101,7 +101,7 @@ public Q_SLOTS: void forceRebuildFolder(const QString &path); bool addFolder(const QString &collection, const QString &path, const QString &embedding_model); void removeFolder(const QString &collection, const QString &path); - void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); + void retrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QString &jsonResult); void changeChunkSize(int chunkSize); void changeFileExtensions(const QStringList &extensions); @@ -168,7 +168,6 @@ private: QStringList m_scannedFileExtensions; QTimer *m_scanTimer; QMap> m_docsToScan; - QList m_retrieve; QThread m_dbThread; QFileSystemWatcher *m_watcher; QSet m_watchedPaths; diff --git a/gpt4all-chat/localdocssearch.cpp b/gpt4all-chat/localdocssearch.cpp new file mode 100644 index 00000000..77becf4e --- /dev/null +++ b/gpt4all-chat/localdocssearch.cpp @@ -0,0 +1,50 @@ +#include "localdocssearch.h" +#include "database.h" +#include "localdocs.h" + +#include +#include +#include +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + +QString LocalDocsSearch::run(const QJsonObject ¶meters, qint64 timeout) +{ + QList collections; + QJsonArray collectionsArray = parameters["collections"].toArray(); + for (int i = 0; i < collectionsArray.size(); ++i) + collections.append(collectionsArray[i].toString()); + const QString text = parameters["text"].toString(); + const int count = parameters["count"].toInt(); + QThread workerThread; + LocalDocsWorker worker; + worker.moveToThread(&workerThread); + connect(&worker, &LocalDocsWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); + connect(&workerThread, &QThread::started, [&worker, collections, text, count]() { + worker.request(collections, text, count); + }); + workerThread.start(); + workerThread.wait(timeout); + workerThread.quit(); + workerThread.wait(); + return worker.response(); +} + +LocalDocsWorker::LocalDocsWorker() + : QObject(nullptr) +{ + // The following are blocking operations and will block the calling thread + connect(this, &LocalDocsWorker::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), + &Database::retrieveFromDB, Qt::BlockingQueuedConnection); +} + +void LocalDocsWorker::request(const QList &collections, const QString &text, int count) +{ + QString jsonResult; + emit requestRetrieveFromDB(collections, text, count, jsonResult); // blocks + m_response = jsonResult; + emit finished(); +} diff --git a/gpt4all-chat/localdocssearch.h b/gpt4all-chat/localdocssearch.h new file mode 100644 index 00000000..4f757a52 --- /dev/null +++ b/gpt4all-chat/localdocssearch.h @@ -0,0 +1,36 @@ +#ifndef LOCALDOCSSEARCH_H +#define LOCALDOCSSEARCH_H + +#include "tool.h" + +#include +#include + +class LocalDocsWorker : public QObject { + Q_OBJECT +public: + LocalDocsWorker(); + virtual ~LocalDocsWorker() {} + + QString response() const { return m_response; } + + void request(const QList &collections, const QString &text, int count); + +Q_SIGNALS: + void requestRetrieveFromDB(const QList &collections, const QString &text, int count, QString &jsonResponse); + void finished(); + +private: + QString m_response; +}; + +class LocalDocsSearch : public Tool { + Q_OBJECT +public: + LocalDocsSearch() : Tool() {} + virtual ~LocalDocsSearch() {} + + QString run(const QJsonObject ¶meters, qint64 timeout = 2000) override; +}; + +#endif // LOCALDOCSSEARCH_H diff --git a/gpt4all-chat/sourceexcerpt.h b/gpt4all-chat/sourceexcerpt.h index c66007f0..82769239 100644 --- a/gpt4all-chat/sourceexcerpt.h +++ b/gpt4all-chat/sourceexcerpt.h @@ -77,19 +77,7 @@ public: static QList fromJson(const QString &json, QString &errorString); bool operator==(const SourceExcerpt &other) const { - return date == other.date && - text == other.text && - collection == other.collection && - path == other.path && - file == other.file && - url == other.url && - favicon == other.favicon && - title == other.title && - author == other.author && - description == other.description && - page == other.page && - from == other.from && - to == other.to; + return file == other.file || url == other.url; } bool operator!=(const SourceExcerpt &other) const { return !(*this == other); diff --git a/gpt4all-chat/tool.h b/gpt4all-chat/tool.h index 240d2aa2..3dac88e7 100644 --- a/gpt4all-chat/tool.h +++ b/gpt4all-chat/tool.h @@ -1,8 +1,6 @@ #ifndef TOOL_H #define TOOL_H -#include "sourceexcerpt.h" - #include #include @@ -70,18 +68,4 @@ public: virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000) = 0; }; -//class BuiltinTool : public Tool { -// Q_OBJECT -//public: -// BuiltinTool() : Tool() {} -// virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000); -//}; - -//class LocalTool : public Tool { -// Q_OBJECT -//public: -// LocalTool() : Tool() {} -// virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000); -//}; - #endif // TOOL_H