chat: fix issues with quickly switching between multiple chats (#2343)

* prevent load progress from getting out of sync with the current chat
* fix memory leak on exit if the LLModelStore contains a model
* do not report cancellation as a failure in console/Mixpanel
* show "waiting for model" separately from "switching context" in UI
* do not show lower "reload" button on error
* skip context switch if unload is pending
* skip unnecessary calls to LLModel::saveState

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-05-15 14:07:03 -04:00 committed by GitHub
parent 7f1c3d4275
commit 7e1e00f331
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 179 additions and 143 deletions

View File

@ -54,7 +54,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::trySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection);
@ -95,16 +95,6 @@ void Chat::processSystemPrompt()
emit processSystemPromptRequested(); emit processSystemPromptRequested();
} }
bool Chat::isModelLoaded() const
{
return m_modelLoadingPercentage == 1.0f;
}
float Chat::modelLoadingPercentage() const
{
return m_modelLoadingPercentage;
}
void Chat::resetResponseState() void Chat::resetResponseState()
{ {
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
@ -167,9 +157,16 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
if (loadingPercentage == m_modelLoadingPercentage) if (loadingPercentage == m_modelLoadingPercentage)
return; return;
bool wasLoading = isCurrentlyLoading();
bool wasLoaded = isModelLoaded();
m_modelLoadingPercentage = loadingPercentage; m_modelLoadingPercentage = loadingPercentage;
emit modelLoadingPercentageChanged(); emit modelLoadingPercentageChanged();
if (m_modelLoadingPercentage == 1.0f || m_modelLoadingPercentage == 0.0f)
if (isCurrentlyLoading() != wasLoading)
emit isCurrentlyLoadingChanged();
if (isModelLoaded() != wasLoaded)
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} }
@ -247,10 +244,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
if (m_modelInfo == modelInfo && isModelLoaded()) if (m_modelInfo == modelInfo && isModelLoaded())
return; return;
m_modelLoadingPercentage = std::numeric_limits<float>::min(); // small non-zero positive value
emit isModelLoadedChanged();
m_modelLoadingError = QString();
emit modelLoadingErrorChanged();
m_modelInfo = modelInfo; m_modelInfo = modelInfo;
emit modelInfoChanged(); emit modelInfoChanged();
emit modelChangeRequested(modelInfo); emit modelChangeRequested(modelInfo);
@ -320,8 +313,9 @@ void Chat::forceReloadModel()
void Chat::trySwitchContextOfLoadedModel() void Chat::trySwitchContextOfLoadedModel()
{ {
emit trySwitchContextOfLoadedModelAttempted(); m_trySwitchContextInProgress = 1;
m_llmodel->setShouldTrySwitchContext(true); emit trySwitchContextInProgressChanged();
m_llmodel->requestTrySwitchContext();
} }
void Chat::generatedNameChanged(const QString &name) void Chat::generatedNameChanged(const QString &name)
@ -342,8 +336,10 @@ void Chat::handleRecalculating()
void Chat::handleModelLoadingError(const QString &error) void Chat::handleModelLoadingError(const QString &error)
{ {
auto stream = qWarning().noquote() << "ERROR:" << error << "id"; if (!error.isEmpty()) {
stream.quote() << id(); auto stream = qWarning().noquote() << "ERROR:" << error << "id";
stream.quote() << id();
}
m_modelLoadingError = error; m_modelLoadingError = error;
emit modelLoadingErrorChanged(); emit modelLoadingErrorChanged();
} }
@ -380,6 +376,11 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
emit modelInfoChanged(); emit modelInfoChanged();
} }
void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) {
m_trySwitchContextInProgress = value;
emit trySwitchContextInProgressChanged();
}
bool Chat::serialize(QDataStream &stream, int version) const bool Chat::serialize(QDataStream &stream, int version) const
{ {
stream << m_creationDate; stream << m_creationDate;

View File

@ -17,6 +17,7 @@ class Chat : public QObject
Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged) Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged)
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged) Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(bool isCurrentlyLoading READ isCurrentlyLoading NOTIFY isCurrentlyLoadingChanged)
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged) Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged) Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
@ -30,6 +31,8 @@ class Chat : public QObject
Q_PROPERTY(QString device READ device NOTIFY deviceChanged); Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged); Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged);
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged) Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
// 0=no, 1=waiting, 2=working
Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
QML_ELEMENT QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!") QML_UNCREATABLE("Only creatable from c++!")
@ -62,8 +65,9 @@ public:
Q_INVOKABLE void reset(); Q_INVOKABLE void reset();
Q_INVOKABLE void processSystemPrompt(); Q_INVOKABLE void processSystemPrompt();
Q_INVOKABLE bool isModelLoaded() const; bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
Q_INVOKABLE float modelLoadingPercentage() const; bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
Q_INVOKABLE void prompt(const QString &prompt); Q_INVOKABLE void prompt(const QString &prompt);
Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void stopGenerating();
@ -105,6 +109,8 @@ public:
QString device() const { return m_device; } QString device() const { return m_device; }
QString fallbackReason() const { return m_fallbackReason; } QString fallbackReason() const { return m_fallbackReason; }
int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }
public Q_SLOTS: public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt); void serverNewPromptResponsePair(const QString &prompt);
@ -113,6 +119,7 @@ Q_SIGNALS:
void nameChanged(); void nameChanged();
void chatModelChanged(); void chatModelChanged();
void isModelLoadedChanged(); void isModelLoadedChanged();
void isCurrentlyLoadingChanged();
void modelLoadingPercentageChanged(); void modelLoadingPercentageChanged();
void modelLoadingWarning(const QString &warning); void modelLoadingWarning(const QString &warning);
void responseChanged(); void responseChanged();
@ -136,8 +143,7 @@ Q_SIGNALS:
void deviceChanged(); void deviceChanged();
void fallbackReasonChanged(); void fallbackReasonChanged();
void collectionModelChanged(); void collectionModelChanged();
void trySwitchContextOfLoadedModelAttempted(); void trySwitchContextInProgressChanged();
void trySwitchContextOfLoadedModelCompleted(bool);
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(const QString &response); void handleResponseChanged(const QString &response);
@ -152,6 +158,7 @@ private Q_SLOTS:
void handleFallbackReasonChanged(const QString &device); void handleFallbackReasonChanged(const QString &device);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results); void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo); void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleTrySwitchContextOfLoadedModelCompleted(int value);
private: private:
QString m_id; QString m_id;
@ -176,6 +183,8 @@ private:
float m_modelLoadingPercentage = 0.0f; float m_modelLoadingPercentage = 0.0f;
LocalDocsCollectionsModel *m_collectionModel; LocalDocsCollectionsModel *m_collectionModel;
bool m_firstResponse = true; bool m_firstResponse = true;
int m_trySwitchContextInProgress = 0;
bool m_isCurrentlyLoading = false;
}; };
#endif // CHAT_H #endif // CHAT_H

View File

@ -195,7 +195,11 @@ public:
int count() const { return m_chats.size(); } int count() const { return m_chats.size(); }
// stop ChatLLM threads for clean shutdown // stop ChatLLM threads for clean shutdown
void destroyChats() { for (auto *chat: m_chats) { chat->destroy(); } } void destroyChats()
{
for (auto *chat: m_chats) { chat->destroy(); }
ChatLLM::destroyStore();
}
void removeChatFile(Chat *chat) const; void removeChatFile(Chat *chat) const;
Q_INVOKABLE void saveChats(); Q_INVOKABLE void saveChats();

View File

@ -30,16 +30,17 @@ public:
static LLModelStore *globalInstance(); static LLModelStore *globalInstance();
LLModelInfo acquireModel(); // will block until llmodel is ready LLModelInfo acquireModel(); // will block until llmodel is ready
void releaseModel(const LLModelInfo &info); // must be called when you are done void releaseModel(LLModelInfo &&info); // must be called when you are done
void destroy();
private: private:
LLModelStore() LLModelStore()
{ {
// seed with empty model // seed with empty model
m_availableModels.append(LLModelInfo()); m_availableModel = LLModelInfo();
} }
~LLModelStore() {} ~LLModelStore() {}
QVector<LLModelInfo> m_availableModels; std::optional<LLModelInfo> m_availableModel;
QMutex m_mutex; QMutex m_mutex;
QWaitCondition m_condition; QWaitCondition m_condition;
friend class MyLLModelStore; friend class MyLLModelStore;
@ -55,19 +56,27 @@ LLModelStore *LLModelStore::globalInstance()
LLModelInfo LLModelStore::acquireModel() LLModelInfo LLModelStore::acquireModel()
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
while (m_availableModels.isEmpty()) while (!m_availableModel)
m_condition.wait(locker.mutex()); m_condition.wait(locker.mutex());
return m_availableModels.takeFirst(); auto first = std::move(*m_availableModel);
m_availableModel.reset();
return first;
} }
void LLModelStore::releaseModel(const LLModelInfo &info) void LLModelStore::releaseModel(LLModelInfo &&info)
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
m_availableModels.append(info); Q_ASSERT(!m_availableModel);
Q_ASSERT(m_availableModels.count() < 2); m_availableModel = std::move(info);
m_condition.wakeAll(); m_condition.wakeAll();
} }
void LLModelStore::destroy()
{
QMutexLocker locker(&m_mutex);
m_availableModel.reset();
}
ChatLLM::ChatLLM(Chat *parent, bool isServer) ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr} : QObject{nullptr}
, m_promptResponseTokens(0) , m_promptResponseTokens(0)
@ -76,7 +85,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_shouldBeLoaded(false) , m_shouldBeLoaded(false)
, m_forceUnloadModel(false) , m_forceUnloadModel(false)
, m_markedForDeletion(false) , m_markedForDeletion(false)
, m_shouldTrySwitchContext(false)
, m_stopGenerating(false) , m_stopGenerating(false)
, m_timer(nullptr) , m_timer(nullptr)
, m_isServer(isServer) , m_isServer(isServer)
@ -88,7 +96,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged, connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
@ -108,7 +116,8 @@ ChatLLM::~ChatLLM()
destroy(); destroy();
} }
void ChatLLM::destroy() { void ChatLLM::destroy()
{
m_stopGenerating = true; m_stopGenerating = true;
m_llmThread.quit(); m_llmThread.quit();
m_llmThread.wait(); m_llmThread.wait();
@ -116,11 +125,15 @@ void ChatLLM::destroy() {
// The only time we should have a model loaded here is on shutdown // The only time we should have a model loaded here is on shutdown
// as we explicitly unload the model in all other circumstances // as we explicitly unload the model in all other circumstances
if (isModelLoaded()) { if (isModelLoaded()) {
delete m_llModelInfo.model; m_llModelInfo.model.reset();
m_llModelInfo.model = nullptr;
} }
} }
void ChatLLM::destroyStore()
{
LLModelStore::globalInstance()->destroy();
}
void ChatLLM::handleThreadStarted() void ChatLLM::handleThreadStarted()
{ {
m_timer = new TokenTimer(this); m_timer = new TokenTimer(this);
@ -161,7 +174,7 @@ bool ChatLLM::loadDefaultModel()
return loadModel(defaultModel); return loadModel(defaultModel);
} }
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{ {
// We're trying to see if the store already has the model fully loaded that we wish to use // We're trying to see if the store already has the model fully loaded that we wish to use
// and if so we just acquire it from the store and switch the context and return true. If the // and if so we just acquire it from the store and switch the context and return true. If the
@ -169,10 +182,11 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
// If we're already loaded or a server or we're reloading to change the variant/device or the // If we're already loaded or a server or we're reloading to change the variant/device or the
// modelInfo is empty, then this should fail // modelInfo is empty, then this should fail
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) { if (
m_shouldTrySwitchContext = false; isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded
emit trySwitchContextOfLoadedModelCompleted(false); ) {
return false; emit trySwitchContextOfLoadedModelCompleted(0);
return;
} }
QString filePath = modelInfo.dirpath + modelInfo.filename(); QString filePath = modelInfo.dirpath + modelInfo.filename();
@ -180,33 +194,28 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
// The store gave us no already loaded model, the wrong type of model, then give it back to the // The store gave us no already loaded model, the wrong type of model, then give it back to the
// store and fail // store and fail
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) { if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo || !m_shouldBeLoaded) {
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo(); emit trySwitchContextOfLoadedModelCompleted(0);
m_shouldTrySwitchContext = false; return;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
} }
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
// We should be loaded and now we are emit trySwitchContextOfLoadedModelCompleted(2);
m_shouldBeLoaded = true;
m_shouldTrySwitchContext = false;
// Restore, signal and process // Restore, signal and process
restoreState(); restoreState();
emit modelLoadingPercentageChanged(1.0f); emit modelLoadingPercentageChanged(1.0f);
emit trySwitchContextOfLoadedModelCompleted(true); emit trySwitchContextOfLoadedModelCompleted(0);
processSystemPrompt(); processSystemPrompt();
return true;
} }
bool ChatLLM::loadModel(const ModelInfo &modelInfo) bool ChatLLM::loadModel(const ModelInfo &modelInfo)
@ -223,6 +232,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (isModelLoaded() && this->modelInfo() == modelInfo) if (isModelLoaded() && this->modelInfo() == modelInfo)
return true; return true;
// reset status
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
emit modelLoadingError("");
emit reportFallbackReason("");
emit reportDevice("");
m_pristineLoadedState = false;
QString filePath = modelInfo.dirpath + modelInfo.filename(); QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath); QFileInfo fileInfo(filePath);
@ -231,28 +247,25 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (alreadyAcquired) { if (alreadyAcquired) {
resetContext(); resetContext();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
delete m_llModelInfo.model; m_llModelInfo.model.reset();
m_llModelInfo.model = nullptr;
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
} else if (!m_isServer) { } else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store. // This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model // If it succeeds, then we just have to restore state. If the store has never had a model
// returned to it, then the modelInfo.model pointer should be null which will happen on startup // returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
// At this point it is possible that while we were blocked waiting to acquire the model from the // At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model // store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading // back into the store and quit loading
if (!m_shouldBeLoaded) { if (!m_shouldBeLoaded) {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f); emit modelLoadingPercentageChanged(0.0f);
return false; return false;
} }
@ -260,7 +273,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// Check if the store just gave us exactly the model we were looking for // Check if the store just gave us exactly the model we were looking for
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) { if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
restoreState(); restoreState();
emit modelLoadingPercentageChanged(1.0f); emit modelLoadingPercentageChanged(1.0f);
@ -274,10 +287,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
} else { } else {
// Release the memory since we have to switch to a different model. // Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
delete m_llModelInfo.model; m_llModelInfo.model.reset();
m_llModelInfo.model = nullptr;
} }
} }
@ -307,7 +319,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
model->setModelName(modelName); model->setModelName(modelName);
model->setRequestURL(modelInfo.url()); model->setRequestURL(modelInfo.url());
model->setAPIKey(apiKey); model->setAPIKey(apiKey);
m_llModelInfo.model = model; m_llModelInfo.model.reset(model);
} else { } else {
QElapsedTimer modelLoadTimer; QElapsedTimer modelLoadTimer;
modelLoadTimer.start(); modelLoadTimer.start();
@ -322,9 +334,10 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
buildVariant = "metal"; buildVariant = "metal";
#endif #endif
QString constructError; QString constructError;
m_llModelInfo.model = nullptr; m_llModelInfo.model.reset();
try { try {
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx); auto *model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
m_llModelInfo.model.reset(model);
} catch (const LLModel::MissingImplementationError &e) { } catch (const LLModel::MissingImplementationError &e) {
modelLoadProps.insert("error", "missing_model_impl"); modelLoadProps.insert("error", "missing_model_impl");
constructError = e.what(); constructError = e.what();
@ -355,8 +368,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
return m_shouldBeLoaded; return m_shouldBeLoaded;
}); });
emit reportFallbackReason(""); // no fallback yet
auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) { auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) {
float memGB = dev->heapSize / float(1024 * 1024 * 1024); float memGB = dev->heapSize / float(1024 * 1024 * 1024);
return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
@ -407,6 +418,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
emit reportDevice(actualDevice); emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl); bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
if (!m_shouldBeLoaded) {
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
if (actualDevice == "CPU") { if (actualDevice == "CPU") {
// we asked llama.cpp to use the CPU // we asked llama.cpp to use the CPU
} else if (!success) { } else if (!success) {
@ -415,6 +436,15 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)"); emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed"); modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed");
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0); success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
if (!m_shouldBeLoaded) {
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
} else if (!m_llModelInfo.model->usingGPUDevice()) { } else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp // ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate // We might have had to fallback to CPU after load if the model is not possible to accelerate
@ -425,10 +455,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
} }
if (!success) { if (!success) {
delete m_llModelInfo.model; m_llModelInfo.model.reset();
m_llModelInfo.model = nullptr;
if (!m_isServer) if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename())); emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
modelLoadProps.insert("error", "loadmodel_failed"); modelLoadProps.insert("error", "loadmodel_failed");
@ -438,10 +467,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
case 'G': m_llModelType = LLModelType::GPTJ_; break; case 'G': m_llModelType = LLModelType::GPTJ_; break;
default: default:
{ {
delete m_llModelInfo.model; m_llModelInfo.model.reset();
m_llModelInfo.model = nullptr;
if (!m_isServer) if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename())); emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename()));
} }
@ -451,13 +479,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
} }
} else { } else {
if (!m_isServer) if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError)); emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError));
} }
} }
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
restoreState(); restoreState();
#if defined(DEBUG) #if defined(DEBUG)
@ -471,7 +499,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps); Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
} else { } else {
if (!m_isServer) if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename())); emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename()));
} }
@ -480,7 +508,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
setModelInfo(modelInfo); setModelInfo(modelInfo);
processSystemPrompt(); processSystemPrompt();
} }
return m_llModelInfo.model; return bool(m_llModelInfo.model);
} }
bool ChatLLM::isModelLoaded() const bool ChatLLM::isModelLoaded() const
@ -700,22 +728,23 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
emit responseChanged(QString::fromStdString(m_response)); emit responseChanged(QString::fromStdString(m_response));
} }
emit responseStopped(elapsed); emit responseStopped(elapsed);
m_pristineLoadedState = false;
return true; return true;
} }
void ChatLLM::setShouldBeLoaded(bool b) void ChatLLM::setShouldBeLoaded(bool b)
{ {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model; qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
#endif #endif
m_shouldBeLoaded = b; // atomic m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged(); emit shouldBeLoadedChanged();
} }
void ChatLLM::setShouldTrySwitchContext(bool b) void ChatLLM::requestTrySwitchContext()
{ {
m_shouldTrySwitchContext = b; // atomic m_shouldBeLoaded = true; // atomic
emit shouldTrySwitchContextChanged(); emit trySwitchContextRequested(modelInfo());
} }
void ChatLLM::handleShouldBeLoadedChanged() void ChatLLM::handleShouldBeLoadedChanged()
@ -726,12 +755,6 @@ void ChatLLM::handleShouldBeLoadedChanged()
unloadModel(); unloadModel();
} }
void ChatLLM::handleShouldTrySwitchContextChanged()
{
if (m_shouldTrySwitchContext)
trySwitchContextOfLoadedModel(modelInfo());
}
void ChatLLM::unloadModel() void ChatLLM::unloadModel()
{ {
if (!isModelLoaded() || m_isServer) if (!isModelLoaded() || m_isServer)
@ -746,17 +769,16 @@ void ChatLLM::unloadModel()
saveState(); saveState();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
if (m_forceUnloadModel) { if (m_forceUnloadModel) {
delete m_llModelInfo.model; m_llModelInfo.model.reset();
m_llModelInfo.model = nullptr;
m_forceUnloadModel = false; m_forceUnloadModel = false;
} }
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo(); m_pristineLoadedState = false;
} }
void ChatLLM::reloadModel() void ChatLLM::reloadModel()
@ -768,7 +790,7 @@ void ChatLLM::reloadModel()
return; return;
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
const ModelInfo m = modelInfo(); const ModelInfo m = modelInfo();
if (m.name().isEmpty()) if (m.name().isEmpty())
@ -795,6 +817,7 @@ void ChatLLM::generateName()
m_nameResponse = trimmed; m_nameResponse = trimmed;
emit generatedNameChanged(QString::fromStdString(m_nameResponse)); emit generatedNameChanged(QString::fromStdString(m_nameResponse));
} }
m_pristineLoadedState = false;
} }
void ChatLLM::handleChatIdChanged(const QString &id) void ChatLLM::handleChatIdChanged(const QString &id)
@ -934,7 +957,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the // If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive // text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone. // alone.
m_restoreStateFromText = !deserializeKV || discardKV; if (!deserializeKV || discardKV) {
m_restoreStateFromText = true;
m_pristineLoadedState = true;
}
if (!deserializeKV) { if (!deserializeKV) {
#if defined(DEBUG) #if defined(DEBUG)
@ -998,14 +1024,14 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
void ChatLLM::saveState() void ChatLLM::saveState()
{ {
if (!isModelLoaded()) if (!isModelLoaded() || m_pristineLoadedState)
return; return;
if (m_llModelType == LLModelType::API_) { if (m_llModelType == LLModelType::API_) {
m_state.clear(); m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly); QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4); stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model); ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
stream << chatAPI->context(); stream << chatAPI->context();
return; return;
} }
@ -1026,7 +1052,7 @@ void ChatLLM::restoreState()
if (m_llModelType == LLModelType::API_) { if (m_llModelType == LLModelType::API_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly); QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_4); stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model); ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
QList<QString> context; QList<QString> context;
stream >> context; stream >> context;
chatAPI->setContext(context); chatAPI->setContext(context);
@ -1045,13 +1071,18 @@ void ChatLLM::restoreState()
if (m_llModelInfo.model->stateSize() == m_state.size()) { if (m_llModelInfo.model->stateSize() == m_state.size()) {
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data()))); m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_processedSystemPrompt = true; m_processedSystemPrompt = true;
m_pristineLoadedState = true;
} else { } else {
qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size(); qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size();
m_restoreStateFromText = true; m_restoreStateFromText = true;
} }
m_state.clear(); // free local state copy unless unload is pending
m_state.squeeze(); if (m_shouldBeLoaded) {
m_state.clear();
m_state.squeeze();
m_pristineLoadedState = false;
}
} }
void ChatLLM::processSystemPrompt() void ChatLLM::processSystemPrompt()
@ -1105,6 +1136,7 @@ void ChatLLM::processSystemPrompt()
#endif #endif
m_processedSystemPrompt = m_stopGenerating == false; m_processedSystemPrompt = m_stopGenerating == false;
m_pristineLoadedState = false;
} }
void ChatLLM::processRestoreStateFromText() void ChatLLM::processRestoreStateFromText()
@ -1163,4 +1195,6 @@ void ChatLLM::processRestoreStateFromText()
m_isRecalc = false; m_isRecalc = false;
emit recalcChanged(); emit recalcChanged();
m_pristineLoadedState = false;
} }

View File

@ -5,6 +5,8 @@
#include <QThread> #include <QThread>
#include <QFileInfo> #include <QFileInfo>
#include <memory>
#include "database.h" #include "database.h"
#include "modellist.h" #include "modellist.h"
#include "../gpt4all-backend/llmodel.h" #include "../gpt4all-backend/llmodel.h"
@ -16,7 +18,7 @@ enum LLModelType {
}; };
struct LLModelInfo { struct LLModelInfo {
LLModel *model = nullptr; std::unique_ptr<LLModel> model;
QFileInfo fileInfo; QFileInfo fileInfo;
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which // NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
// must be able to serialize the information even if it is in the unloaded state // must be able to serialize the information even if it is in the unloaded state
@ -72,6 +74,7 @@ public:
virtual ~ChatLLM(); virtual ~ChatLLM();
void destroy(); void destroy();
static void destroyStore();
bool isModelLoaded() const; bool isModelLoaded() const;
void regenerateResponse(); void regenerateResponse();
void resetResponse(); void resetResponse();
@ -81,7 +84,7 @@ public:
bool shouldBeLoaded() const { return m_shouldBeLoaded; } bool shouldBeLoaded() const { return m_shouldBeLoaded; }
void setShouldBeLoaded(bool b); void setShouldBeLoaded(bool b);
void setShouldTrySwitchContext(bool b); void requestTrySwitchContext();
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
@ -101,7 +104,7 @@ public:
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt); bool prompt(const QList<QString> &collectionList, const QString &prompt);
bool loadDefaultModel(); bool loadDefaultModel();
bool trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
bool loadModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo);
void modelChangeRequested(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo);
void unloadModel(); void unloadModel();
@ -109,7 +112,6 @@ public Q_SLOTS:
void generateName(); void generateName();
void handleChatIdChanged(const QString &id); void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged(); void handleShouldBeLoadedChanged();
void handleShouldTrySwitchContextChanged();
void handleThreadStarted(); void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal); void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged(); void handleDeviceChanged();
@ -128,8 +130,8 @@ Q_SIGNALS:
void stateChanged(); void stateChanged();
void threadStarted(); void threadStarted();
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
void shouldTrySwitchContextChanged(); void trySwitchContextRequested(const ModelInfo &modelInfo);
void trySwitchContextOfLoadedModelCompleted(bool); void trySwitchContextOfLoadedModelCompleted(int value);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results); void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed); void reportSpeed(const QString &speed);
void reportDevice(const QString &device); void reportDevice(const QString &device);
@ -172,7 +174,6 @@ private:
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded; std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_shouldTrySwitchContext;
std::atomic<bool> m_isRecalc; std::atomic<bool> m_isRecalc;
std::atomic<bool> m_forceUnloadModel; std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion; std::atomic<bool> m_markedForDeletion;
@ -181,6 +182,10 @@ private:
bool m_reloadingToChangeVariant; bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt; bool m_processedSystemPrompt;
bool m_restoreStateFromText; bool m_restoreStateFromText;
// m_pristineLoadedState is set if saveSate is unnecessary, either because:
// - an unload was queued during LLModel::restoreState()
// - the chat will be restored from text and hasn't been interacted with yet
bool m_pristineLoadedState = false;
QVector<QPair<QString, QString>> m_stateFromText; QVector<QPair<QString, QString>> m_stateFromText;
}; };

View File

@ -122,10 +122,6 @@ Rectangle {
return ModelList.modelInfo(currentChat.modelInfo.id).name; return ModelList.modelInfo(currentChat.modelInfo.id).name;
} }
property bool isCurrentlyLoading: false
property real modelLoadingPercentage: 0.0
property bool trySwitchContextInProgress: false
PopupDialog { PopupDialog {
id: errorCompatHardware id: errorCompatHardware
anchors.centerIn: parent anchors.centerIn: parent
@ -340,34 +336,18 @@ Rectangle {
implicitWidth: 575 implicitWidth: 575
width: window.width >= 750 ? implicitWidth : implicitWidth - (750 - window.width) width: window.width >= 750 ? implicitWidth : implicitWidth - (750 - window.width)
enabled: !currentChat.isServer enabled: !currentChat.isServer
&& !window.trySwitchContextInProgress && !currentChat.trySwitchContextInProgress
&& !window.isCurrentlyLoading && !currentChat.isCurrentlyLoading
model: ModelList.installedModels model: ModelList.installedModels
valueRole: "id" valueRole: "id"
textRole: "name" textRole: "name"
function changeModel(index) { function changeModel(index) {
window.modelLoadingPercentage = 0.0;
window.isCurrentlyLoading = true;
currentChat.stopGenerating() currentChat.stopGenerating()
currentChat.reset(); currentChat.reset();
currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index)) currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index))
} }
Connections {
target: currentChat
function onModelLoadingPercentageChanged() {
window.modelLoadingPercentage = currentChat.modelLoadingPercentage;
window.isCurrentlyLoading = currentChat.modelLoadingPercentage !== 0.0
&& currentChat.modelLoadingPercentage !== 1.0;
}
function onTrySwitchContextOfLoadedModelAttempted() {
window.trySwitchContextInProgress = true;
}
function onTrySwitchContextOfLoadedModelCompleted() {
window.trySwitchContextInProgress = false;
}
}
Connections { Connections {
target: switchModelDialog target: switchModelDialog
function onAccepted() { function onAccepted() {
@ -377,14 +357,14 @@ Rectangle {
background: ProgressBar { background: ProgressBar {
id: modelProgress id: modelProgress
value: window.modelLoadingPercentage value: currentChat.modelLoadingPercentage
background: Rectangle { background: Rectangle {
color: theme.mainComboBackground color: theme.mainComboBackground
radius: 10 radius: 10
} }
contentItem: Item { contentItem: Item {
Rectangle { Rectangle {
visible: window.isCurrentlyLoading visible: currentChat.isCurrentlyLoading
anchors.bottom: parent.bottom anchors.bottom: parent.bottom
width: modelProgress.visualPosition * parent.width width: modelProgress.visualPosition * parent.width
height: 10 height: 10
@ -406,13 +386,15 @@ Rectangle {
text: { text: {
if (currentChat.modelLoadingError !== "") if (currentChat.modelLoadingError !== "")
return qsTr("Model loading error...") return qsTr("Model loading error...")
if (window.trySwitchContextInProgress) if (currentChat.trySwitchContextInProgress == 1)
return qsTr("Waiting for model...")
if (currentChat.trySwitchContextInProgress == 2)
return qsTr("Switching context...") return qsTr("Switching context...")
if (currentModelName() === "") if (currentModelName() === "")
return qsTr("Choose a model...") return qsTr("Choose a model...")
if (currentChat.modelLoadingPercentage === 0.0) if (currentChat.modelLoadingPercentage === 0.0)
return qsTr("Reload \u00B7 ") + currentModelName() return qsTr("Reload \u00B7 ") + currentModelName()
if (window.isCurrentlyLoading) if (currentChat.isCurrentlyLoading)
return qsTr("Loading \u00B7 ") + currentModelName() return qsTr("Loading \u00B7 ") + currentModelName()
return currentModelName() return currentModelName()
} }
@ -456,7 +438,7 @@ Rectangle {
MyMiniButton { MyMiniButton {
id: ejectButton id: ejectButton
visible: currentChat.isModelLoaded && !window.isCurrentlyLoading visible: currentChat.isModelLoaded && !currentChat.isCurrentlyLoading
z: 500 z: 500
anchors.right: parent.right anchors.right: parent.right
anchors.rightMargin: 50 anchors.rightMargin: 50
@ -474,8 +456,8 @@ Rectangle {
MyMiniButton { MyMiniButton {
id: reloadButton id: reloadButton
visible: currentChat.modelLoadingError === "" visible: currentChat.modelLoadingError === ""
&& !window.trySwitchContextInProgress && !currentChat.trySwitchContextInProgress
&& !window.isCurrentlyLoading && !currentChat.isCurrentlyLoading
&& (currentChat.isModelLoaded || currentModelName() !== "") && (currentChat.isModelLoaded || currentModelName() !== "")
z: 500 z: 500
anchors.right: ejectButton.visible ? ejectButton.left : parent.right anchors.right: ejectButton.visible ? ejectButton.left : parent.right
@ -1344,8 +1326,9 @@ Rectangle {
textColor: theme.textColor textColor: theme.textColor
visible: !currentChat.isServer visible: !currentChat.isServer
&& !currentChat.isModelLoaded && !currentChat.isModelLoaded
&& !window.trySwitchContextInProgress && currentChat.modelLoadingError === ""
&& !window.isCurrentlyLoading && !currentChat.trySwitchContextInProgress
&& !currentChat.isCurrentlyLoading
&& currentModelName() !== "" && currentModelName() !== ""
Image { Image {