Make ChatModel threadsafe to support direct access by ChatLLM (#3018)

Signed-off-by: Adam Treat <treat.adam@gmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT
2024-10-01 18:15:02 -04:00
committed by GitHub
parent ee67cca885
commit c11b67dfcb
6 changed files with 140 additions and 73 deletions

View File

@@ -2,6 +2,7 @@
#include "chat.h"
#include "chatapi.h"
#include "chatmodel.h"
#include "localdocs.h"
#include "mysettings.h"
#include "network.h"
@@ -13,10 +14,14 @@
#include <QIODevice>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonValue>
#include <QMap>
#include <QMutex>
#include <QMutexLocker>
#include <QRegularExpression>
#include <QSet>
#include <QStringList>
#include <QUrl>
#include <QWaitCondition>
#include <Qt>
#include <QtLogging>
@@ -113,6 +118,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
, m_restoreStateFromText(false)
, m_chatModel(parent->chatModel())
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
@@ -1313,31 +1319,32 @@ void ChatLLM::processRestoreStateFromText()
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
auto it = m_stateFromText.begin();
while (it < m_stateFromText.end()) {
Q_ASSERT(m_chatModel);
m_chatModel->lock();
auto it = m_chatModel->begin();
while (it < m_chatModel->end()) {
auto &prompt = *it++;
Q_ASSERT(prompt.first == "Prompt: ");
Q_ASSERT(it < m_stateFromText.end());
Q_ASSERT(prompt.name == "Prompt: ");
Q_ASSERT(it < m_chatModel->end());
auto &response = *it++;
Q_ASSERT(response.first != "Prompt: ");
Q_ASSERT(response.name == "Response: ");
// FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing
// m_promptTokens or m_promptResponseTokens
m_llModelInfo.model->prompt(
prompt.second.toStdString(), promptTemplate.toStdString(),
prompt.value.toStdString(), promptTemplate.toStdString(),
promptFunc, /*responseFunc*/ [](auto &&...) { return true; },
/*allowContextShift*/ true,
m_ctx,
/*special*/ false,
response.second.toUtf8().constData()
response.value.toUtf8().constData()
);
}
m_chatModel->unlock();
if (!m_stopGenerating) {
if (!m_stopGenerating)
m_restoreStateFromText = false;
m_stateFromText.clear();
}
m_restoringFromText = false;
emit restoringFromTextChanged();