Adds the collections to serialize and implement references for localdocs.

This commit is contained in:
Adam Treat
2023-05-24 14:49:43 -04:00
committed by AT
parent d81302950e
commit b5380c9b7f
7 changed files with 171 additions and 77 deletions

View File

@@ -15,7 +15,6 @@ Chat::Chat(QObject *parent)
, m_llmodel(new ChatLLM(this))
, m_isServer(false)
, m_shouldDeleteLater(false)
, m_contextContainsLocalDocs(false)
{
connectLLM();
}
@@ -31,7 +30,6 @@ Chat::Chat(bool isServer, QObject *parent)
, m_llmodel(new Server(this))
, m_isServer(true)
, m_shouldDeleteLater(false)
, m_contextContainsLocalDocs(false)
{
connectLLM();
}
@@ -103,7 +101,8 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
m_contextContainsLocalDocs = false;
Q_ASSERT(m_results.isEmpty());
m_results.clear(); // just in case, but the assert above is important
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
@@ -116,18 +115,25 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
m_queuedPrompt.n_batch = n_batch;
m_queuedPrompt.repeat_penalty = repeat_penalty;
m_queuedPrompt.repeat_penalty_tokens = repeat_penalty_tokens;
LocalDocs::globalInstance()->requestRetrieve(m_collections, prompt);
LocalDocs::globalInstance()->requestRetrieve(m_id, m_collections, prompt);
}
void Chat::handleLocalDocsRetrieved()
void Chat::handleLocalDocsRetrieved(const QString &uid, const QList<ResultInfo> &results)
{
// If the uid doesn't match, then these are not our results
if (uid != m_id)
return;
// Store our results locally
m_results = results;
// Augment the prompt template with the results if any
QList<QString> augmentedTemplate;
QList<QString> results = LocalDocs::globalInstance()->result();
if (!results.isEmpty()) {
if (!m_results.isEmpty())
augmentedTemplate.append("### Context:");
augmentedTemplate.append(results.join("\n\n"));
}
m_contextContainsLocalDocs = !results.isEmpty();
for (const ResultInfo &info : m_results)
augmentedTemplate.append(info.text);
augmentedTemplate.append(m_queuedPrompt.prompt_template);
emit promptRequested(
m_queuedPrompt.prompt,
@@ -191,13 +197,48 @@ void Chat::handleModelLoadedChanged()
void Chat::promptProcessing()
{
m_responseState = m_contextContainsLocalDocs ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
m_responseState = !m_results.isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged();
}
void Chat::responseStopped()
{
m_contextContainsLocalDocs = false;
const QString chatResponse = response();
QList<QString> finalResponse { chatResponse };
int validReferenceNumber = 1;
for (const ResultInfo &info : m_results) {
if (info.file.isEmpty())
continue;
if (validReferenceNumber == 1)
finalResponse.append(QStringLiteral("---"));
QString reference;
{
QTextStream stream(&reference);
stream << (validReferenceNumber++) << ". ";
if (!info.title.isEmpty())
stream << "\"" << info.title << "\". ";
if (!info.author.isEmpty())
stream << "By " << info.author << ". ";
if (!info.date.isEmpty())
stream << "Date: " << info.date << ". ";
stream << "In " << info.file << ". ";
if (info.page != -1)
stream << "Page " << info.page << ". ";
if (info.from != -1) {
stream << "Lines " << info.from;
if (info.to != -1)
stream << "-" << info.to;
stream << ". ";
}
}
finalResponse.append(reference);
}
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, finalResponse.join("\n"));
emit responseChanged();
m_results.clear();
m_responseInProgress = false;
m_responseState = Chat::ResponseStopped;
emit responseInProgressChanged();
@@ -301,6 +342,8 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_name;
stream << m_userName;
stream << m_savedModelName;
if (version > 2)
stream << m_collections;
if (!m_llmodel->serialize(stream, version))
return false;
if (!m_chatModel->serialize(stream, version))
@@ -321,6 +364,10 @@ bool Chat::deserialize(QDataStream &stream, int version)
// unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
return false;
if (version > 2) {
stream >> m_collections;
emit collectionListChanged();
}
m_llmodel->setModelName(m_savedModelName);
if (!m_llmodel->deserialize(stream, version))
return false;