mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-10 04:49:07 +00:00
chat: fix issues with quickly switching between multiple chats (#2343)
* prevent load progress from getting out of sync with the current chat * fix memory leak on exit if the LLModelStore contains a model * do not report cancellation as a failure in console/Mixpanel * show "waiting for model" separately from "switching context" in UI * do not show lower "reload" button on error * skip context switch if unload is pending * skip unnecessary calls to LLModel::saveState Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -30,16 +30,17 @@ public:
|
||||
static LLModelStore *globalInstance();
|
||||
|
||||
LLModelInfo acquireModel(); // will block until llmodel is ready
|
||||
void releaseModel(const LLModelInfo &info); // must be called when you are done
|
||||
void releaseModel(LLModelInfo &&info); // must be called when you are done
|
||||
void destroy();
|
||||
|
||||
private:
|
||||
LLModelStore()
|
||||
{
|
||||
// seed with empty model
|
||||
m_availableModels.append(LLModelInfo());
|
||||
m_availableModel = LLModelInfo();
|
||||
}
|
||||
~LLModelStore() {}
|
||||
QVector<LLModelInfo> m_availableModels;
|
||||
std::optional<LLModelInfo> m_availableModel;
|
||||
QMutex m_mutex;
|
||||
QWaitCondition m_condition;
|
||||
friend class MyLLModelStore;
|
||||
@@ -55,19 +56,27 @@ LLModelStore *LLModelStore::globalInstance()
|
||||
LLModelInfo LLModelStore::acquireModel()
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
while (m_availableModels.isEmpty())
|
||||
while (!m_availableModel)
|
||||
m_condition.wait(locker.mutex());
|
||||
return m_availableModels.takeFirst();
|
||||
auto first = std::move(*m_availableModel);
|
||||
m_availableModel.reset();
|
||||
return first;
|
||||
}
|
||||
|
||||
void LLModelStore::releaseModel(const LLModelInfo &info)
|
||||
void LLModelStore::releaseModel(LLModelInfo &&info)
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_availableModels.append(info);
|
||||
Q_ASSERT(m_availableModels.count() < 2);
|
||||
Q_ASSERT(!m_availableModel);
|
||||
m_availableModel = std::move(info);
|
||||
m_condition.wakeAll();
|
||||
}
|
||||
|
||||
void LLModelStore::destroy()
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_availableModel.reset();
|
||||
}
|
||||
|
||||
ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
: QObject{nullptr}
|
||||
, m_promptResponseTokens(0)
|
||||
@@ -76,7 +85,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
, m_shouldBeLoaded(false)
|
||||
, m_forceUnloadModel(false)
|
||||
, m_markedForDeletion(false)
|
||||
, m_shouldTrySwitchContext(false)
|
||||
, m_stopGenerating(false)
|
||||
, m_timer(nullptr)
|
||||
, m_isServer(isServer)
|
||||
@@ -88,7 +96,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
moveToThread(&m_llmThread);
|
||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
|
||||
connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
||||
@@ -108,7 +116,8 @@ ChatLLM::~ChatLLM()
|
||||
destroy();
|
||||
}
|
||||
|
||||
void ChatLLM::destroy() {
|
||||
void ChatLLM::destroy()
|
||||
{
|
||||
m_stopGenerating = true;
|
||||
m_llmThread.quit();
|
||||
m_llmThread.wait();
|
||||
@@ -116,11 +125,15 @@ void ChatLLM::destroy() {
|
||||
// The only time we should have a model loaded here is on shutdown
|
||||
// as we explicitly unload the model in all other circumstances
|
||||
if (isModelLoaded()) {
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
}
|
||||
}
|
||||
|
||||
void ChatLLM::destroyStore()
|
||||
{
|
||||
LLModelStore::globalInstance()->destroy();
|
||||
}
|
||||
|
||||
void ChatLLM::handleThreadStarted()
|
||||
{
|
||||
m_timer = new TokenTimer(this);
|
||||
@@ -161,7 +174,7 @@ bool ChatLLM::loadDefaultModel()
|
||||
return loadModel(defaultModel);
|
||||
}
|
||||
|
||||
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
{
|
||||
// We're trying to see if the store already has the model fully loaded that we wish to use
|
||||
// and if so we just acquire it from the store and switch the context and return true. If the
|
||||
@@ -169,10 +182,11 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
|
||||
// If we're already loaded or a server or we're reloading to change the variant/device or the
|
||||
// modelInfo is empty, then this should fail
|
||||
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(false);
|
||||
return false;
|
||||
if (
|
||||
isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded
|
||||
) {
|
||||
emit trySwitchContextOfLoadedModelCompleted(0);
|
||||
return;
|
||||
}
|
||||
|
||||
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
||||
@@ -180,33 +194,28 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
|
||||
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
|
||||
// The store gave us no already loaded model, the wrong type of model, then give it back to the
|
||||
// store and fail
|
||||
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(false);
|
||||
return false;
|
||||
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo || !m_shouldBeLoaded) {
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
emit trySwitchContextOfLoadedModelCompleted(0);
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
|
||||
// We should be loaded and now we are
|
||||
m_shouldBeLoaded = true;
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(2);
|
||||
|
||||
// Restore, signal and process
|
||||
restoreState();
|
||||
emit modelLoadingPercentageChanged(1.0f);
|
||||
emit trySwitchContextOfLoadedModelCompleted(true);
|
||||
emit trySwitchContextOfLoadedModelCompleted(0);
|
||||
processSystemPrompt();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
@@ -223,6 +232,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
if (isModelLoaded() && this->modelInfo() == modelInfo)
|
||||
return true;
|
||||
|
||||
// reset status
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
||||
emit modelLoadingError("");
|
||||
emit reportFallbackReason("");
|
||||
emit reportDevice("");
|
||||
m_pristineLoadedState = false;
|
||||
|
||||
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
||||
QFileInfo fileInfo(filePath);
|
||||
|
||||
@@ -231,28 +247,25 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
if (alreadyAcquired) {
|
||||
resetContext();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
||||
m_llModelInfo.model.reset();
|
||||
} else if (!m_isServer) {
|
||||
// This is a blocking call that tries to retrieve the model we need from the model store.
|
||||
// If it succeeds, then we just have to restore state. If the store has never had a model
|
||||
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
|
||||
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
// At this point it is possible that while we were blocked waiting to acquire the model from the
|
||||
// store, that our state was changed to not be loaded. If this is the case, release the model
|
||||
// back into the store and quit loading
|
||||
if (!m_shouldBeLoaded) {
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
@@ -260,7 +273,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
// Check if the store just gave us exactly the model we were looking for
|
||||
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
restoreState();
|
||||
emit modelLoadingPercentageChanged(1.0f);
|
||||
@@ -274,10 +287,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
} else {
|
||||
// Release the memory since we have to switch to a different model.
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,7 +319,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
model->setModelName(modelName);
|
||||
model->setRequestURL(modelInfo.url());
|
||||
model->setAPIKey(apiKey);
|
||||
m_llModelInfo.model = model;
|
||||
m_llModelInfo.model.reset(model);
|
||||
} else {
|
||||
QElapsedTimer modelLoadTimer;
|
||||
modelLoadTimer.start();
|
||||
@@ -322,9 +334,10 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
buildVariant = "metal";
|
||||
#endif
|
||||
QString constructError;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
try {
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
auto *model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
m_llModelInfo.model.reset(model);
|
||||
} catch (const LLModel::MissingImplementationError &e) {
|
||||
modelLoadProps.insert("error", "missing_model_impl");
|
||||
constructError = e.what();
|
||||
@@ -355,8 +368,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
return m_shouldBeLoaded;
|
||||
});
|
||||
|
||||
emit reportFallbackReason(""); // no fallback yet
|
||||
|
||||
auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) {
|
||||
float memGB = dev->heapSize / float(1024 * 1024 * 1024);
|
||||
return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
|
||||
@@ -407,6 +418,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
emit reportDevice(actualDevice);
|
||||
|
||||
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
|
||||
|
||||
if (!m_shouldBeLoaded) {
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (actualDevice == "CPU") {
|
||||
// we asked llama.cpp to use the CPU
|
||||
} else if (!success) {
|
||||
@@ -415,6 +436,15 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
|
||||
modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed");
|
||||
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
|
||||
|
||||
if (!m_shouldBeLoaded) {
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
} else if (!m_llModelInfo.model->usingGPUDevice()) {
|
||||
// ggml_vk_init was not called in llama.cpp
|
||||
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
||||
@@ -425,10 +455,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
|
||||
modelLoadProps.insert("error", "loadmodel_failed");
|
||||
@@ -438,10 +467,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
case 'G': m_llModelType = LLModelType::GPTJ_; break;
|
||||
default:
|
||||
{
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename()));
|
||||
}
|
||||
@@ -451,13 +479,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
}
|
||||
} else {
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError));
|
||||
}
|
||||
}
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
restoreState();
|
||||
#if defined(DEBUG)
|
||||
@@ -471,7 +499,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
|
||||
} else {
|
||||
if (!m_isServer)
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename()));
|
||||
}
|
||||
@@ -480,7 +508,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
setModelInfo(modelInfo);
|
||||
processSystemPrompt();
|
||||
}
|
||||
return m_llModelInfo.model;
|
||||
return bool(m_llModelInfo.model);
|
||||
}
|
||||
|
||||
bool ChatLLM::isModelLoaded() const
|
||||
@@ -700,22 +728,23 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
emit responseChanged(QString::fromStdString(m_response));
|
||||
}
|
||||
emit responseStopped(elapsed);
|
||||
m_pristineLoadedState = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ChatLLM::setShouldBeLoaded(bool b)
|
||||
{
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model;
|
||||
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
|
||||
#endif
|
||||
m_shouldBeLoaded = b; // atomic
|
||||
emit shouldBeLoadedChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::setShouldTrySwitchContext(bool b)
|
||||
void ChatLLM::requestTrySwitchContext()
|
||||
{
|
||||
m_shouldTrySwitchContext = b; // atomic
|
||||
emit shouldTrySwitchContextChanged();
|
||||
m_shouldBeLoaded = true; // atomic
|
||||
emit trySwitchContextRequested(modelInfo());
|
||||
}
|
||||
|
||||
void ChatLLM::handleShouldBeLoadedChanged()
|
||||
@@ -726,12 +755,6 @@ void ChatLLM::handleShouldBeLoadedChanged()
|
||||
unloadModel();
|
||||
}
|
||||
|
||||
void ChatLLM::handleShouldTrySwitchContextChanged()
|
||||
{
|
||||
if (m_shouldTrySwitchContext)
|
||||
trySwitchContextOfLoadedModel(modelInfo());
|
||||
}
|
||||
|
||||
void ChatLLM::unloadModel()
|
||||
{
|
||||
if (!isModelLoaded() || m_isServer)
|
||||
@@ -746,17 +769,16 @@ void ChatLLM::unloadModel()
|
||||
saveState();
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
|
||||
if (m_forceUnloadModel) {
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_llModelInfo.model.reset();
|
||||
m_forceUnloadModel = false;
|
||||
}
|
||||
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
||||
void ChatLLM::reloadModel()
|
||||
@@ -768,7 +790,7 @@ void ChatLLM::reloadModel()
|
||||
return;
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
const ModelInfo m = modelInfo();
|
||||
if (m.name().isEmpty())
|
||||
@@ -795,6 +817,7 @@ void ChatLLM::generateName()
|
||||
m_nameResponse = trimmed;
|
||||
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
|
||||
}
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
||||
void ChatLLM::handleChatIdChanged(const QString &id)
|
||||
@@ -934,7 +957,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
|
||||
// text only. This will be a costly operation, but the chat has to be restored from the text archive
|
||||
// alone.
|
||||
m_restoreStateFromText = !deserializeKV || discardKV;
|
||||
if (!deserializeKV || discardKV) {
|
||||
m_restoreStateFromText = true;
|
||||
m_pristineLoadedState = true;
|
||||
}
|
||||
|
||||
if (!deserializeKV) {
|
||||
#if defined(DEBUG)
|
||||
@@ -998,14 +1024,14 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
|
||||
void ChatLLM::saveState()
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
if (!isModelLoaded() || m_pristineLoadedState)
|
||||
return;
|
||||
|
||||
if (m_llModelType == LLModelType::API_) {
|
||||
m_state.clear();
|
||||
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
|
||||
stream << chatAPI->context();
|
||||
return;
|
||||
}
|
||||
@@ -1026,7 +1052,7 @@ void ChatLLM::restoreState()
|
||||
if (m_llModelType == LLModelType::API_) {
|
||||
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
|
||||
QList<QString> context;
|
||||
stream >> context;
|
||||
chatAPI->setContext(context);
|
||||
@@ -1045,13 +1071,18 @@ void ChatLLM::restoreState()
|
||||
if (m_llModelInfo.model->stateSize() == m_state.size()) {
|
||||
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
||||
m_processedSystemPrompt = true;
|
||||
m_pristineLoadedState = true;
|
||||
} else {
|
||||
qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size();
|
||||
m_restoreStateFromText = true;
|
||||
}
|
||||
|
||||
m_state.clear();
|
||||
m_state.squeeze();
|
||||
// free local state copy unless unload is pending
|
||||
if (m_shouldBeLoaded) {
|
||||
m_state.clear();
|
||||
m_state.squeeze();
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
}
|
||||
|
||||
void ChatLLM::processSystemPrompt()
|
||||
@@ -1105,6 +1136,7 @@ void ChatLLM::processSystemPrompt()
|
||||
#endif
|
||||
|
||||
m_processedSystemPrompt = m_stopGenerating == false;
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
||||
void ChatLLM::processRestoreStateFromText()
|
||||
@@ -1163,4 +1195,6 @@ void ChatLLM::processRestoreStateFromText()
|
||||
|
||||
m_isRecalc = false;
|
||||
emit recalcChanged();
|
||||
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
Reference in New Issue
Block a user