Begin converting the localdocs to a tool.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-07-31 22:46:30 -04:00
parent dfe3e951d4
commit 01f67c74ea
11 changed files with 163 additions and 84 deletions

View File

@ -116,7 +116,7 @@ qt_add_executable(chat
database.h database.cpp database.h database.cpp
download.h download.cpp download.h download.cpp
embllm.cpp embllm.h embllm.cpp embllm.h
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocssearch.h localdocssearch.cpp
llm.h llm.cpp llm.h llm.cpp
modellist.h modellist.cpp modellist.h modellist.cpp
mysettings.h mysettings.cpp mysettings.h mysettings.cpp

View File

@ -35,10 +35,8 @@ QString BraveSearch::run(const QJsonObject &parameters, qint64 timeout)
return worker.response(); return worker.response();
} }
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK) void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count)
{ {
m_topK = topK;
// Documentation on the brave web search: // Documentation on the brave web search:
// https://api.search.brave.com/app/documentation/web-search/get-started // https://api.search.brave.com/app/documentation/web-search/get-started
QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search"); QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search");
@ -47,7 +45,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to
//https://api.search.brave.com/app/documentation/web-search/query //https://api.search.brave.com/app/documentation/web-search/query
QUrlQuery urlQuery; QUrlQuery urlQuery;
urlQuery.addQueryItem("q", query); urlQuery.addQueryItem("q", query);
urlQuery.addQueryItem("count", QString::number(topK)); urlQuery.addQueryItem("count", QString::number(count));
urlQuery.addQueryItem("result_filter", "web"); urlQuery.addQueryItem("result_filter", "web");
urlQuery.addQueryItem("extra_snippets", "true"); urlQuery.addQueryItem("extra_snippets", "true");
jsonUrl.setQuery(urlQuery); jsonUrl.setQuery(urlQuery);
@ -64,7 +62,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to
connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred); connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred);
} }
static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1) static QString cleanBraveResponse(const QByteArray& jsonResponse)
{ {
// This parses the response from brave and formats it in json that conforms to the de facto // This parses the response from brave and formats it in json that conforms to the de facto
// standard in SourceExcerpts::fromJson(...) // standard in SourceExcerpts::fromJson(...)
@ -77,7 +75,6 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
QString query; QString query;
QJsonObject searchResponse = document.object(); QJsonObject searchResponse = document.object();
QJsonObject cleanResponse;
QJsonArray cleanArray; QJsonArray cleanArray;
if (searchResponse.contains("query")) { if (searchResponse.contains("query")) {
@ -99,7 +96,7 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
const int idx = m["index"].toInt(); const int idx = m["index"].toInt();
QJsonObject resultObj = resultsArray[idx].toObject(); QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description"}; QStringList selectedKeys = {"type", "title", "url"};
QJsonObject result; QJsonObject result;
for (const auto& key : selectedKeys) for (const auto& key : selectedKeys)
if (resultObj.contains(key)) if (resultObj.contains(key))
@ -107,6 +104,8 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
if (resultObj.contains("page_age")) if (resultObj.contains("page_age"))
result.insert("date", resultObj["page_age"]); result.insert("date", resultObj["page_age"]);
else
result.insert("date", QDate::currentDate().toString());
QJsonArray excerpts; QJsonArray excerpts;
if (resultObj.contains("extra_snippets")) { if (resultObj.contains("extra_snippets")) {
@ -117,12 +116,18 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
excerpt.insert("text", snippet); excerpt.insert("text", snippet);
excerpts.append(excerpt); excerpts.append(excerpt);
} }
if (resultObj.contains("description"))
result.insert("description", resultObj["description"]);
} else {
QJsonObject excerpt;
excerpt.insert("text", resultObj["description"]);
} }
result.insert("excerpts", excerpts); result.insert("excerpts", excerpts);
cleanArray.append(QJsonValue(result)); cleanArray.append(QJsonValue(result));
} }
} }
QJsonObject cleanResponse;
cleanResponse.insert("query", query); cleanResponse.insert("query", query);
cleanResponse.insert("results", cleanArray); cleanResponse.insert("results", cleanArray);
QJsonDocument cleanedDoc(cleanResponse); QJsonDocument cleanedDoc(cleanResponse);
@ -139,12 +144,13 @@ void BraveAPIWorker::handleFinished()
if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) { if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) {
QByteArray jsonData = jsonReply->readAll(); QByteArray jsonData = jsonReply->readAll();
jsonReply->deleteLater(); jsonReply->deleteLater();
m_response = cleanBraveResponse(jsonData, m_topK); m_response = cleanBraveResponse(jsonData);
} else { } else {
QByteArray jsonData = jsonReply->readAll(); QByteArray jsonData = jsonReply->readAll();
qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData; qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData;
jsonReply->deleteLater(); jsonReply->deleteLater();
} }
emit finished();
} }
void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code) void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code)

View File

@ -1,7 +1,6 @@
#ifndef BRAVESEARCH_H #ifndef BRAVESEARCH_H
#define BRAVESEARCH_H #define BRAVESEARCH_H
#include "sourceexcerpt.h"
#include "tool.h" #include "tool.h"
#include <QObject> #include <QObject>
@ -14,14 +13,13 @@ class BraveAPIWorker : public QObject {
public: public:
BraveAPIWorker() BraveAPIWorker()
: QObject(nullptr) : QObject(nullptr)
, m_networkManager(nullptr) , m_networkManager(nullptr) {}
, m_topK(1) {}
virtual ~BraveAPIWorker() {} virtual ~BraveAPIWorker() {}
QString response() const { return m_response; } QString response() const { return m_response; }
public Q_SLOTS: public Q_SLOTS:
void request(const QString &apiKey, const QString &query, int topK); void request(const QString &apiKey, const QString &query, int count);
Q_SIGNALS: Q_SIGNALS:
void finished(); void finished();
@ -33,7 +31,6 @@ private Q_SLOTS:
private: private:
QNetworkAccessManager *m_networkManager; QNetworkAccessManager *m_networkManager;
QString m_response; QString m_response;
int m_topK;
}; };
class BraveSearch : public Tool { class BraveSearch : public Tool {

View File

@ -3,7 +3,7 @@
#include "bravesearch.h" #include "bravesearch.h"
#include "chat.h" #include "chat.h"
#include "chatapi.h" #include "chatapi.h"
#include "localdocs.h" #include "localdocssearch.h"
#include "mysettings.h" #include "mysettings.h"
#include "network.h" #include "network.h"
@ -13,6 +13,7 @@
#include <QGlobalStatic> #include <QGlobalStatic>
#include <QGuiApplication> #include <QGuiApplication>
#include <QIODevice> #include <QIODevice>
#include <QJsonArray>
#include <QJsonDocument> #include <QJsonDocument>
#include <QJsonObject> #include <QJsonObject>
#include <QMutex> #include <QMutex>
@ -128,11 +129,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged); connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged);
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
Qt::BlockingQueuedConnection);
m_llmThread.setObjectName(parent->id()); m_llmThread.setObjectName(parent->id());
m_llmThread.start(); m_llmThread.start();
} }
@ -767,21 +763,33 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (!isModelLoaded()) if (!isModelLoaded())
return false; return false;
QList<SourceExcerpt> databaseResults; QList<SourceExcerpt> localDocsExcerpts;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
if (!collectionList.isEmpty() && !isToolCallResponse) { if (!collectionList.isEmpty() && !isToolCallResponse) {
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks LocalDocsSearch localdocs;
emit sourceExcerptsChanged(databaseResults); QJsonObject parameters;
parameters.insert("text", prompt);
parameters.insert("count", MySettings::globalInstance()->localDocsRetrievalSize());
parameters.insert("collections", QJsonArray::fromStringList(collectionList));
// FIXME: This has to handle errors of the tool call
const QString localDocsResponse = localdocs.run(parameters, 2000 /*msecs to timeout*/);
QString parseError;
localDocsExcerpts = SourceExcerpt::fromJson(localDocsResponse, parseError);
if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for localdocs response:" << parseError;
} else if (!localDocsExcerpts.isEmpty()) {
emit sourceExcerptsChanged(localDocsExcerpts);
}
} }
// Augment the prompt template with the results if any // Augment the prompt template with the results if any
QString docsContext; QString docsContext;
if (!databaseResults.isEmpty()) { if (!localDocsExcerpts.isEmpty()) {
// FIXME(adam): we should be using the new tool template if available otherwise this I guess
QStringList results; QStringList results;
for (const SourceExcerpt &info : databaseResults) for (const SourceExcerpt &info : localDocsExcerpts)
results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text); results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text);
// FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template
docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n")); docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n"));
} }
@ -887,7 +895,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
QString parseError; QString parseError;
QList<SourceExcerpt> sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError); QList<SourceExcerpt> sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError);
if (!parseError.isEmpty()) { if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError; qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError;
} else if (!sourceExcerpts.isEmpty()) { } else if (!sourceExcerpts.isEmpty()) {
emit sourceExcerptsChanged(sourceExcerpts); emit sourceExcerptsChanged(sourceExcerpts);
} }
@ -912,7 +920,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
} }
SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!databaseResults.isEmpty() || isToolCallResponse))) if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse)))
generateQuestions(elapsed); generateQuestions(elapsed);
else else
emit responseStopped(elapsed); emit responseStopped(elapsed);

View File

@ -189,7 +189,6 @@ Q_SIGNALS:
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
void trySwitchContextRequested(const ModelInfo &modelInfo); void trySwitchContextRequested(const ModelInfo &modelInfo);
void trySwitchContextOfLoadedModelCompleted(int value); void trySwitchContextOfLoadedModelCompleted(int value);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<SourceExcerpt> *results);
void reportSpeed(const QString &speed); void reportSpeed(const QString &speed);
void reportDevice(const QString &device); void reportDevice(const QString &device);
void reportFallbackReason(const QString &fallbackReason); void reportFallbackReason(const QString &fallbackReason);

View File

@ -1938,7 +1938,7 @@ QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QLi
} }
void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize,
QList<SourceExcerpt> *results) QString &jsonResult)
{ {
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "retrieveFromDB" << collections << text << retrievalSize; qDebug() << "retrieveFromDB" << collections << text << retrievalSize;
@ -1960,37 +1960,49 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
return; return;
} }
QMap<QString, QJsonObject> results;
while (q.next()) { while (q.next()) {
#if defined(DEBUG) #if defined(DEBUG)
const int rowid = q.value(0).toInt(); const int rowid = q.value(0).toInt();
#endif #endif
const QString document_path = q.value(2).toString();
const QString chunk_text = q.value(3).toString();
const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd");
const QString file = q.value(4).toString(); const QString file = q.value(4).toString();
const QString title = q.value(5).toString(); QJsonObject resultObject = results.value(file);
const QString author = q.value(6).toString(); resultObject.insert("file", file);
const int page = q.value(7).toInt(); resultObject.insert("path", q.value(2).toString());
const int from = q.value(8).toInt(); resultObject.insert("date", QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"));
const int to = q.value(9).toInt(); resultObject.insert("title", q.value(5).toString());
const QString collectionName = q.value(10).toString(); resultObject.insert("author", q.value(6).toString());
SourceExcerpt info; resultObject.insert("collection", q.value(10).toString());
info.collection = collectionName;
info.path = document_path; QJsonArray excerpts;
info.file = file; if (resultObject.contains("excerpts"))
info.title = title; excerpts = resultObject["excerpts"].toArray();
info.author = author;
info.date = date; QJsonObject excerptObject;
info.text = chunk_text; excerptObject.insert("text", q.value(3).toString());
info.page = page; excerptObject.insert("page", q.value(7).toInt());
info.from = from; excerptObject.insert("from", q.value(8).toInt());
info.to = to; excerptObject.insert("to", q.value(9).toInt());
results->append(info); excerpts.append(excerptObject);
resultObject.insert("excerpts", excerpts);
results.insert(file, resultObject);
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "retrieve rowid:" << rowid qDebug() << "retrieve rowid:" << rowid
<< "chunk_text:" << chunk_text; << "chunk_text:" << chunk_text;
#endif #endif
} }
QJsonArray resultsArray;
QList<QJsonObject> resultsList = results.values();
for (const QJsonObject &result : resultsList)
resultsArray.append(QJsonValue(result));
QJsonObject response;
response.insert("results", resultsArray);
QJsonDocument document(response);
// qDebug().noquote() << document.toJson(QJsonDocument::Indented);
jsonResult = document.toJson(QJsonDocument::Compact);
} }
// FIXME This is very slow and non-interruptible and when we close the application and we're // FIXME This is very slow and non-interruptible and when we close the application and we're

View File

@ -101,7 +101,7 @@ public Q_SLOTS:
void forceRebuildFolder(const QString &path); void forceRebuildFolder(const QString &path);
bool addFolder(const QString &collection, const QString &path, const QString &embedding_model); bool addFolder(const QString &collection, const QString &path, const QString &embedding_model);
void removeFolder(const QString &collection, const QString &path); void removeFolder(const QString &collection, const QString &path);
void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<SourceExcerpt> *results); void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QString &jsonResult);
void changeChunkSize(int chunkSize); void changeChunkSize(int chunkSize);
void changeFileExtensions(const QStringList &extensions); void changeFileExtensions(const QStringList &extensions);
@ -168,7 +168,6 @@ private:
QStringList m_scannedFileExtensions; QStringList m_scannedFileExtensions;
QTimer *m_scanTimer; QTimer *m_scanTimer;
QMap<int, QQueue<DocumentInfo>> m_docsToScan; QMap<int, QQueue<DocumentInfo>> m_docsToScan;
QList<SourceExcerpt> m_retrieve;
QThread m_dbThread; QThread m_dbThread;
QFileSystemWatcher *m_watcher; QFileSystemWatcher *m_watcher;
QSet<QString> m_watchedPaths; QSet<QString> m_watchedPaths;

View File

@ -0,0 +1,50 @@
#include "localdocssearch.h"
#include "database.h"
#include "localdocs.h"
#include <QCoreApplication>
#include <QDebug>
#include <QGuiApplication>
#include <QJsonArray>
#include <QJsonObject>
#include <QThread>
using namespace Qt::Literals::StringLiterals;
QString LocalDocsSearch::run(const QJsonObject &parameters, qint64 timeout)
{
QList<QString> collections;
QJsonArray collectionsArray = parameters["collections"].toArray();
for (int i = 0; i < collectionsArray.size(); ++i)
collections.append(collectionsArray[i].toString());
const QString text = parameters["text"].toString();
const int count = parameters["count"].toInt();
QThread workerThread;
LocalDocsWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &LocalDocsWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(&workerThread, &QThread::started, [&worker, collections, text, count]() {
worker.request(collections, text, count);
});
workerThread.start();
workerThread.wait(timeout);
workerThread.quit();
workerThread.wait();
return worker.response();
}
LocalDocsWorker::LocalDocsWorker()
: QObject(nullptr)
{
// The following are blocking operations and will block the calling thread
connect(this, &LocalDocsWorker::requestRetrieveFromDB, LocalDocs::globalInstance()->database(),
&Database::retrieveFromDB, Qt::BlockingQueuedConnection);
}
void LocalDocsWorker::request(const QList<QString> &collections, const QString &text, int count)
{
QString jsonResult;
emit requestRetrieveFromDB(collections, text, count, jsonResult); // blocks
m_response = jsonResult;
emit finished();
}

View File

@ -0,0 +1,36 @@
#ifndef LOCALDOCSSEARCH_H
#define LOCALDOCSSEARCH_H
#include "tool.h"
#include <QObject>
#include <QString>
class LocalDocsWorker : public QObject {
Q_OBJECT
public:
LocalDocsWorker();
virtual ~LocalDocsWorker() {}
QString response() const { return m_response; }
void request(const QList<QString> &collections, const QString &text, int count);
Q_SIGNALS:
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int count, QString &jsonResponse);
void finished();
private:
QString m_response;
};
class LocalDocsSearch : public Tool {
Q_OBJECT
public:
LocalDocsSearch() : Tool() {}
virtual ~LocalDocsSearch() {}
QString run(const QJsonObject &parameters, qint64 timeout = 2000) override;
};
#endif // LOCALDOCSSEARCH_H

View File

@ -77,19 +77,7 @@ public:
static QList<SourceExcerpt> fromJson(const QString &json, QString &errorString); static QList<SourceExcerpt> fromJson(const QString &json, QString &errorString);
bool operator==(const SourceExcerpt &other) const { bool operator==(const SourceExcerpt &other) const {
return date == other.date && return file == other.file || url == other.url;
text == other.text &&
collection == other.collection &&
path == other.path &&
file == other.file &&
url == other.url &&
favicon == other.favicon &&
title == other.title &&
author == other.author &&
description == other.description &&
page == other.page &&
from == other.from &&
to == other.to;
} }
bool operator!=(const SourceExcerpt &other) const { bool operator!=(const SourceExcerpt &other) const {
return !(*this == other); return !(*this == other);

View File

@ -1,8 +1,6 @@
#ifndef TOOL_H #ifndef TOOL_H
#define TOOL_H #define TOOL_H
#include "sourceexcerpt.h"
#include <QObject> #include <QObject>
#include <QJsonObject> #include <QJsonObject>
@ -70,18 +68,4 @@ public:
virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000) = 0; virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000) = 0;
}; };
//class BuiltinTool : public Tool {
// Q_OBJECT
//public:
// BuiltinTool() : Tool() {}
// virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000);
//};
//class LocalTool : public Tool {
// Q_OBJECT
//public:
// LocalTool() : Tool() {}
// virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000);
//};
#endif // TOOL_H #endif // TOOL_H