Make localdocs work with server mode.

This commit is contained in:
Adam Treat
2023-06-01 14:13:12 -04:00
committed by AT
parent 8e89ceb54b
commit f62e439a2d
9 changed files with 90 additions and 90 deletions

View File

@@ -45,7 +45,6 @@ void Chat::connectLLM()
// Should be in same thread
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(LocalDocs::globalInstance(), &LocalDocs::receivedResult, this, &Chat::handleLocalDocsRetrieved, Qt::DirectConnection);
// Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
@@ -101,52 +100,17 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
Q_ASSERT(m_results.isEmpty());
m_results.clear(); // just in case, but the assert above is important
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
m_queuedPrompt.prompt = prompt;
m_queuedPrompt.prompt_template = prompt_template;
m_queuedPrompt.n_predict = n_predict;
m_queuedPrompt.top_k = top_k;
m_queuedPrompt.temp = temp;
m_queuedPrompt.n_batch = n_batch;
m_queuedPrompt.repeat_penalty = repeat_penalty;
m_queuedPrompt.repeat_penalty_tokens = repeat_penalty_tokens;
LocalDocs::globalInstance()->requestRetrieve(m_id, m_collections, prompt);
}
void Chat::handleLocalDocsRetrieved(const QString &uid, const QList<ResultInfo> &results)
{
// If the uid doesn't match, then these are not our results
if (uid != m_id)
return;
// Store our results locally
m_results = results;
// Augment the prompt template with the results if any
QList<QString> augmentedTemplate;
if (!m_results.isEmpty())
augmentedTemplate.append("### Context:");
for (const ResultInfo &info : m_results)
augmentedTemplate.append(info.text);
augmentedTemplate.append(m_queuedPrompt.prompt_template);
emit promptRequested(
m_queuedPrompt.prompt,
augmentedTemplate.join("\n"),
m_queuedPrompt.n_predict,
m_queuedPrompt.top_k,
m_queuedPrompt.top_p,
m_queuedPrompt.temp,
m_queuedPrompt.n_batch,
m_queuedPrompt.repeat_penalty,
m_queuedPrompt.repeat_penalty_tokens,
prompt,
prompt_template,
n_predict,
top_k,
top_p,
temp,
n_batch,
repeat_penalty,
repeat_penalty_tokens,
LLM::globalInstance()->threadCount());
m_queuedPrompt = Prompt();
}
void Chat::regenerateResponse()
@@ -195,9 +159,14 @@ void Chat::handleModelLoadedChanged()
deleteLater();
}
QList<ResultInfo> Chat::results() const
{
return m_llmodel->results();
}
void Chat::promptProcessing()
{
m_responseState = !m_results.isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
m_responseState = !results().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged();
}
@@ -207,7 +176,7 @@ void Chat::responseStopped()
QList<QString> references;
QList<QString> referencesContext;
int validReferenceNumber = 1;
for (const ResultInfo &info : m_results) {
for (const ResultInfo &info : results()) {
if (info.file.isEmpty())
continue;
if (validReferenceNumber == 1)
@@ -241,7 +210,6 @@ void Chat::responseStopped()
m_chatModel->updateReferences(index, references.join("\n"), referencesContext);
emit responseChanged();
m_results.clear();
m_responseInProgress = false;
m_responseState = Chat::ResponseStopped;
emit responseInProgressChanged();
@@ -266,6 +234,10 @@ void Chat::setModelName(const QString &modelName)
void Chat::newPromptResponsePair(const QString &prompt)
{
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt);
@@ -274,7 +246,11 @@ void Chat::newPromptResponsePair(const QString &prompt)
void Chat::serverNewPromptResponsePair(const QString &prompt)
{
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_responseInProgress = true;
m_responseState = Chat::LocalDocsRetrieval;
emit responseInProgressChanged();
emit responseStateChanged();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt);
}