mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-26 15:31:55 +00:00
chat: fix issues with quickly switching between multiple chats (#2343)
* prevent load progress from getting out of sync with the current chat * fix memory leak on exit if the LLModelStore contains a model * do not report cancellation as a failure in console/Mixpanel * show "waiting for model" separately from "switching context" in UI * do not show lower "reload" button on error * skip context switch if unload is pending * skip unnecessary calls to LLModel::saveState Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
7f1c3d4275
commit
7e1e00f331
@ -54,7 +54,7 @@ void Chat::connectLLM()
|
||||
connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::trySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
|
||||
|
||||
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection);
|
||||
@ -95,16 +95,6 @@ void Chat::processSystemPrompt()
|
||||
emit processSystemPromptRequested();
|
||||
}
|
||||
|
||||
bool Chat::isModelLoaded() const
|
||||
{
|
||||
return m_modelLoadingPercentage == 1.0f;
|
||||
}
|
||||
|
||||
float Chat::modelLoadingPercentage() const
|
||||
{
|
||||
return m_modelLoadingPercentage;
|
||||
}
|
||||
|
||||
void Chat::resetResponseState()
|
||||
{
|
||||
if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval)
|
||||
@ -167,9 +157,16 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
|
||||
if (loadingPercentage == m_modelLoadingPercentage)
|
||||
return;
|
||||
|
||||
bool wasLoading = isCurrentlyLoading();
|
||||
bool wasLoaded = isModelLoaded();
|
||||
|
||||
m_modelLoadingPercentage = loadingPercentage;
|
||||
emit modelLoadingPercentageChanged();
|
||||
if (m_modelLoadingPercentage == 1.0f || m_modelLoadingPercentage == 0.0f)
|
||||
|
||||
if (isCurrentlyLoading() != wasLoading)
|
||||
emit isCurrentlyLoadingChanged();
|
||||
|
||||
if (isModelLoaded() != wasLoaded)
|
||||
emit isModelLoadedChanged();
|
||||
}
|
||||
|
||||
@ -247,10 +244,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo)
|
||||
if (m_modelInfo == modelInfo && isModelLoaded())
|
||||
return;
|
||||
|
||||
m_modelLoadingPercentage = std::numeric_limits<float>::min(); // small non-zero positive value
|
||||
emit isModelLoadedChanged();
|
||||
m_modelLoadingError = QString();
|
||||
emit modelLoadingErrorChanged();
|
||||
m_modelInfo = modelInfo;
|
||||
emit modelInfoChanged();
|
||||
emit modelChangeRequested(modelInfo);
|
||||
@ -320,8 +313,9 @@ void Chat::forceReloadModel()
|
||||
|
||||
void Chat::trySwitchContextOfLoadedModel()
|
||||
{
|
||||
emit trySwitchContextOfLoadedModelAttempted();
|
||||
m_llmodel->setShouldTrySwitchContext(true);
|
||||
m_trySwitchContextInProgress = 1;
|
||||
emit trySwitchContextInProgressChanged();
|
||||
m_llmodel->requestTrySwitchContext();
|
||||
}
|
||||
|
||||
void Chat::generatedNameChanged(const QString &name)
|
||||
@ -342,8 +336,10 @@ void Chat::handleRecalculating()
|
||||
|
||||
void Chat::handleModelLoadingError(const QString &error)
|
||||
{
|
||||
auto stream = qWarning().noquote() << "ERROR:" << error << "id";
|
||||
stream.quote() << id();
|
||||
if (!error.isEmpty()) {
|
||||
auto stream = qWarning().noquote() << "ERROR:" << error << "id";
|
||||
stream.quote() << id();
|
||||
}
|
||||
m_modelLoadingError = error;
|
||||
emit modelLoadingErrorChanged();
|
||||
}
|
||||
@ -380,6 +376,11 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
|
||||
emit modelInfoChanged();
|
||||
}
|
||||
|
||||
void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) {
|
||||
m_trySwitchContextInProgress = value;
|
||||
emit trySwitchContextInProgressChanged();
|
||||
}
|
||||
|
||||
bool Chat::serialize(QDataStream &stream, int version) const
|
||||
{
|
||||
stream << m_creationDate;
|
||||
|
@ -17,6 +17,7 @@ class Chat : public QObject
|
||||
Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged)
|
||||
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
|
||||
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
||||
Q_PROPERTY(bool isCurrentlyLoading READ isCurrentlyLoading NOTIFY isCurrentlyLoadingChanged)
|
||||
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
|
||||
@ -30,6 +31,8 @@ class Chat : public QObject
|
||||
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
|
||||
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged);
|
||||
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
|
||||
// 0=no, 1=waiting, 2=working
|
||||
Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
|
||||
QML_ELEMENT
|
||||
QML_UNCREATABLE("Only creatable from c++!")
|
||||
|
||||
@ -62,8 +65,9 @@ public:
|
||||
|
||||
Q_INVOKABLE void reset();
|
||||
Q_INVOKABLE void processSystemPrompt();
|
||||
Q_INVOKABLE bool isModelLoaded() const;
|
||||
Q_INVOKABLE float modelLoadingPercentage() const;
|
||||
bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; }
|
||||
bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; }
|
||||
float modelLoadingPercentage() const { return m_modelLoadingPercentage; }
|
||||
Q_INVOKABLE void prompt(const QString &prompt);
|
||||
Q_INVOKABLE void regenerateResponse();
|
||||
Q_INVOKABLE void stopGenerating();
|
||||
@ -105,6 +109,8 @@ public:
|
||||
QString device() const { return m_device; }
|
||||
QString fallbackReason() const { return m_fallbackReason; }
|
||||
|
||||
int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }
|
||||
|
||||
public Q_SLOTS:
|
||||
void serverNewPromptResponsePair(const QString &prompt);
|
||||
|
||||
@ -113,6 +119,7 @@ Q_SIGNALS:
|
||||
void nameChanged();
|
||||
void chatModelChanged();
|
||||
void isModelLoadedChanged();
|
||||
void isCurrentlyLoadingChanged();
|
||||
void modelLoadingPercentageChanged();
|
||||
void modelLoadingWarning(const QString &warning);
|
||||
void responseChanged();
|
||||
@ -136,8 +143,7 @@ Q_SIGNALS:
|
||||
void deviceChanged();
|
||||
void fallbackReasonChanged();
|
||||
void collectionModelChanged();
|
||||
void trySwitchContextOfLoadedModelAttempted();
|
||||
void trySwitchContextOfLoadedModelCompleted(bool);
|
||||
void trySwitchContextInProgressChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void handleResponseChanged(const QString &response);
|
||||
@ -152,6 +158,7 @@ private Q_SLOTS:
|
||||
void handleFallbackReasonChanged(const QString &device);
|
||||
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
||||
void handleModelInfoChanged(const ModelInfo &modelInfo);
|
||||
void handleTrySwitchContextOfLoadedModelCompleted(int value);
|
||||
|
||||
private:
|
||||
QString m_id;
|
||||
@ -176,6 +183,8 @@ private:
|
||||
float m_modelLoadingPercentage = 0.0f;
|
||||
LocalDocsCollectionsModel *m_collectionModel;
|
||||
bool m_firstResponse = true;
|
||||
int m_trySwitchContextInProgress = 0;
|
||||
bool m_isCurrentlyLoading = false;
|
||||
};
|
||||
|
||||
#endif // CHAT_H
|
||||
|
@ -195,7 +195,11 @@ public:
|
||||
int count() const { return m_chats.size(); }
|
||||
|
||||
// stop ChatLLM threads for clean shutdown
|
||||
void destroyChats() { for (auto *chat: m_chats) { chat->destroy(); } }
|
||||
void destroyChats()
|
||||
{
|
||||
for (auto *chat: m_chats) { chat->destroy(); }
|
||||
ChatLLM::destroyStore();
|
||||
}
|
||||
|
||||
void removeChatFile(Chat *chat) const;
|
||||
Q_INVOKABLE void saveChats();
|
||||
|
@ -30,16 +30,17 @@ public:
|
||||
static LLModelStore *globalInstance();
|
||||
|
||||
LLModelInfo acquireModel(); // will block until llmodel is ready
|
||||
void releaseModel(const LLModelInfo &info); // must be called when you are done
|
||||
void releaseModel(LLModelInfo &&info); // must be called when you are done
|
||||
void destroy();
|
||||
|
||||
private:
|
||||
LLModelStore()
|
||||
{
|
||||
// seed with empty model
|
||||
m_availableModels.append(LLModelInfo());
|
||||
m_availableModel = LLModelInfo();
|
||||
}
|
||||
~LLModelStore() {}
|
||||
QVector<LLModelInfo> m_availableModels;
|
||||
std::optional<LLModelInfo> m_availableModel;
|
||||
QMutex m_mutex;
|
||||
QWaitCondition m_condition;
|
||||
friend class MyLLModelStore;
|
||||
@ -55,19 +56,27 @@ LLModelStore *LLModelStore::globalInstance()
|
||||
LLModelInfo LLModelStore::acquireModel()
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
while (m_availableModels.isEmpty())
|
||||
while (!m_availableModel)
|
||||
m_condition.wait(locker.mutex());
|
||||
return m_availableModels.takeFirst();
|
||||
auto first = std::move(*m_availableModel);
|
||||
m_availableModel.reset();
|
||||
return first;
|
||||
}
|
||||
|
||||
void LLModelStore::releaseModel(const LLModelInfo &info)
|
||||
void LLModelStore::releaseModel(LLModelInfo &&info)
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_availableModels.append(info);
|
||||
Q_ASSERT(m_availableModels.count() < 2);
|
||||
Q_ASSERT(!m_availableModel);
|
||||
m_availableModel = std::move(info);
|
||||
m_condition.wakeAll();
|
||||
}
|
||||
|
||||
void LLModelStore::destroy()
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_availableModel.reset();
|
||||
}
|
||||
|
||||
ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
: QObject{nullptr}
|
||||
, m_promptResponseTokens(0)
|
||||
@ -76,7 +85,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
, m_shouldBeLoaded(false)
|
||||
, m_forceUnloadModel(false)
|
||||
, m_markedForDeletion(false)
|
||||
, m_shouldTrySwitchContext(false)
|
||||
, m_stopGenerating(false)
|
||||
, m_timer(nullptr)
|
||||
, m_isServer(isServer)
|
||||
@ -88,7 +96,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
moveToThread(&m_llmThread);
|
||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
|
||||
connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
||||
@ -108,7 +116,8 @@ ChatLLM::~ChatLLM()
|
||||
destroy();
|
||||
}
|
||||
|
||||
void ChatLLM::destroy() {
|
||||
void ChatLLM::destroy()
|
||||
{
|
||||
m_stopGenerating = true;
|
||||
m_llmThread.quit();
|
||||
m_llmThread.wait();
|
||||
@ -116,11 +125,15 @@ void ChatLLM::destroy() {
|
||||
// The only time we should have a model loaded here is on shutdown
|
||||
// as we explicitly unload the model in all other circumstances
|
||||
if (isModelLoaded()) {
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
}
|
||||
}
|
||||
|
||||
void ChatLLM::destroyStore()
|
||||
{
|
||||
LLModelStore::globalInstance()->destroy();
|
||||
}
|
||||
|
||||
void ChatLLM::handleThreadStarted()
|
||||
{
|
||||
m_timer = new TokenTimer(this);
|
||||
@ -161,7 +174,7 @@ bool ChatLLM::loadDefaultModel()
|
||||
return loadModel(defaultModel);
|
||||
}
|
||||
|
||||
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
{
|
||||
// We're trying to see if the store already has the model fully loaded that we wish to use
|
||||
// and if so we just acquire it from the store and switch the context and return true. If the
|
||||
@ -169,10 +182,11 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
|
||||
// If we're already loaded or a server or we're reloading to change the variant/device or the
|
||||
// modelInfo is empty, then this should fail
|
||||
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(false);
|
||||
return false;
|
||||
if (
|
||||
isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded
|
||||
) {
|
||||
emit trySwitchContextOfLoadedModelCompleted(0);
|
||||
return;
|
||||
}
|
||||
|
||||
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
||||
@ -180,33 +194,28 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
|
||||
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
|
||||
// The store gave us no already loaded model, the wrong type of model, then give it back to the
|
||||
// store and fail
|
||||
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(false);
|
||||
return false;
|
||||
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo || !m_shouldBeLoaded) {
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
emit trySwitchContextOfLoadedModelCompleted(0);
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
|
||||
// We should be loaded and now we are
|
||||
m_shouldBeLoaded = true;
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(2);
|
||||
|
||||
// Restore, signal and process
|
||||
restoreState();
|
||||
emit modelLoadingPercentageChanged(1.0f);
|
||||
emit trySwitchContextOfLoadedModelCompleted(true);
|
||||
emit trySwitchContextOfLoadedModelCompleted(0);
|
||||
processSystemPrompt();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
@ -223,6 +232,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
if (isModelLoaded() && this->modelInfo() == modelInfo)
|
||||
return true;
|
||||
|
||||
// reset status
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
||||
emit modelLoadingError("");
|
||||
emit reportFallbackReason("");
|
||||
emit reportDevice("");
|
||||
m_pristineLoadedState = false;
|
||||
|
||||
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
||||
QFileInfo fileInfo(filePath);
|
||||
|
||||
@ -231,28 +247,25 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
if (alreadyAcquired) {
|
||||
resetContext();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
||||
m_llModelInfo.model.reset();
|
||||
} else if (!m_isServer) {
|
||||
// This is a blocking call that tries to retrieve the model we need from the model store.
|
||||
// If it succeeds, then we just have to restore state. If the store has never had a model
|
||||
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
|
||||
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
// At this point it is possible that while we were blocked waiting to acquire the model from the
|
||||
// store, that our state was changed to not be loaded. If this is the case, release the model
|
||||
// back into the store and quit loading
|
||||
if (!m_shouldBeLoaded) {
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
@ -260,7 +273,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
// Check if the store just gave us exactly the model we were looking for
|
||||
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
restoreState();
|
||||
emit modelLoadingPercentageChanged(1.0f);
|
||||
@ -274,10 +287,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
} else {
|
||||
// Release the memory since we have to switch to a different model.
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@ -307,7 +319,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
model->setModelName(modelName);
|
||||
model->setRequestURL(modelInfo.url());
|
||||
model->setAPIKey(apiKey);
|
||||
m_llModelInfo.model = model;
|
||||
m_llModelInfo.model.reset(model);
|
||||
} else {
|
||||
QElapsedTimer modelLoadTimer;
|
||||
modelLoadTimer.start();
|
||||
@ -322,9 +334,10 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
buildVariant = "metal";
|
||||
#endif
|
||||
QString constructError;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
try {
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
auto *model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
m_llModelInfo.model.reset(model);
|
||||
} catch (const LLModel::MissingImplementationError &e) {
|
||||
modelLoadProps.insert("error", "missing_model_impl");
|
||||
constructError = e.what();
|
||||
@ -355,8 +368,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
return m_shouldBeLoaded;
|
||||
});
|
||||
|
||||
emit reportFallbackReason(""); // no fallback yet
|
||||
|
||||
auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) {
|
||||
float memGB = dev->heapSize / float(1024 * 1024 * 1024);
|
||||
return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
|
||||
@ -407,6 +418,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
emit reportDevice(actualDevice);
|
||||
|
||||
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
|
||||
|
||||
if (!m_shouldBeLoaded) {
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (actualDevice == "CPU") {
|
||||
// we asked llama.cpp to use the CPU
|
||||
} else if (!success) {
|
||||
@ -415,6 +436,15 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
|
||||
modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed");
|
||||
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
|
||||
|
||||
if (!m_shouldBeLoaded) {
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
} else if (!m_llModelInfo.model->usingGPUDevice()) {
|
||||
// ggml_vk_init was not called in llama.cpp
|
||||
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
||||
@ -425,10 +455,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
|
||||
modelLoadProps.insert("error", "loadmodel_failed");
|
||||
@ -438,10 +467,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
case 'G': m_llModelType = LLModelType::GPTJ_; break;
|
||||
default:
|
||||
{
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename()));
|
||||
}
|
||||
@ -451,13 +479,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
}
|
||||
} else {
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError));
|
||||
}
|
||||
}
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
restoreState();
|
||||
#if defined(DEBUG)
|
||||
@ -471,7 +499,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
|
||||
} else {
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename()));
|
||||
}
|
||||
@ -480,7 +508,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
setModelInfo(modelInfo);
|
||||
processSystemPrompt();
|
||||
}
|
||||
return m_llModelInfo.model;
|
||||
return bool(m_llModelInfo.model);
|
||||
}
|
||||
|
||||
bool ChatLLM::isModelLoaded() const
|
||||
@ -700,22 +728,23 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
emit responseChanged(QString::fromStdString(m_response));
|
||||
}
|
||||
emit responseStopped(elapsed);
|
||||
m_pristineLoadedState = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ChatLLM::setShouldBeLoaded(bool b)
|
||||
{
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model;
|
||||
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
|
||||
#endif
|
||||
m_shouldBeLoaded = b; // atomic
|
||||
emit shouldBeLoadedChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::setShouldTrySwitchContext(bool b)
|
||||
void ChatLLM::requestTrySwitchContext()
|
||||
{
|
||||
m_shouldTrySwitchContext = b; // atomic
|
||||
emit shouldTrySwitchContextChanged();
|
||||
m_shouldBeLoaded = true; // atomic
|
||||
emit trySwitchContextRequested(modelInfo());
|
||||
}
|
||||
|
||||
void ChatLLM::handleShouldBeLoadedChanged()
|
||||
@ -726,12 +755,6 @@ void ChatLLM::handleShouldBeLoadedChanged()
|
||||
unloadModel();
|
||||
}
|
||||
|
||||
void ChatLLM::handleShouldTrySwitchContextChanged()
|
||||
{
|
||||
if (m_shouldTrySwitchContext)
|
||||
trySwitchContextOfLoadedModel(modelInfo());
|
||||
}
|
||||
|
||||
void ChatLLM::unloadModel()
|
||||
{
|
||||
if (!isModelLoaded() || m_isServer)
|
||||
@ -746,17 +769,16 @@ void ChatLLM::unloadModel()
|
||||
saveState();
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
|
||||
if (m_forceUnloadModel) {
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
m_forceUnloadModel = false;
|
||||
}
|
||||
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
||||
void ChatLLM::reloadModel()
|
||||
@ -768,7 +790,7 @@ void ChatLLM::reloadModel()
|
||||
return;
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
const ModelInfo m = modelInfo();
|
||||
if (m.name().isEmpty())
|
||||
@ -795,6 +817,7 @@ void ChatLLM::generateName()
|
||||
m_nameResponse = trimmed;
|
||||
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
|
||||
}
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
||||
void ChatLLM::handleChatIdChanged(const QString &id)
|
||||
@ -934,7 +957,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
|
||||
// text only. This will be a costly operation, but the chat has to be restored from the text archive
|
||||
// alone.
|
||||
m_restoreStateFromText = !deserializeKV || discardKV;
|
||||
if (!deserializeKV || discardKV) {
|
||||
m_restoreStateFromText = true;
|
||||
m_pristineLoadedState = true;
|
||||
}
|
||||
|
||||
if (!deserializeKV) {
|
||||
#if defined(DEBUG)
|
||||
@ -998,14 +1024,14 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
|
||||
void ChatLLM::saveState()
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
if (!isModelLoaded() || m_pristineLoadedState)
|
||||
return;
|
||||
|
||||
if (m_llModelType == LLModelType::API_) {
|
||||
m_state.clear();
|
||||
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
|
||||
stream << chatAPI->context();
|
||||
return;
|
||||
}
|
||||
@ -1026,7 +1052,7 @@ void ChatLLM::restoreState()
|
||||
if (m_llModelType == LLModelType::API_) {
|
||||
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
|
||||
QList<QString> context;
|
||||
stream >> context;
|
||||
chatAPI->setContext(context);
|
||||
@ -1045,13 +1071,18 @@ void ChatLLM::restoreState()
|
||||
if (m_llModelInfo.model->stateSize() == m_state.size()) {
|
||||
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
||||
m_processedSystemPrompt = true;
|
||||
m_pristineLoadedState = true;
|
||||
} else {
|
||||
qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size();
|
||||
m_restoreStateFromText = true;
|
||||
}
|
||||
|
||||
m_state.clear();
|
||||
m_state.squeeze();
|
||||
// free local state copy unless unload is pending
|
||||
if (m_shouldBeLoaded) {
|
||||
m_state.clear();
|
||||
m_state.squeeze();
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
}
|
||||
|
||||
void ChatLLM::processSystemPrompt()
|
||||
@ -1105,6 +1136,7 @@ void ChatLLM::processSystemPrompt()
|
||||
#endif
|
||||
|
||||
m_processedSystemPrompt = m_stopGenerating == false;
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
||||
void ChatLLM::processRestoreStateFromText()
|
||||
@ -1163,4 +1195,6 @@ void ChatLLM::processRestoreStateFromText()
|
||||
|
||||
m_isRecalc = false;
|
||||
emit recalcChanged();
|
||||
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <QThread>
|
||||
#include <QFileInfo>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "database.h"
|
||||
#include "modellist.h"
|
||||
#include "../gpt4all-backend/llmodel.h"
|
||||
@ -16,7 +18,7 @@ enum LLModelType {
|
||||
};
|
||||
|
||||
struct LLModelInfo {
|
||||
LLModel *model = nullptr;
|
||||
std::unique_ptr<LLModel> model;
|
||||
QFileInfo fileInfo;
|
||||
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
|
||||
// must be able to serialize the information even if it is in the unloaded state
|
||||
@ -72,6 +74,7 @@ public:
|
||||
virtual ~ChatLLM();
|
||||
|
||||
void destroy();
|
||||
static void destroyStore();
|
||||
bool isModelLoaded() const;
|
||||
void regenerateResponse();
|
||||
void resetResponse();
|
||||
@ -81,7 +84,7 @@ public:
|
||||
|
||||
bool shouldBeLoaded() const { return m_shouldBeLoaded; }
|
||||
void setShouldBeLoaded(bool b);
|
||||
void setShouldTrySwitchContext(bool b);
|
||||
void requestTrySwitchContext();
|
||||
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
|
||||
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
|
||||
|
||||
@ -101,7 +104,7 @@ public:
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QList<QString> &collectionList, const QString &prompt);
|
||||
bool loadDefaultModel();
|
||||
bool trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
|
||||
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
|
||||
bool loadModel(const ModelInfo &modelInfo);
|
||||
void modelChangeRequested(const ModelInfo &modelInfo);
|
||||
void unloadModel();
|
||||
@ -109,7 +112,6 @@ public Q_SLOTS:
|
||||
void generateName();
|
||||
void handleChatIdChanged(const QString &id);
|
||||
void handleShouldBeLoadedChanged();
|
||||
void handleShouldTrySwitchContextChanged();
|
||||
void handleThreadStarted();
|
||||
void handleForceMetalChanged(bool forceMetal);
|
||||
void handleDeviceChanged();
|
||||
@ -128,8 +130,8 @@ Q_SIGNALS:
|
||||
void stateChanged();
|
||||
void threadStarted();
|
||||
void shouldBeLoadedChanged();
|
||||
void shouldTrySwitchContextChanged();
|
||||
void trySwitchContextOfLoadedModelCompleted(bool);
|
||||
void trySwitchContextRequested(const ModelInfo &modelInfo);
|
||||
void trySwitchContextOfLoadedModelCompleted(int value);
|
||||
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||
void reportSpeed(const QString &speed);
|
||||
void reportDevice(const QString &device);
|
||||
@ -172,7 +174,6 @@ private:
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
std::atomic<bool> m_shouldBeLoaded;
|
||||
std::atomic<bool> m_shouldTrySwitchContext;
|
||||
std::atomic<bool> m_isRecalc;
|
||||
std::atomic<bool> m_forceUnloadModel;
|
||||
std::atomic<bool> m_markedForDeletion;
|
||||
@ -181,6 +182,10 @@ private:
|
||||
bool m_reloadingToChangeVariant;
|
||||
bool m_processedSystemPrompt;
|
||||
bool m_restoreStateFromText;
|
||||
// m_pristineLoadedState is set if saveSate is unnecessary, either because:
|
||||
// - an unload was queued during LLModel::restoreState()
|
||||
// - the chat will be restored from text and hasn't been interacted with yet
|
||||
bool m_pristineLoadedState = false;
|
||||
QVector<QPair<QString, QString>> m_stateFromText;
|
||||
};
|
||||
|
||||
|
@ -122,10 +122,6 @@ Rectangle {
|
||||
return ModelList.modelInfo(currentChat.modelInfo.id).name;
|
||||
}
|
||||
|
||||
property bool isCurrentlyLoading: false
|
||||
property real modelLoadingPercentage: 0.0
|
||||
property bool trySwitchContextInProgress: false
|
||||
|
||||
PopupDialog {
|
||||
id: errorCompatHardware
|
||||
anchors.centerIn: parent
|
||||
@ -340,34 +336,18 @@ Rectangle {
|
||||
implicitWidth: 575
|
||||
width: window.width >= 750 ? implicitWidth : implicitWidth - (750 - window.width)
|
||||
enabled: !currentChat.isServer
|
||||
&& !window.trySwitchContextInProgress
|
||||
&& !window.isCurrentlyLoading
|
||||
&& !currentChat.trySwitchContextInProgress
|
||||
&& !currentChat.isCurrentlyLoading
|
||||
model: ModelList.installedModels
|
||||
valueRole: "id"
|
||||
textRole: "name"
|
||||
|
||||
function changeModel(index) {
|
||||
window.modelLoadingPercentage = 0.0;
|
||||
window.isCurrentlyLoading = true;
|
||||
currentChat.stopGenerating()
|
||||
currentChat.reset();
|
||||
currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index))
|
||||
}
|
||||
|
||||
Connections {
|
||||
target: currentChat
|
||||
function onModelLoadingPercentageChanged() {
|
||||
window.modelLoadingPercentage = currentChat.modelLoadingPercentage;
|
||||
window.isCurrentlyLoading = currentChat.modelLoadingPercentage !== 0.0
|
||||
&& currentChat.modelLoadingPercentage !== 1.0;
|
||||
}
|
||||
function onTrySwitchContextOfLoadedModelAttempted() {
|
||||
window.trySwitchContextInProgress = true;
|
||||
}
|
||||
function onTrySwitchContextOfLoadedModelCompleted() {
|
||||
window.trySwitchContextInProgress = false;
|
||||
}
|
||||
}
|
||||
Connections {
|
||||
target: switchModelDialog
|
||||
function onAccepted() {
|
||||
@ -377,14 +357,14 @@ Rectangle {
|
||||
|
||||
background: ProgressBar {
|
||||
id: modelProgress
|
||||
value: window.modelLoadingPercentage
|
||||
value: currentChat.modelLoadingPercentage
|
||||
background: Rectangle {
|
||||
color: theme.mainComboBackground
|
||||
radius: 10
|
||||
}
|
||||
contentItem: Item {
|
||||
Rectangle {
|
||||
visible: window.isCurrentlyLoading
|
||||
visible: currentChat.isCurrentlyLoading
|
||||
anchors.bottom: parent.bottom
|
||||
width: modelProgress.visualPosition * parent.width
|
||||
height: 10
|
||||
@ -406,13 +386,15 @@ Rectangle {
|
||||
text: {
|
||||
if (currentChat.modelLoadingError !== "")
|
||||
return qsTr("Model loading error...")
|
||||
if (window.trySwitchContextInProgress)
|
||||
if (currentChat.trySwitchContextInProgress == 1)
|
||||
return qsTr("Waiting for model...")
|
||||
if (currentChat.trySwitchContextInProgress == 2)
|
||||
return qsTr("Switching context...")
|
||||
if (currentModelName() === "")
|
||||
return qsTr("Choose a model...")
|
||||
if (currentChat.modelLoadingPercentage === 0.0)
|
||||
return qsTr("Reload \u00B7 ") + currentModelName()
|
||||
if (window.isCurrentlyLoading)
|
||||
if (currentChat.isCurrentlyLoading)
|
||||
return qsTr("Loading \u00B7 ") + currentModelName()
|
||||
return currentModelName()
|
||||
}
|
||||
@ -456,7 +438,7 @@ Rectangle {
|
||||
|
||||
MyMiniButton {
|
||||
id: ejectButton
|
||||
visible: currentChat.isModelLoaded && !window.isCurrentlyLoading
|
||||
visible: currentChat.isModelLoaded && !currentChat.isCurrentlyLoading
|
||||
z: 500
|
||||
anchors.right: parent.right
|
||||
anchors.rightMargin: 50
|
||||
@ -474,8 +456,8 @@ Rectangle {
|
||||
MyMiniButton {
|
||||
id: reloadButton
|
||||
visible: currentChat.modelLoadingError === ""
|
||||
&& !window.trySwitchContextInProgress
|
||||
&& !window.isCurrentlyLoading
|
||||
&& !currentChat.trySwitchContextInProgress
|
||||
&& !currentChat.isCurrentlyLoading
|
||||
&& (currentChat.isModelLoaded || currentModelName() !== "")
|
||||
z: 500
|
||||
anchors.right: ejectButton.visible ? ejectButton.left : parent.right
|
||||
@ -1344,8 +1326,9 @@ Rectangle {
|
||||
textColor: theme.textColor
|
||||
visible: !currentChat.isServer
|
||||
&& !currentChat.isModelLoaded
|
||||
&& !window.trySwitchContextInProgress
|
||||
&& !window.isCurrentlyLoading
|
||||
&& currentChat.modelLoadingError === ""
|
||||
&& !currentChat.trySwitchContextInProgress
|
||||
&& !currentChat.isCurrentlyLoading
|
||||
&& currentModelName() !== ""
|
||||
|
||||
Image {
|
||||
|
Loading…
Reference in New Issue
Block a user