diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index ebd2633b..21d2ce5b 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996)) - Fix "regenerate" always forgetting the most recent message ([#3011](https://github.com/nomic-ai/gpt4all/pull/3011)) - Fix loaded chats forgetting context when there is a system prompt ([#3015](https://github.com/nomic-ai/gpt4all/pull/3015)) +- Make it possible to downgrade and keep some chats, and avoid crash for some model types ([#3030](https://github.com/nomic-ai/gpt4all/pull/3030)) ## [3.3.1] - 2024-09-27 ([v3.3.y](https://github.com/nomic-ai/gpt4all/tree/v3.3.y)) diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index 5eb18473..347a9cd2 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -101,6 +101,7 @@ void Chat::reset() // is to allow switching models but throwing up a dialog warning users if we switch between types // of models that a long recalculation will ensue. m_chatModel->clear(); + m_needsSave = true; } void Chat::processSystemPrompt() @@ -163,6 +164,7 @@ void Chat::prompt(const QString &prompt) { resetResponseState(); emit promptRequested(m_collections, prompt); + m_needsSave = true; } void Chat::regenerateResponse() @@ -170,6 +172,7 @@ void Chat::regenerateResponse() const int index = m_chatModel->count() - 1; m_chatModel->updateSources(index, QList()); emit regenerateResponseRequested(); + m_needsSave = true; } void Chat::stopGenerating() @@ -224,7 +227,7 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) void Chat::promptProcessing() { m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; - emit responseStateChanged(); + emit responseStateChanged(); } void Chat::generatingQuestions() @@ -261,10 +264,12 @@ ModelInfo Chat::modelInfo() const void Chat::setModelInfo(const ModelInfo &modelInfo) { - if (m_modelInfo == modelInfo && isModelLoaded()) + if (m_modelInfo != modelInfo) { + m_modelInfo = modelInfo; + m_needsSave = true; + } else if (isModelLoaded()) return; - m_modelInfo = modelInfo; emit modelInfoChanged(); emit modelChangeRequested(modelInfo); } @@ -343,12 +348,14 @@ void Chat::generatedNameChanged(const QString &name) int wordCount = qMin(7, words.size()); m_name = words.mid(0, wordCount).join(' '); emit nameChanged(); + m_needsSave = true; } void Chat::generatedQuestionFinished(const QString &question) { m_generatedQuestions << question; emit generatedQuestionsChanged(); + m_needsSave = true; } void Chat::handleRestoringFromText() @@ -393,6 +400,7 @@ void Chat::handleDatabaseResultsChanged(const QList &results) m_databaseResults = results; const int index = m_chatModel->count() - 1; m_chatModel->updateSources(index, m_databaseResults); + m_needsSave = true; } void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) @@ -402,6 +410,7 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) m_modelInfo = modelInfo; emit modelInfoChanged(); + m_needsSave = true; } void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) @@ -416,15 +425,15 @@ bool Chat::serialize(QDataStream &stream, int version) const stream << m_id; stream << m_name; stream << m_userName; - if (version > 4) + if (version >= 5) stream << m_modelInfo.id(); else stream << m_modelInfo.filename(); - if (version > 2) + if (version >= 3) stream << m_collections; const bool serializeKV = MySettings::globalInstance()->saveChatsContext(); - if (version > 5) + if (version >= 6) stream << serializeKV; if (!m_llmodel->serialize(stream, version, serializeKV)) return false; @@ -445,7 +454,7 @@ bool Chat::deserialize(QDataStream &stream, int version) QString modelId; stream >> modelId; - if (version > 4) { + if (version >= 5) { if (ModelList::globalInstance()->contains(modelId)) m_modelInfo = ModelList::globalInstance()->modelInfo(modelId); } else { @@ -457,13 +466,13 @@ bool Chat::deserialize(QDataStream &stream, int version) bool discardKV = m_modelInfo.id().isEmpty(); - if (version > 2) { + if (version >= 3) { stream >> m_collections; emit collectionListChanged(m_collections); } bool deserializeKV = true; - if (version > 5) + if (version >= 6) stream >> deserializeKV; m_llmodel->setModelInfo(m_modelInfo); @@ -473,7 +482,11 @@ bool Chat::deserialize(QDataStream &stream, int version) return false; emit chatModelChanged(); - return stream.status() == QDataStream::Ok; + if (stream.status() != QDataStream::Ok) + return false; + + m_needsSave = false; + return true; } QList Chat::collectionList() const @@ -493,6 +506,7 @@ void Chat::addCollection(const QString &collection) m_collections.append(collection); emit collectionListChanged(m_collections); + m_needsSave = true; } void Chat::removeCollection(const QString &collection) @@ -502,4 +516,5 @@ void Chat::removeCollection(const QString &collection) m_collections.removeAll(collection); emit collectionListChanged(m_collections); + m_needsSave = true; } diff --git a/gpt4all-chat/src/chat.h b/gpt4all-chat/src/chat.h index b7caeb0c..da644c89 100644 --- a/gpt4all-chat/src/chat.h +++ b/gpt4all-chat/src/chat.h @@ -67,6 +67,7 @@ public: { m_userName = name; emit nameChanged(); + m_needsSave = true; } ChatModel *chatModel() { return m_chatModel; } @@ -124,6 +125,8 @@ public: QList generatedQuestions() const { return m_generatedQuestions; } + bool needsSave() const { return m_needsSave; } + public Q_SLOTS: void serverNewPromptResponsePair(const QString &prompt, const QList &attachments = {}); @@ -203,6 +206,10 @@ private: bool m_firstResponse = true; int m_trySwitchContextInProgress = 0; bool m_isCurrentlyLoading = false; + // True if we need to serialize the chat to disk, because of one of two reasons: + // - The chat was freshly created during this launch. + // - The chat was changed after loading it from disk. + bool m_needsSave = true; }; #endif // CHAT_H diff --git a/gpt4all-chat/src/chatlistmodel.cpp b/gpt4all-chat/src/chatlistmodel.cpp index c5be4338..fb5ef68e 100644 --- a/gpt4all-chat/src/chatlistmodel.cpp +++ b/gpt4all-chat/src/chatlistmodel.cpp @@ -99,7 +99,12 @@ void ChatSaver::saveChats(const QVector &chats) QElapsedTimer timer; timer.start(); const QString savePath = MySettings::globalInstance()->modelPath(); + qsizetype nSavedChats = 0; for (Chat *chat : chats) { + if (!chat->needsSave()) + continue; + ++nSavedChats; + QString fileName = "gpt4all-" + chat->id() + ".chat"; QString filePath = savePath + "/" + fileName; QFile originalFile(filePath); @@ -129,7 +134,7 @@ void ChatSaver::saveChats(const QVector &chats) } qint64 elapsedTime = timer.elapsed(); - qDebug() << "serializing chats took:" << elapsedTime << "ms"; + qDebug() << "serializing chats took" << elapsedTime << "ms, saved" << nSavedChats << "/" << chats.size() << "chats"; emit saveChatsFinished(); } @@ -194,11 +199,16 @@ void ChatsRestoreThread::run() qint32 version; in >> version; if (version < 1) { - qWarning() << "ERROR: Chat file has non supported version:" << file.fileName(); + qWarning() << "WARNING: Chat file version" << version << "is not supported:" << file.fileName(); + continue; + } + if (version > CHAT_FORMAT_VERSION) { + qWarning().nospace() << "WARNING: Chat file is from a future version (have " << version << " want " + << CHAT_FORMAT_VERSION << "): " << file.fileName(); continue; } - if (version <= 1) + if (version < 2) in.setVersion(QDataStream::Qt_6_2); FileInfo info; @@ -239,7 +249,7 @@ void ChatsRestoreThread::run() continue; } - if (version <= 1) + if (version < 2) in.setVersion(QDataStream::Qt_6_2); } diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index 91fabe89..36c3a269 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -42,8 +42,8 @@ using namespace Qt::Literals::StringLiterals; //#define DEBUG //#define DEBUG_MODEL_LOADING -#define GPTJ_INTERNAL_STATE_VERSION 0 // GPT-J is gone but old chats still use this -#define LLAMA_INTERNAL_STATE_VERSION 0 +static constexpr int LLAMA_INTERNAL_STATE_VERSION = 0; +static constexpr int API_INTERNAL_STATE_VERSION = 0; class LLModelStore { public: @@ -105,6 +105,7 @@ void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) { ChatLLM::ChatLLM(Chat *parent, bool isServer) : QObject{nullptr} + , m_chat(parent) , m_promptResponseTokens(0) , m_promptTokens(0) , m_restoringFromText(false) @@ -355,7 +356,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) requestUrl = modelInfo.url(); } } - m_llModelType = LLModelType::API_; + m_llModelType = LLModelTypeV1::API; ChatAPI *model = new ChatAPI(); model->setModelName(modelName); model->setRequestURL(requestUrl); @@ -571,7 +572,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro } switch (m_llModelInfo.model->implementation().modelType()[0]) { - case 'L': m_llModelType = LLModelType::LLAMA_; break; + case 'L': m_llModelType = LLModelTypeV1::LLAMA; break; default: { m_llModelInfo.resetModel(this); @@ -624,7 +625,7 @@ void ChatLLM::regenerateResponse() { // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning // of n_past is of the number of prompt/response pairs, rather than for total tokens. - if (m_llModelType == LLModelType::API_) + if (m_llModelType == LLModelTypeV1::API) m_ctx.n_past -= 1; else m_ctx.n_past -= m_promptResponseTokens; @@ -1045,12 +1046,18 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) // we want to also serialize n_ctx, and read it at load time. bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) { - if (version > 1) { + if (version >= 2) { + if (m_llModelType == LLModelTypeV1::NONE) { + qWarning() << "ChatLLM ERROR: attempted to serialize a null model for chat id" << m_chat->id() + << "name" << m_chat->name(); + return false; + } + stream << m_llModelType; switch (m_llModelType) { - case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break; - case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break; - default: Q_UNREACHABLE(); + case LLModelTypeV1::LLAMA: stream << LLAMA_INTERNAL_STATE_VERSION; break; + case LLModelTypeV1::API: stream << API_INTERNAL_STATE_VERSION; break; + default: stream << 0; // models removed in v2.5.0 } } stream << response(); @@ -1064,7 +1071,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) return stream.status() == QDataStream::Ok; } - if (version <= 3) { + if (version < 4) { int responseLogits = 0; stream << responseLogits; } @@ -1085,10 +1092,20 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) { - if (version > 1) { - int internalStateVersion; - stream >> m_llModelType; - stream >> internalStateVersion; // for future use + if (version >= 2) { + int llModelType; + stream >> llModelType; + m_llModelType = (version >= 6 ? parseLLModelTypeV1 : parseLLModelTypeV0)(llModelType); + if (m_llModelType == LLModelTypeV1::NONE) { + qWarning().nospace() << "error loading chat id " << m_chat->id() << ": unrecognized model type: " + << llModelType; + return false; + } + + /* note: prior to chat version 10, API models and chats with models removed in v2.5.0 only wrote this because of + * undefined behavior in Release builds */ + int internalStateVersion; // for future use + stream >> internalStateVersion; } QString response; stream >> response; @@ -1114,7 +1131,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, return stream.status() == QDataStream::Ok; } - if (version <= 3) { + if (version < 4) { int responseLogits; stream >> responseLogits; } @@ -1144,7 +1161,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, stream.skipRawData(tokensSize * sizeof(int)); } - if (version > 0) { + if (version >= 1) { QByteArray compressed; stream >> compressed; if (!discardKV) @@ -1169,7 +1186,7 @@ void ChatLLM::saveState() if (!isModelLoaded() || m_pristineLoadedState) return; - if (m_llModelType == LLModelType::API_) { + if (m_llModelType == LLModelTypeV1::API) { m_state.clear(); QDataStream stream(&m_state, QIODeviceBase::WriteOnly); stream.setVersion(QDataStream::Qt_6_4); @@ -1197,7 +1214,7 @@ void ChatLLM::restoreState() if (!isModelLoaded()) return; - if (m_llModelType == LLModelType::API_) { + if (m_llModelType == LLModelTypeV1::API) { QDataStream stream(m_state); stream.setVersion(QDataStream::Qt_6_4); ChatAPI *chatAPI = static_cast(m_llModelInfo.model.get()); diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index c486b06c..201c7779 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -28,12 +28,60 @@ using namespace Qt::Literals::StringLiterals; class QDataStream; // NOTE: values serialized to disk, do not change or reuse -enum LLModelType { - GPTJ_ = 0, // no longer used - LLAMA_ = 1, - API_ = 2, - BERT_ = 3, // no longer used +enum class LLModelTypeV0 { // chat versions 2-5 + MPT = 0, + GPTJ = 1, + LLAMA = 2, + CHATGPT = 3, + REPLIT = 4, + FALCON = 5, + BERT = 6, // not used + STARCODER = 7, }; +enum class LLModelTypeV1 { // since chat version 6 (v2.5.0) + GPTJ = 0, // not for new chats + LLAMA = 1, + API = 2, + BERT = 3, // not used + // none of the below are used in new chats + REPLIT = 4, + FALCON = 5, + MPT = 6, + STARCODER = 7, + NONE = -1, // no state +}; + +static LLModelTypeV1 parseLLModelTypeV1(int type) +{ + switch (LLModelTypeV1(type)) { + case LLModelTypeV1::GPTJ: + case LLModelTypeV1::LLAMA: + case LLModelTypeV1::API: + // case LLModelTypeV1::BERT: -- not used + case LLModelTypeV1::REPLIT: + case LLModelTypeV1::FALCON: + case LLModelTypeV1::MPT: + case LLModelTypeV1::STARCODER: + return LLModelTypeV1(type); + default: + return LLModelTypeV1::NONE; + } +} + +static LLModelTypeV1 parseLLModelTypeV0(int v0) +{ + switch (LLModelTypeV0(v0)) { + case LLModelTypeV0::MPT: return LLModelTypeV1::MPT; + case LLModelTypeV0::GPTJ: return LLModelTypeV1::GPTJ; + case LLModelTypeV0::LLAMA: return LLModelTypeV1::LLAMA; + case LLModelTypeV0::CHATGPT: return LLModelTypeV1::API; + case LLModelTypeV0::REPLIT: return LLModelTypeV1::REPLIT; + case LLModelTypeV0::FALCON: return LLModelTypeV1::FALCON; + // case LLModelTypeV0::BERT: -- not used + case LLModelTypeV0::STARCODER: return LLModelTypeV1::STARCODER; + default: return LLModelTypeV1::NONE; + } +} class ChatLLM; class ChatModel; @@ -219,12 +267,13 @@ protected: private: bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); + const Chat *m_chat; std::string m_response; std::string m_trimmedResponse; std::string m_nameResponse; QString m_questionResponse; LLModelInfo m_llModelInfo; - LLModelType m_llModelType; + LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE; ModelInfo m_modelInfo; TokenTimer *m_timer; QByteArray m_state; diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 6ab3ac8d..d9e93493 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -386,7 +386,7 @@ public: stream << c.stopped; stream << c.thumbsUpState; stream << c.thumbsDownState; - if (version > 7) { + if (version >= 8) { stream << c.sources.size(); for (const ResultInfo &info : c.sources) { Q_ASSERT(!info.file.isEmpty()); @@ -401,7 +401,7 @@ public: stream << info.from; stream << info.to; } - } else if (version > 2) { + } else if (version >= 3) { QList references; QList referencesContext; int validReferenceNumber = 1; @@ -468,7 +468,7 @@ public: stream >> c.stopped; stream >> c.thumbsUpState; stream >> c.thumbsDownState; - if (version > 7) { + if (version >= 8) { qsizetype count; stream >> count; QList sources; @@ -488,7 +488,7 @@ public: } c.sources = sources; c.consolidatedSources = consolidateSources(sources); - } else if (version > 2) { + } else if (version >= 3) { QString references; QList referencesContext; stream >> references;