mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-30 14:53:41 +00:00
Make localdocs work with server mode.
This commit is contained in:
parent
8e89ceb54b
commit
f62e439a2d
@ -45,7 +45,6 @@ void Chat::connectLLM()
|
||||
// Should be in same thread
|
||||
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
|
||||
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
|
||||
connect(LocalDocs::globalInstance(), &LocalDocs::receivedResult, this, &Chat::handleLocalDocsRetrieved, Qt::DirectConnection);
|
||||
|
||||
// Should be in different threads
|
||||
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
|
||||
@ -101,52 +100,17 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
|
||||
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens)
|
||||
{
|
||||
Q_ASSERT(m_results.isEmpty());
|
||||
m_results.clear(); // just in case, but the assert above is important
|
||||
m_responseInProgress = true;
|
||||
m_responseState = Chat::LocalDocsRetrieval;
|
||||
emit responseInProgressChanged();
|
||||
emit responseStateChanged();
|
||||
m_queuedPrompt.prompt = prompt;
|
||||
m_queuedPrompt.prompt_template = prompt_template;
|
||||
m_queuedPrompt.n_predict = n_predict;
|
||||
m_queuedPrompt.top_k = top_k;
|
||||
m_queuedPrompt.temp = temp;
|
||||
m_queuedPrompt.n_batch = n_batch;
|
||||
m_queuedPrompt.repeat_penalty = repeat_penalty;
|
||||
m_queuedPrompt.repeat_penalty_tokens = repeat_penalty_tokens;
|
||||
LocalDocs::globalInstance()->requestRetrieve(m_id, m_collections, prompt);
|
||||
}
|
||||
|
||||
void Chat::handleLocalDocsRetrieved(const QString &uid, const QList<ResultInfo> &results)
|
||||
{
|
||||
// If the uid doesn't match, then these are not our results
|
||||
if (uid != m_id)
|
||||
return;
|
||||
|
||||
// Store our results locally
|
||||
m_results = results;
|
||||
|
||||
// Augment the prompt template with the results if any
|
||||
QList<QString> augmentedTemplate;
|
||||
if (!m_results.isEmpty())
|
||||
augmentedTemplate.append("### Context:");
|
||||
for (const ResultInfo &info : m_results)
|
||||
augmentedTemplate.append(info.text);
|
||||
|
||||
augmentedTemplate.append(m_queuedPrompt.prompt_template);
|
||||
emit promptRequested(
|
||||
m_queuedPrompt.prompt,
|
||||
augmentedTemplate.join("\n"),
|
||||
m_queuedPrompt.n_predict,
|
||||
m_queuedPrompt.top_k,
|
||||
m_queuedPrompt.top_p,
|
||||
m_queuedPrompt.temp,
|
||||
m_queuedPrompt.n_batch,
|
||||
m_queuedPrompt.repeat_penalty,
|
||||
m_queuedPrompt.repeat_penalty_tokens,
|
||||
prompt,
|
||||
prompt_template,
|
||||
n_predict,
|
||||
top_k,
|
||||
top_p,
|
||||
temp,
|
||||
n_batch,
|
||||
repeat_penalty,
|
||||
repeat_penalty_tokens,
|
||||
LLM::globalInstance()->threadCount());
|
||||
m_queuedPrompt = Prompt();
|
||||
}
|
||||
|
||||
void Chat::regenerateResponse()
|
||||
@ -195,9 +159,14 @@ void Chat::handleModelLoadedChanged()
|
||||
deleteLater();
|
||||
}
|
||||
|
||||
QList<ResultInfo> Chat::results() const
|
||||
{
|
||||
return m_llmodel->results();
|
||||
}
|
||||
|
||||
void Chat::promptProcessing()
|
||||
{
|
||||
m_responseState = !m_results.isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
|
||||
m_responseState = !results().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
|
||||
emit responseStateChanged();
|
||||
}
|
||||
|
||||
@ -207,7 +176,7 @@ void Chat::responseStopped()
|
||||
QList<QString> references;
|
||||
QList<QString> referencesContext;
|
||||
int validReferenceNumber = 1;
|
||||
for (const ResultInfo &info : m_results) {
|
||||
for (const ResultInfo &info : results()) {
|
||||
if (info.file.isEmpty())
|
||||
continue;
|
||||
if (validReferenceNumber == 1)
|
||||
@ -241,7 +210,6 @@ void Chat::responseStopped()
|
||||
m_chatModel->updateReferences(index, references.join("\n"), referencesContext);
|
||||
emit responseChanged();
|
||||
|
||||
m_results.clear();
|
||||
m_responseInProgress = false;
|
||||
m_responseState = Chat::ResponseStopped;
|
||||
emit responseInProgressChanged();
|
||||
@ -266,6 +234,10 @@ void Chat::setModelName(const QString &modelName)
|
||||
|
||||
void Chat::newPromptResponsePair(const QString &prompt)
|
||||
{
|
||||
m_responseInProgress = true;
|
||||
m_responseState = Chat::LocalDocsRetrieval;
|
||||
emit responseInProgressChanged();
|
||||
emit responseStateChanged();
|
||||
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
||||
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
|
||||
m_chatModel->appendResponse(tr("Response: "), prompt);
|
||||
@ -274,7 +246,11 @@ void Chat::newPromptResponsePair(const QString &prompt)
|
||||
|
||||
void Chat::serverNewPromptResponsePair(const QString &prompt)
|
||||
{
|
||||
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
||||
m_responseInProgress = true;
|
||||
m_responseState = Chat::LocalDocsRetrieval;
|
||||
emit responseInProgressChanged();
|
||||
emit responseStateChanged();
|
||||
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
|
||||
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
|
||||
m_chatModel->appendResponse(tr("Response: "), prompt);
|
||||
}
|
||||
|
@ -60,6 +60,8 @@ public:
|
||||
Q_INVOKABLE void stopGenerating();
|
||||
Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
|
||||
|
||||
QList<ResultInfo> results() const;
|
||||
|
||||
QString response() const;
|
||||
bool responseInProgress() const { return m_responseInProgress; }
|
||||
QString responseState() const;
|
||||
@ -115,7 +117,6 @@ Q_SIGNALS:
|
||||
void collectionListChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void handleLocalDocsRetrieved(const QString &uid, const QList<ResultInfo> &results);
|
||||
void handleResponseChanged();
|
||||
void handleModelLoadedChanged();
|
||||
void promptProcessing();
|
||||
@ -125,24 +126,11 @@ private Q_SLOTS:
|
||||
void handleModelNameChanged();
|
||||
|
||||
private:
|
||||
struct Prompt {
|
||||
QString prompt;
|
||||
QString prompt_template;
|
||||
int32_t n_predict;
|
||||
int32_t top_k;
|
||||
float top_p;
|
||||
float temp;
|
||||
int32_t n_batch;
|
||||
float repeat_penalty;
|
||||
int32_t repeat_penalty_tokens;
|
||||
};
|
||||
|
||||
QString m_id;
|
||||
QString m_name;
|
||||
QString m_userName;
|
||||
QString m_savedModelName;
|
||||
QList<QString> m_collections;
|
||||
QList<ResultInfo> m_results;
|
||||
ChatModel *m_chatModel;
|
||||
bool m_responseInProgress;
|
||||
ResponseState m_responseState;
|
||||
@ -150,7 +138,6 @@ private:
|
||||
ChatLLM *m_llmodel;
|
||||
bool m_isServer;
|
||||
bool m_shouldDeleteLater;
|
||||
Prompt m_queuedPrompt;
|
||||
};
|
||||
|
||||
#endif // CHAT_H
|
||||
|
@ -91,9 +91,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
moveToThread(&m_llmThread);
|
||||
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
|
||||
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection);
|
||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted);
|
||||
|
||||
// The following are blocking operations and will block the llm thread
|
||||
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
||||
Qt::BlockingQueuedConnection);
|
||||
|
||||
m_llmThread.setObjectName(m_chat->id());
|
||||
m_llmThread.start();
|
||||
}
|
||||
@ -386,7 +392,19 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
||||
if (!isModelLoaded())
|
||||
return false;
|
||||
|
||||
QString instructPrompt = prompt_template.arg(prompt);
|
||||
m_results.clear();
|
||||
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
|
||||
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &m_results); // blocks
|
||||
|
||||
// Augment the prompt template with the results if any
|
||||
QList<QString> augmentedTemplate;
|
||||
if (!m_results.isEmpty())
|
||||
augmentedTemplate.append("### Context:");
|
||||
for (const ResultInfo &info : m_results)
|
||||
augmentedTemplate.append(info.text);
|
||||
augmentedTemplate.append(prompt_template);
|
||||
|
||||
QString instructPrompt = augmentedTemplate.join("\n").arg(prompt);
|
||||
|
||||
m_stopGenerating = false;
|
||||
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <QThread>
|
||||
#include <QFileInfo>
|
||||
|
||||
#include "localdocs.h"
|
||||
#include "../gpt4all-backend/llmodel.h"
|
||||
|
||||
enum LLModelType {
|
||||
@ -39,6 +40,7 @@ public:
|
||||
void regenerateResponse();
|
||||
void resetResponse();
|
||||
void resetContext();
|
||||
QList<ResultInfo> results() const { return m_results; }
|
||||
|
||||
void stopGenerating() { m_stopGenerating = true; }
|
||||
|
||||
@ -85,6 +87,8 @@ Q_SIGNALS:
|
||||
void stateChanged();
|
||||
void threadStarted();
|
||||
void shouldBeLoadedChanged();
|
||||
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||
|
||||
|
||||
protected:
|
||||
bool handlePrompt(int32_t token);
|
||||
@ -111,6 +115,7 @@ protected:
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
std::atomic<bool> m_shouldBeLoaded;
|
||||
QList<ResultInfo> m_results;
|
||||
bool m_isRecalc;
|
||||
bool m_isServer;
|
||||
bool m_isChatGPT;
|
||||
|
@ -892,7 +892,8 @@ bool Database::removeFolderFromWatch(const QString &path)
|
||||
return m_watcher->removePath(path);
|
||||
}
|
||||
|
||||
void Database::retrieveFromDB(const QString &uid, const QList<QString> &collections, const QString &text, int retrievalSize)
|
||||
void Database::retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize,
|
||||
QList<ResultInfo> *results)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "retrieveFromDB" << collections << text << retrievalSize;
|
||||
@ -904,7 +905,6 @@ void Database::retrieveFromDB(const QString &uid, const QList<QString> &collecti
|
||||
return;
|
||||
}
|
||||
|
||||
QList<ResultInfo> results;
|
||||
while (q.next()) {
|
||||
const int rowid = q.value(0).toInt();
|
||||
const QString date = QDateTime::fromMSecsSinceEpoch(q.value(1).toLongLong()).toString("yyyy, MMMM dd");
|
||||
@ -924,13 +924,12 @@ void Database::retrieveFromDB(const QString &uid, const QList<QString> &collecti
|
||||
info.page = page;
|
||||
info.from = from;
|
||||
info.to = to;
|
||||
results.append(info);
|
||||
results->append(info);
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "retrieve rowid:" << rowid
|
||||
<< "chunk_text:" << chunk_text;
|
||||
#endif
|
||||
}
|
||||
emit retrieveResult(uid, results);
|
||||
}
|
||||
|
||||
void Database::cleanDB()
|
||||
|
@ -43,13 +43,12 @@ public Q_SLOTS:
|
||||
void scanDocuments(int folder_id, const QString &folder_path);
|
||||
void addFolder(const QString &collection, const QString &path);
|
||||
void removeFolder(const QString &collection, const QString &path);
|
||||
void retrieveFromDB(const QString &uid, const QList<QString> &collections, const QString &text, int retrievalSize);
|
||||
void retrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||
void cleanDB();
|
||||
void changeChunkSize(int chunkSize);
|
||||
|
||||
Q_SIGNALS:
|
||||
void docsToScanChanged();
|
||||
void retrieveResult(const QString &uid, const QList<ResultInfo> &result);
|
||||
void collectionListUpdated(const QList<CollectionItem> &collectionList);
|
||||
|
||||
private Q_SLOTS:
|
||||
|
@ -24,12 +24,8 @@ LocalDocs::LocalDocs()
|
||||
&Database::addFolder, Qt::QueuedConnection);
|
||||
connect(this, &LocalDocs::requestRemoveFolder, m_database,
|
||||
&Database::removeFolder, Qt::QueuedConnection);
|
||||
connect(this, &LocalDocs::requestRetrieveFromDB, m_database,
|
||||
&Database::retrieveFromDB, Qt::QueuedConnection);
|
||||
connect(this, &LocalDocs::requestChunkSizeChange, m_database,
|
||||
&Database::changeChunkSize, Qt::QueuedConnection);
|
||||
connect(m_database, &Database::retrieveResult, this,
|
||||
&LocalDocs::receivedResult, Qt::QueuedConnection);
|
||||
connect(m_database, &Database::collectionListUpdated,
|
||||
m_localDocsModel, &LocalDocsModel::handleCollectionListUpdated, Qt::QueuedConnection);
|
||||
}
|
||||
@ -49,11 +45,6 @@ void LocalDocs::removeFolder(const QString &collection, const QString &path)
|
||||
emit requestRemoveFolder(collection, path);
|
||||
}
|
||||
|
||||
void LocalDocs::requestRetrieve(const QString &uid, const QList<QString> &collections, const QString &text)
|
||||
{
|
||||
emit requestRetrieveFromDB(uid, collections, text, m_retrievalSize);
|
||||
}
|
||||
|
||||
int LocalDocs::chunkSize() const
|
||||
{
|
||||
return m_chunkSize;
|
||||
|
@ -20,7 +20,8 @@ public:
|
||||
|
||||
Q_INVOKABLE void addFolder(const QString &collection, const QString &path);
|
||||
Q_INVOKABLE void removeFolder(const QString &collection, const QString &path);
|
||||
void requestRetrieve(const QString &uid, const QList<QString> &collections, const QString &text);
|
||||
|
||||
Database *database() const { return m_database; }
|
||||
|
||||
int chunkSize() const;
|
||||
void setChunkSize(int chunkSize);
|
||||
@ -31,9 +32,7 @@ public:
|
||||
Q_SIGNALS:
|
||||
void requestAddFolder(const QString &collection, const QString &path);
|
||||
void requestRemoveFolder(const QString &collection, const QString &path);
|
||||
void requestRetrieveFromDB(const QString &uid, const QList<QString> &collections, const QString &text, int retrievalSize);
|
||||
void requestChunkSizeChange(int chunkSize);
|
||||
void receivedResult(const QString &uid, const QList<ResultInfo> &result);
|
||||
void localDocsModelChanged();
|
||||
void chunkSizeChanged();
|
||||
void retrievalSizeChanged();
|
||||
|
@ -51,6 +51,20 @@ static inline QJsonObject modelToJson(const ModelInfo &info)
|
||||
return model;
|
||||
}
|
||||
|
||||
static inline QJsonObject resultToJson(const ResultInfo &info)
|
||||
{
|
||||
QJsonObject result;
|
||||
result.insert("file", info.file);
|
||||
result.insert("title", info.title);
|
||||
result.insert("author", info.author);
|
||||
result.insert("date", info.date);
|
||||
result.insert("text", info.text);
|
||||
result.insert("page", info.page);
|
||||
result.insert("from", info.from);
|
||||
result.insert("to", info.to);
|
||||
return result;
|
||||
}
|
||||
|
||||
Server::Server(Chat *chat)
|
||||
: ChatLLM(chat, true /*isServer*/)
|
||||
, m_chat(chat)
|
||||
@ -298,7 +312,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
||||
|
||||
int promptTokens = 0;
|
||||
int responseTokens = 0;
|
||||
QList<QString> responses;
|
||||
QList<QPair<QString, QList<ResultInfo>>> responses;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (!prompt(actualPrompt,
|
||||
promptTemplate,
|
||||
@ -317,7 +331,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
||||
QString echoedPrompt = actualPrompt;
|
||||
if (!echoedPrompt.endsWith("\n"))
|
||||
echoedPrompt += "\n";
|
||||
responses.append((echo ? QString("%1\n").arg(actualPrompt) : QString()) + response());
|
||||
responses.append(qMakePair((echo ? QString("%1\n").arg(actualPrompt) : QString()) + response(), m_results));
|
||||
if (!promptTokens)
|
||||
promptTokens += m_promptTokens;
|
||||
responseTokens += m_promptResponseTokens - m_promptTokens;
|
||||
@ -335,24 +349,36 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
||||
|
||||
if (isChat) {
|
||||
int index = 0;
|
||||
for (QString r : responses) {
|
||||
for (const auto &r : responses) {
|
||||
QString result = r.first;
|
||||
QList<ResultInfo> infos = r.second;
|
||||
QJsonObject choice;
|
||||
choice.insert("index", index++);
|
||||
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
|
||||
QJsonObject message;
|
||||
message.insert("role", "assistant");
|
||||
message.insert("content", r);
|
||||
message.insert("content", result);
|
||||
choice.insert("message", message);
|
||||
QJsonArray references;
|
||||
for (const auto &ref : infos)
|
||||
references.append(resultToJson(ref));
|
||||
choice.insert("references", references);
|
||||
choices.append(choice);
|
||||
}
|
||||
} else {
|
||||
int index = 0;
|
||||
for (QString r : responses) {
|
||||
for (const auto &r : responses) {
|
||||
QString result = r.first;
|
||||
QList<ResultInfo> infos = r.second;
|
||||
QJsonObject choice;
|
||||
choice.insert("text", r);
|
||||
choice.insert("text", result);
|
||||
choice.insert("index", index++);
|
||||
choice.insert("logprobs", QJsonValue::Null); // We don't support
|
||||
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
|
||||
QJsonArray references;
|
||||
for (const auto &ref : infos)
|
||||
references.append(resultToJson(ref));
|
||||
choice.insert("references", references);
|
||||
choices.append(choice);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user