Begin converting the localdocs to a tool.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-07-31 22:46:30 -04:00
parent dfe3e951d4
commit 01f67c74ea
11 changed files with 163 additions and 84 deletions

View File

@ -116,7 +116,7 @@ qt_add_executable(chat
database.h database.cpp
download.h download.cpp
embllm.cpp embllm.h
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp
localdocs.h localdocs.cpp localdocsmodel.h localdocsmodel.cpp localdocssearch.h localdocssearch.cpp
llm.h llm.cpp
modellist.h modellist.cpp
mysettings.h mysettings.cpp

View File

@ -35,10 +35,8 @@ QString BraveSearch::run(const QJsonObject &parameters, qint64 timeout)
return worker.response();
}
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK)
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int count)
{
m_topK = topK;
// Documentation on the brave web search:
// https://api.search.brave.com/app/documentation/web-search/get-started
QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search");
@ -47,7 +45,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to
//https://api.search.brave.com/app/documentation/web-search/query
QUrlQuery urlQuery;
urlQuery.addQueryItem("q", query);
urlQuery.addQueryItem("count", QString::number(topK));
urlQuery.addQueryItem("count", QString::number(count));
urlQuery.addQueryItem("result_filter", "web");
urlQuery.addQueryItem("extra_snippets", "true");
jsonUrl.setQuery(urlQuery);
@ -64,7 +62,7 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to
connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred);
}
static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
static QString cleanBraveResponse(const QByteArray& jsonResponse)
{
// This parses the response from brave and formats it in json that conforms to the de facto
// standard in SourceExcerpts::fromJson(...)
@ -77,7 +75,6 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
QString query;
QJsonObject searchResponse = document.object();
QJsonObject cleanResponse;
QJsonArray cleanArray;
if (searchResponse.contains("query")) {
@ -99,7 +96,7 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
const int idx = m["index"].toInt();
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description"};
QStringList selectedKeys = {"type", "title", "url"};
QJsonObject result;
for (const auto& key : selectedKeys)
if (resultObj.contains(key))
@ -107,6 +104,8 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
if (resultObj.contains("page_age"))
result.insert("date", resultObj["page_age"]);
else
result.insert("date", QDate::currentDate().toString());
QJsonArray excerpts;
if (resultObj.contains("extra_snippets")) {
@ -117,12 +116,18 @@ static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK
excerpt.insert("text", snippet);
excerpts.append(excerpt);
}
if (resultObj.contains("description"))
result.insert("description", resultObj["description"]);
} else {
QJsonObject excerpt;
excerpt.insert("text", resultObj["description"]);
}
result.insert("excerpts", excerpts);
cleanArray.append(QJsonValue(result));
}
}
QJsonObject cleanResponse;
cleanResponse.insert("query", query);
cleanResponse.insert("results", cleanArray);
QJsonDocument cleanedDoc(cleanResponse);
@ -139,12 +144,13 @@ void BraveAPIWorker::handleFinished()
if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) {
QByteArray jsonData = jsonReply->readAll();
jsonReply->deleteLater();
m_response = cleanBraveResponse(jsonData, m_topK);
m_response = cleanBraveResponse(jsonData);
} else {
QByteArray jsonData = jsonReply->readAll();
qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData;
jsonReply->deleteLater();
}
emit finished();
}
void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code)

View File

@ -1,7 +1,6 @@
#ifndef BRAVESEARCH_H
#define BRAVESEARCH_H
#include "sourceexcerpt.h"
#include "tool.h"
#include <QObject>
@ -14,14 +13,13 @@ class BraveAPIWorker : public QObject {
public:
BraveAPIWorker()
: QObject(nullptr)
, m_networkManager(nullptr)
, m_topK(1) {}
, m_networkManager(nullptr) {}
virtual ~BraveAPIWorker() {}
QString response() const { return m_response; }
public Q_SLOTS:
void request(const QString &apiKey, const QString &query, int topK);
void request(const QString &apiKey, const QString &query, int count);
Q_SIGNALS:
void finished();
@ -33,7 +31,6 @@ private Q_SLOTS:
private:
QNetworkAccessManager *m_networkManager;
QString m_response;
int m_topK;
};
class BraveSearch : public Tool {

View File

@ -3,7 +3,7 @@
#include "bravesearch.h"
#include "chat.h"
#include "chatapi.h"
#include "localdocs.h"
#include "localdocssearch.h"
#include "mysettings.h"
#include "network.h"
@ -13,6 +13,7 @@
#include <QGlobalStatic>
#include <QGuiApplication>
#include <QIODevice>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QMutex>
@ -128,11 +129,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged);
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
Qt::BlockingQueuedConnection);
m_llmThread.setObjectName(parent->id());
m_llmThread.start();
}
@ -767,21 +763,33 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
if (!isModelLoaded())
return false;
QList<SourceExcerpt> databaseResults;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
QList<SourceExcerpt> localDocsExcerpts;
if (!collectionList.isEmpty() && !isToolCallResponse) {
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
emit sourceExcerptsChanged(databaseResults);
LocalDocsSearch localdocs;
QJsonObject parameters;
parameters.insert("text", prompt);
parameters.insert("count", MySettings::globalInstance()->localDocsRetrievalSize());
parameters.insert("collections", QJsonArray::fromStringList(collectionList));
// FIXME: This has to handle errors of the tool call
const QString localDocsResponse = localdocs.run(parameters, 2000 /*msecs to timeout*/);
QString parseError;
localDocsExcerpts = SourceExcerpt::fromJson(localDocsResponse, parseError);
if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for localdocs response:" << parseError;
} else if (!localDocsExcerpts.isEmpty()) {
emit sourceExcerptsChanged(localDocsExcerpts);
}
}
// Augment the prompt template with the results if any
QString docsContext;
if (!databaseResults.isEmpty()) {
if (!localDocsExcerpts.isEmpty()) {
// FIXME(adam): we should be using the new tool template if available otherwise this I guess
QStringList results;
for (const SourceExcerpt &info : databaseResults)
for (const SourceExcerpt &info : localDocsExcerpts)
results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text);
// FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template
docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n"));
}
@ -887,7 +895,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
QString parseError;
QList<SourceExcerpt> sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError);
if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError;
qWarning() << "ERROR: Could not parse source excerpts for brave response:" << parseError;
} else if (!sourceExcerpts.isEmpty()) {
emit sourceExcerptsChanged(sourceExcerpts);
}
@ -912,7 +920,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
}
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!databaseResults.isEmpty() || isToolCallResponse)))
if (mode == SuggestionMode::On || (mode == SuggestionMode::SourceExcerptsOnly && (!localDocsExcerpts.isEmpty() || isToolCallResponse)))
generateQuestions(elapsed);
else
emit responseStopped(elapsed);

View File

@ -189,7 +189,6 @@ Q_SIGNALS:
void shouldBeLoadedChanged();
void trySwitchContextRequested(const ModelInfo &modelInfo);
void trySwitchContextOfLoadedModelCompleted(int value);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<SourceExcerpt> *results);
void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
void reportFallbackReason(const QString &fallbackReason);

View File

@ -1938,7 +1938,7 @@ QList<int> Database::searchEmbeddings(const std::vector<float> &query, const QLi
}
void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize,
QList<SourceExcerpt> *results)
QString &jsonResult)
{
#if defined(DEBUG)
qDebug() << "retrieveFromDB" << collections << text << retrievalSize;
@ -1960,37 +1960,49 @@ void Database::retrieveFromDB(const QList<QString> &collections, const QString &
return;
}
QMap<QString, QJsonObject> results;
while (q.next()) {
#if defined(DEBUG)
const int rowid = q.value(0).toInt();
#endif
const QString document_path = q.value(2).toString();
const QString chunk_text = q.value(3).toString();
const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd");
const QString file = q.value(4).toString();
const QString title = q.value(5).toString();
const QString author = q.value(6).toString();
const int page = q.value(7).toInt();
const int from = q.value(8).toInt();
const int to = q.value(9).toInt();
const QString collectionName = q.value(10).toString();
SourceExcerpt info;
info.collection = collectionName;
info.path = document_path;
info.file = file;
info.title = title;
info.author = author;
info.date = date;
info.text = chunk_text;
info.page = page;
info.from = from;
info.to = to;
results->append(info);
QJsonObject resultObject = results.value(file);
resultObject.insert("file", file);
resultObject.insert("path", q.value(2).toString());
resultObject.insert("date", QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd"));
resultObject.insert("title", q.value(5).toString());
resultObject.insert("author", q.value(6).toString());
resultObject.insert("collection", q.value(10).toString());
QJsonArray excerpts;
if (resultObject.contains("excerpts"))
excerpts = resultObject["excerpts"].toArray();
QJsonObject excerptObject;
excerptObject.insert("text", q.value(3).toString());
excerptObject.insert("page", q.value(7).toInt());
excerptObject.insert("from", q.value(8).toInt());
excerptObject.insert("to", q.value(9).toInt());
excerpts.append(excerptObject);
resultObject.insert("excerpts", excerpts);
results.insert(file, resultObject);
#if defined(DEBUG)
qDebug() << "retrieve rowid:" << rowid
<< "chunk_text:" << chunk_text;
#endif
}
QJsonArray resultsArray;
QList<QJsonObject> resultsList = results.values();
for (const QJsonObject &result : resultsList)
resultsArray.append(QJsonValue(result));
QJsonObject response;
response.insert("results", resultsArray);
QJsonDocument document(response);
// qDebug().noquote() << document.toJson(QJsonDocument::Indented);
jsonResult = document.toJson(QJsonDocument::Compact);
}
// FIXME This is very slow and non-interruptible and when we close the application and we're

View File

@ -101,7 +101,7 @@ public Q_SLOTS:
void forceRebuildFolder(const QString &path);
bool addFolder(const QString &collection, const QString &path, const QString &embedding_model);
void removeFolder(const QString &collection, const QString &path);
void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<SourceExcerpt> *results);
void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QString &jsonResult);
void changeChunkSize(int chunkSize);
void changeFileExtensions(const QStringList &extensions);
@ -168,7 +168,6 @@ private:
QStringList m_scannedFileExtensions;
QTimer *m_scanTimer;
QMap<int, QQueue<DocumentInfo>> m_docsToScan;
QList<SourceExcerpt> m_retrieve;
QThread m_dbThread;
QFileSystemWatcher *m_watcher;
QSet<QString> m_watchedPaths;

View File

@ -0,0 +1,50 @@
#include "localdocssearch.h"
#include "database.h"
#include "localdocs.h"
#include <QCoreApplication>
#include <QDebug>
#include <QGuiApplication>
#include <QJsonArray>
#include <QJsonObject>
#include <QThread>
using namespace Qt::Literals::StringLiterals;
QString LocalDocsSearch::run(const QJsonObject &parameters, qint64 timeout)
{
QList<QString> collections;
QJsonArray collectionsArray = parameters["collections"].toArray();
for (int i = 0; i < collectionsArray.size(); ++i)
collections.append(collectionsArray[i].toString());
const QString text = parameters["text"].toString();
const int count = parameters["count"].toInt();
QThread workerThread;
LocalDocsWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &LocalDocsWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(&workerThread, &QThread::started, [&worker, collections, text, count]() {
worker.request(collections, text, count);
});
workerThread.start();
workerThread.wait(timeout);
workerThread.quit();
workerThread.wait();
return worker.response();
}
LocalDocsWorker::LocalDocsWorker()
: QObject(nullptr)
{
// The following are blocking operations and will block the calling thread
connect(this, &LocalDocsWorker::requestRetrieveFromDB, LocalDocs::globalInstance()->database(),
&Database::retrieveFromDB, Qt::BlockingQueuedConnection);
}
void LocalDocsWorker::request(const QList<QString> &collections, const QString &text, int count)
{
QString jsonResult;
emit requestRetrieveFromDB(collections, text, count, jsonResult); // blocks
m_response = jsonResult;
emit finished();
}

View File

@ -0,0 +1,36 @@
#ifndef LOCALDOCSSEARCH_H
#define LOCALDOCSSEARCH_H
#include "tool.h"
#include <QObject>
#include <QString>
class LocalDocsWorker : public QObject {
Q_OBJECT
public:
LocalDocsWorker();
virtual ~LocalDocsWorker() {}
QString response() const { return m_response; }
void request(const QList<QString> &collections, const QString &text, int count);
Q_SIGNALS:
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int count, QString &jsonResponse);
void finished();
private:
QString m_response;
};
class LocalDocsSearch : public Tool {
Q_OBJECT
public:
LocalDocsSearch() : Tool() {}
virtual ~LocalDocsSearch() {}
QString run(const QJsonObject &parameters, qint64 timeout = 2000) override;
};
#endif // LOCALDOCSSEARCH_H

View File

@ -77,19 +77,7 @@ public:
static QList<SourceExcerpt> fromJson(const QString &json, QString &errorString);
bool operator==(const SourceExcerpt &other) const {
return date == other.date &&
text == other.text &&
collection == other.collection &&
path == other.path &&
file == other.file &&
url == other.url &&
favicon == other.favicon &&
title == other.title &&
author == other.author &&
description == other.description &&
page == other.page &&
from == other.from &&
to == other.to;
return file == other.file || url == other.url;
}
bool operator!=(const SourceExcerpt &other) const {
return !(*this == other);

View File

@ -1,8 +1,6 @@
#ifndef TOOL_H
#define TOOL_H
#include "sourceexcerpt.h"
#include <QObject>
#include <QJsonObject>
@ -70,18 +68,4 @@ public:
virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000) = 0;
};
//class BuiltinTool : public Tool {
// Q_OBJECT
//public:
// BuiltinTool() : Tool() {}
// virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000);
//};
//class LocalTool : public Tool {
// Q_OBJECT
//public:
// LocalTool() : Tool() {}
// virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000);
//};
#endif // TOOL_H