chat: fix issues with quickly switching between multiple chats (#2343)

* prevent load progress from getting out of sync with the current chat
* fix memory leak on exit if the LLModelStore contains a model
* do not report cancellation as a failure in console/Mixpanel
* show "waiting for model" separately from "switching context" in UI
* do not show lower "reload" button on error
* skip context switch if unload is pending
* skip unnecessary calls to LLModel::saveState

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-05-15 14:07:03 -04:00
committed by GitHub
parent 7f1c3d4275
commit 7e1e00f331
6 changed files with 179 additions and 143 deletions

View File

@@ -30,16 +30,17 @@ public:
static LLModelStore *globalInstance();
LLModelInfo acquireModel(); // will block until llmodel is ready
void releaseModel(const LLModelInfo &info); // must be called when you are done
void releaseModel(LLModelInfo &&info); // must be called when you are done
void destroy();
private:
LLModelStore()
{
// seed with empty model
m_availableModels.append(LLModelInfo());
m_availableModel = LLModelInfo();
}
~LLModelStore() {}
QVector<LLModelInfo> m_availableModels;
std::optional<LLModelInfo> m_availableModel;
QMutex m_mutex;
QWaitCondition m_condition;
friend class MyLLModelStore;
@@ -55,19 +56,27 @@ LLModelStore *LLModelStore::globalInstance()
LLModelInfo LLModelStore::acquireModel()
{
QMutexLocker locker(&m_mutex);
while (m_availableModels.isEmpty())
while (!m_availableModel)
m_condition.wait(locker.mutex());
return m_availableModels.takeFirst();
auto first = std::move(*m_availableModel);
m_availableModel.reset();
return first;
}
void LLModelStore::releaseModel(const LLModelInfo &info)
void LLModelStore::releaseModel(LLModelInfo &&info)
{
QMutexLocker locker(&m_mutex);
m_availableModels.append(info);
Q_ASSERT(m_availableModels.count() < 2);
Q_ASSERT(!m_availableModel);
m_availableModel = std::move(info);
m_condition.wakeAll();
}
void LLModelStore::destroy()
{
QMutexLocker locker(&m_mutex);
m_availableModel.reset();
}
ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr}
, m_promptResponseTokens(0)
@@ -76,7 +85,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false)
, m_shouldTrySwitchContext(false)
, m_stopGenerating(false)
, m_timer(nullptr)
, m_isServer(isServer)
@@ -88,7 +96,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel,
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
@@ -108,7 +116,8 @@ ChatLLM::~ChatLLM()
destroy();
}
void ChatLLM::destroy() {
void ChatLLM::destroy()
{
m_stopGenerating = true;
m_llmThread.quit();
m_llmThread.wait();
@@ -116,11 +125,15 @@ void ChatLLM::destroy() {
// The only time we should have a model loaded here is on shutdown
// as we explicitly unload the model in all other circumstances
if (isModelLoaded()) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
}
}
void ChatLLM::destroyStore()
{
LLModelStore::globalInstance()->destroy();
}
void ChatLLM::handleThreadStarted()
{
m_timer = new TokenTimer(this);
@@ -161,7 +174,7 @@ bool ChatLLM::loadDefaultModel()
return loadModel(defaultModel);
}
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{
// We're trying to see if the store already has the model fully loaded that we wish to use
// and if so we just acquire it from the store and switch the context and return true. If the
@@ -169,10 +182,11 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
// If we're already loaded or a server or we're reloading to change the variant/device or the
// modelInfo is empty, then this should fail
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
if (
isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded
) {
emit trySwitchContextOfLoadedModelCompleted(0);
return;
}
QString filePath = modelInfo.dirpath + modelInfo.filename();
@@ -180,33 +194,28 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
// The store gave us no already loaded model, the wrong type of model, then give it back to the
// store and fail
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo || !m_shouldBeLoaded) {
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
emit trySwitchContextOfLoadedModelCompleted(0);
return;
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
// We should be loaded and now we are
m_shouldBeLoaded = true;
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(2);
// Restore, signal and process
restoreState();
emit modelLoadingPercentageChanged(1.0f);
emit trySwitchContextOfLoadedModelCompleted(true);
emit trySwitchContextOfLoadedModelCompleted(0);
processSystemPrompt();
return true;
}
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
@@ -223,6 +232,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (isModelLoaded() && this->modelInfo() == modelInfo)
return true;
// reset status
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
emit modelLoadingError("");
emit reportFallbackReason("");
emit reportDevice("");
m_pristineLoadedState = false;
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);
@@ -231,28 +247,25 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (alreadyAcquired) {
resetContext();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
m_llModelInfo.model.reset();
} else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading
if (!m_shouldBeLoaded) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
emit modelLoadingPercentageChanged(0.0f);
return false;
}
@@ -260,7 +273,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// Check if the store just gave us exactly the model we were looking for
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
restoreState();
emit modelLoadingPercentageChanged(1.0f);
@@ -274,10 +287,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
} else {
// Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
}
}
@@ -307,7 +319,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
model->setModelName(modelName);
model->setRequestURL(modelInfo.url());
model->setAPIKey(apiKey);
m_llModelInfo.model = model;
m_llModelInfo.model.reset(model);
} else {
QElapsedTimer modelLoadTimer;
modelLoadTimer.start();
@@ -322,9 +334,10 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
buildVariant = "metal";
#endif
QString constructError;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
try {
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
auto *model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
m_llModelInfo.model.reset(model);
} catch (const LLModel::MissingImplementationError &e) {
modelLoadProps.insert("error", "missing_model_impl");
constructError = e.what();
@@ -355,8 +368,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
return m_shouldBeLoaded;
});
emit reportFallbackReason(""); // no fallback yet
auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) {
float memGB = dev->heapSize / float(1024 * 1024 * 1024);
return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
@@ -407,6 +418,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
if (!m_shouldBeLoaded) {
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
if (actualDevice == "CPU") {
// we asked llama.cpp to use the CPU
} else if (!success) {
@@ -415,6 +436,15 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed");
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
if (!m_shouldBeLoaded) {
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
} else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate
@@ -425,10 +455,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
}
if (!success) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
modelLoadProps.insert("error", "loadmodel_failed");
@@ -438,10 +467,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
case 'G': m_llModelType = LLModelType::GPTJ_; break;
default:
{
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename()));
}
@@ -451,13 +479,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
}
} else {
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError));
}
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
restoreState();
#if defined(DEBUG)
@@ -471,7 +499,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
} else {
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename()));
}
@@ -480,7 +508,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
setModelInfo(modelInfo);
processSystemPrompt();
}
return m_llModelInfo.model;
return bool(m_llModelInfo.model);
}
bool ChatLLM::isModelLoaded() const
@@ -700,22 +728,23 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
emit responseChanged(QString::fromStdString(m_response));
}
emit responseStopped(elapsed);
m_pristineLoadedState = false;
return true;
}
void ChatLLM::setShouldBeLoaded(bool b)
{
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model;
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
#endif
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged();
}
void ChatLLM::setShouldTrySwitchContext(bool b)
void ChatLLM::requestTrySwitchContext()
{
m_shouldTrySwitchContext = b; // atomic
emit shouldTrySwitchContextChanged();
m_shouldBeLoaded = true; // atomic
emit trySwitchContextRequested(modelInfo());
}
void ChatLLM::handleShouldBeLoadedChanged()
@@ -726,12 +755,6 @@ void ChatLLM::handleShouldBeLoadedChanged()
unloadModel();
}
void ChatLLM::handleShouldTrySwitchContextChanged()
{
if (m_shouldTrySwitchContext)
trySwitchContextOfLoadedModel(modelInfo());
}
void ChatLLM::unloadModel()
{
if (!isModelLoaded() || m_isServer)
@@ -746,17 +769,16 @@ void ChatLLM::unloadModel()
saveState();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
if (m_forceUnloadModel) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_llModelInfo.model.reset();
m_forceUnloadModel = false;
}
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_pristineLoadedState = false;
}
void ChatLLM::reloadModel()
@@ -768,7 +790,7 @@ void ChatLLM::reloadModel()
return;
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
const ModelInfo m = modelInfo();
if (m.name().isEmpty())
@@ -795,6 +817,7 @@ void ChatLLM::generateName()
m_nameResponse = trimmed;
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
}
m_pristineLoadedState = false;
}
void ChatLLM::handleChatIdChanged(const QString &id)
@@ -934,7 +957,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = !deserializeKV || discardKV;
if (!deserializeKV || discardKV) {
m_restoreStateFromText = true;
m_pristineLoadedState = true;
}
if (!deserializeKV) {
#if defined(DEBUG)
@@ -998,14 +1024,14 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
void ChatLLM::saveState()
{
if (!isModelLoaded())
if (!isModelLoaded() || m_pristineLoadedState)
return;
if (m_llModelType == LLModelType::API_) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
stream << chatAPI->context();
return;
}
@@ -1026,7 +1052,7 @@ void ChatLLM::restoreState()
if (m_llModelType == LLModelType::API_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
QList<QString> context;
stream >> context;
chatAPI->setContext(context);
@@ -1045,13 +1071,18 @@ void ChatLLM::restoreState()
if (m_llModelInfo.model->stateSize() == m_state.size()) {
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
} else {
qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size();
m_restoreStateFromText = true;
}
m_state.clear();
m_state.squeeze();
// free local state copy unless unload is pending
if (m_shouldBeLoaded) {
m_state.clear();
m_state.squeeze();
m_pristineLoadedState = false;
}
}
void ChatLLM::processSystemPrompt()
@@ -1105,6 +1136,7 @@ void ChatLLM::processSystemPrompt()
#endif
m_processedSystemPrompt = m_stopGenerating == false;
m_pristineLoadedState = false;
}
void ChatLLM::processRestoreStateFromText()
@@ -1163,4 +1195,6 @@ void ChatLLM::processRestoreStateFromText()
m_isRecalc = false;
emit recalcChanged();
m_pristineLoadedState = false;
}