Make it possible to keep some chats after downgrading GPT4All (#3030)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-10-04 14:25:17 -04:00 committed by GitHub
parent b850e7c867
commit ec4e1e4812
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 141 additions and 42 deletions

View File

@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996)) - Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996))
- Fix "regenerate" always forgetting the most recent message ([#3011](https://github.com/nomic-ai/gpt4all/pull/3011)) - Fix "regenerate" always forgetting the most recent message ([#3011](https://github.com/nomic-ai/gpt4all/pull/3011))
- Fix loaded chats forgetting context when there is a system prompt ([#3015](https://github.com/nomic-ai/gpt4all/pull/3015)) - Fix loaded chats forgetting context when there is a system prompt ([#3015](https://github.com/nomic-ai/gpt4all/pull/3015))
- Make it possible to downgrade and keep some chats, and avoid crash for some model types ([#3030](https://github.com/nomic-ai/gpt4all/pull/3030))
## [3.3.1] - 2024-09-27 ([v3.3.y](https://github.com/nomic-ai/gpt4all/tree/v3.3.y)) ## [3.3.1] - 2024-09-27 ([v3.3.y](https://github.com/nomic-ai/gpt4all/tree/v3.3.y))

View File

@ -101,6 +101,7 @@ void Chat::reset()
// is to allow switching models but throwing up a dialog warning users if we switch between types // is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue. // of models that a long recalculation will ensue.
m_chatModel->clear(); m_chatModel->clear();
m_needsSave = true;
} }
void Chat::processSystemPrompt() void Chat::processSystemPrompt()
@ -163,6 +164,7 @@ void Chat::prompt(const QString &prompt)
{ {
resetResponseState(); resetResponseState();
emit promptRequested(m_collections, prompt); emit promptRequested(m_collections, prompt);
m_needsSave = true;
} }
void Chat::regenerateResponse() void Chat::regenerateResponse()
@ -170,6 +172,7 @@ void Chat::regenerateResponse()
const int index = m_chatModel->count() - 1; const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, QList<ResultInfo>()); m_chatModel->updateSources(index, QList<ResultInfo>());
emit regenerateResponseRequested(); emit regenerateResponseRequested();
m_needsSave = true;
} }
void Chat::stopGenerating() void Chat::stopGenerating()
@ -224,7 +227,7 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
void Chat::promptProcessing() void Chat::promptProcessing()
{ {
m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged(); emit responseStateChanged();
} }
void Chat::generatingQuestions() void Chat::generatingQuestions()
@ -261,10 +264,12 @@ ModelInfo Chat::modelInfo() const
void Chat::setModelInfo(const ModelInfo &modelInfo) void Chat::setModelInfo(const ModelInfo &modelInfo)
{ {
if (m_modelInfo == modelInfo && isModelLoaded()) if (m_modelInfo != modelInfo) {
m_modelInfo = modelInfo;
m_needsSave = true;
} else if (isModelLoaded())
return; return;
m_modelInfo = modelInfo;
emit modelInfoChanged(); emit modelInfoChanged();
emit modelChangeRequested(modelInfo); emit modelChangeRequested(modelInfo);
} }
@ -343,12 +348,14 @@ void Chat::generatedNameChanged(const QString &name)
int wordCount = qMin(7, words.size()); int wordCount = qMin(7, words.size());
m_name = words.mid(0, wordCount).join(' '); m_name = words.mid(0, wordCount).join(' ');
emit nameChanged(); emit nameChanged();
m_needsSave = true;
} }
void Chat::generatedQuestionFinished(const QString &question) void Chat::generatedQuestionFinished(const QString &question)
{ {
m_generatedQuestions << question; m_generatedQuestions << question;
emit generatedQuestionsChanged(); emit generatedQuestionsChanged();
m_needsSave = true;
} }
void Chat::handleRestoringFromText() void Chat::handleRestoringFromText()
@ -393,6 +400,7 @@ void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
m_databaseResults = results; m_databaseResults = results;
const int index = m_chatModel->count() - 1; const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, m_databaseResults); m_chatModel->updateSources(index, m_databaseResults);
m_needsSave = true;
} }
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
@ -402,6 +410,7 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
m_modelInfo = modelInfo; m_modelInfo = modelInfo;
emit modelInfoChanged(); emit modelInfoChanged();
m_needsSave = true;
} }
void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value)
@ -416,15 +425,15 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_id; stream << m_id;
stream << m_name; stream << m_name;
stream << m_userName; stream << m_userName;
if (version > 4) if (version >= 5)
stream << m_modelInfo.id(); stream << m_modelInfo.id();
else else
stream << m_modelInfo.filename(); stream << m_modelInfo.filename();
if (version > 2) if (version >= 3)
stream << m_collections; stream << m_collections;
const bool serializeKV = MySettings::globalInstance()->saveChatsContext(); const bool serializeKV = MySettings::globalInstance()->saveChatsContext();
if (version > 5) if (version >= 6)
stream << serializeKV; stream << serializeKV;
if (!m_llmodel->serialize(stream, version, serializeKV)) if (!m_llmodel->serialize(stream, version, serializeKV))
return false; return false;
@ -445,7 +454,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
QString modelId; QString modelId;
stream >> modelId; stream >> modelId;
if (version > 4) { if (version >= 5) {
if (ModelList::globalInstance()->contains(modelId)) if (ModelList::globalInstance()->contains(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId); m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
} else { } else {
@ -457,13 +466,13 @@ bool Chat::deserialize(QDataStream &stream, int version)
bool discardKV = m_modelInfo.id().isEmpty(); bool discardKV = m_modelInfo.id().isEmpty();
if (version > 2) { if (version >= 3) {
stream >> m_collections; stream >> m_collections;
emit collectionListChanged(m_collections); emit collectionListChanged(m_collections);
} }
bool deserializeKV = true; bool deserializeKV = true;
if (version > 5) if (version >= 6)
stream >> deserializeKV; stream >> deserializeKV;
m_llmodel->setModelInfo(m_modelInfo); m_llmodel->setModelInfo(m_modelInfo);
@ -473,7 +482,11 @@ bool Chat::deserialize(QDataStream &stream, int version)
return false; return false;
emit chatModelChanged(); emit chatModelChanged();
return stream.status() == QDataStream::Ok; if (stream.status() != QDataStream::Ok)
return false;
m_needsSave = false;
return true;
} }
QList<QString> Chat::collectionList() const QList<QString> Chat::collectionList() const
@ -493,6 +506,7 @@ void Chat::addCollection(const QString &collection)
m_collections.append(collection); m_collections.append(collection);
emit collectionListChanged(m_collections); emit collectionListChanged(m_collections);
m_needsSave = true;
} }
void Chat::removeCollection(const QString &collection) void Chat::removeCollection(const QString &collection)
@ -502,4 +516,5 @@ void Chat::removeCollection(const QString &collection)
m_collections.removeAll(collection); m_collections.removeAll(collection);
emit collectionListChanged(m_collections); emit collectionListChanged(m_collections);
m_needsSave = true;
} }

View File

@ -67,6 +67,7 @@ public:
{ {
m_userName = name; m_userName = name;
emit nameChanged(); emit nameChanged();
m_needsSave = true;
} }
ChatModel *chatModel() { return m_chatModel; } ChatModel *chatModel() { return m_chatModel; }
@ -124,6 +125,8 @@ public:
QList<QString> generatedQuestions() const { return m_generatedQuestions; } QList<QString> generatedQuestions() const { return m_generatedQuestions; }
bool needsSave() const { return m_needsSave; }
public Q_SLOTS: public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {}); void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
@ -203,6 +206,10 @@ private:
bool m_firstResponse = true; bool m_firstResponse = true;
int m_trySwitchContextInProgress = 0; int m_trySwitchContextInProgress = 0;
bool m_isCurrentlyLoading = false; bool m_isCurrentlyLoading = false;
// True if we need to serialize the chat to disk, because of one of two reasons:
// - The chat was freshly created during this launch.
// - The chat was changed after loading it from disk.
bool m_needsSave = true;
}; };
#endif // CHAT_H #endif // CHAT_H

View File

@ -99,7 +99,12 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
QElapsedTimer timer; QElapsedTimer timer;
timer.start(); timer.start();
const QString savePath = MySettings::globalInstance()->modelPath(); const QString savePath = MySettings::globalInstance()->modelPath();
qsizetype nSavedChats = 0;
for (Chat *chat : chats) { for (Chat *chat : chats) {
if (!chat->needsSave())
continue;
++nSavedChats;
QString fileName = "gpt4all-" + chat->id() + ".chat"; QString fileName = "gpt4all-" + chat->id() + ".chat";
QString filePath = savePath + "/" + fileName; QString filePath = savePath + "/" + fileName;
QFile originalFile(filePath); QFile originalFile(filePath);
@ -129,7 +134,7 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
} }
qint64 elapsedTime = timer.elapsed(); qint64 elapsedTime = timer.elapsed();
qDebug() << "serializing chats took:" << elapsedTime << "ms"; qDebug() << "serializing chats took" << elapsedTime << "ms, saved" << nSavedChats << "/" << chats.size() << "chats";
emit saveChatsFinished(); emit saveChatsFinished();
} }
@ -194,11 +199,16 @@ void ChatsRestoreThread::run()
qint32 version; qint32 version;
in >> version; in >> version;
if (version < 1) { if (version < 1) {
qWarning() << "ERROR: Chat file has non supported version:" << file.fileName(); qWarning() << "WARNING: Chat file version" << version << "is not supported:" << file.fileName();
continue;
}
if (version > CHAT_FORMAT_VERSION) {
qWarning().nospace() << "WARNING: Chat file is from a future version (have " << version << " want "
<< CHAT_FORMAT_VERSION << "): " << file.fileName();
continue; continue;
} }
if (version <= 1) if (version < 2)
in.setVersion(QDataStream::Qt_6_2); in.setVersion(QDataStream::Qt_6_2);
FileInfo info; FileInfo info;
@ -239,7 +249,7 @@ void ChatsRestoreThread::run()
continue; continue;
} }
if (version <= 1) if (version < 2)
in.setVersion(QDataStream::Qt_6_2); in.setVersion(QDataStream::Qt_6_2);
} }

View File

@ -42,8 +42,8 @@ using namespace Qt::Literals::StringLiterals;
//#define DEBUG //#define DEBUG
//#define DEBUG_MODEL_LOADING //#define DEBUG_MODEL_LOADING
#define GPTJ_INTERNAL_STATE_VERSION 0 // GPT-J is gone but old chats still use this static constexpr int LLAMA_INTERNAL_STATE_VERSION = 0;
#define LLAMA_INTERNAL_STATE_VERSION 0 static constexpr int API_INTERNAL_STATE_VERSION = 0;
class LLModelStore { class LLModelStore {
public: public:
@ -105,6 +105,7 @@ void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) {
ChatLLM::ChatLLM(Chat *parent, bool isServer) ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr} : QObject{nullptr}
, m_chat(parent)
, m_promptResponseTokens(0) , m_promptResponseTokens(0)
, m_promptTokens(0) , m_promptTokens(0)
, m_restoringFromText(false) , m_restoringFromText(false)
@ -355,7 +356,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
requestUrl = modelInfo.url(); requestUrl = modelInfo.url();
} }
} }
m_llModelType = LLModelType::API_; m_llModelType = LLModelTypeV1::API;
ChatAPI *model = new ChatAPI(); ChatAPI *model = new ChatAPI();
model->setModelName(modelName); model->setModelName(modelName);
model->setRequestURL(requestUrl); model->setRequestURL(requestUrl);
@ -571,7 +572,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
} }
switch (m_llModelInfo.model->implementation().modelType()[0]) { switch (m_llModelInfo.model->implementation().modelType()[0]) {
case 'L': m_llModelType = LLModelType::LLAMA_; break; case 'L': m_llModelType = LLModelTypeV1::LLAMA; break;
default: default:
{ {
m_llModelInfo.resetModel(this); m_llModelInfo.resetModel(this);
@ -624,7 +625,7 @@ void ChatLLM::regenerateResponse()
{ {
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens. // of n_past is of the number of prompt/response pairs, rather than for total tokens.
if (m_llModelType == LLModelType::API_) if (m_llModelType == LLModelTypeV1::API)
m_ctx.n_past -= 1; m_ctx.n_past -= 1;
else else
m_ctx.n_past -= m_promptResponseTokens; m_ctx.n_past -= m_promptResponseTokens;
@ -1045,12 +1046,18 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
// we want to also serialize n_ctx, and read it at load time. // we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{ {
if (version > 1) { if (version >= 2) {
if (m_llModelType == LLModelTypeV1::NONE) {
qWarning() << "ChatLLM ERROR: attempted to serialize a null model for chat id" << m_chat->id()
<< "name" << m_chat->name();
return false;
}
stream << m_llModelType; stream << m_llModelType;
switch (m_llModelType) { switch (m_llModelType) {
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break; case LLModelTypeV1::LLAMA: stream << LLAMA_INTERNAL_STATE_VERSION; break;
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break; case LLModelTypeV1::API: stream << API_INTERNAL_STATE_VERSION; break;
default: Q_UNREACHABLE(); default: stream << 0; // models removed in v2.5.0
} }
} }
stream << response(); stream << response();
@ -1064,7 +1071,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
if (version <= 3) { if (version < 4) {
int responseLogits = 0; int responseLogits = 0;
stream << responseLogits; stream << responseLogits;
} }
@ -1085,10 +1092,20 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{ {
if (version > 1) { if (version >= 2) {
int internalStateVersion; int llModelType;
stream >> m_llModelType; stream >> llModelType;
stream >> internalStateVersion; // for future use m_llModelType = (version >= 6 ? parseLLModelTypeV1 : parseLLModelTypeV0)(llModelType);
if (m_llModelType == LLModelTypeV1::NONE) {
qWarning().nospace() << "error loading chat id " << m_chat->id() << ": unrecognized model type: "
<< llModelType;
return false;
}
/* note: prior to chat version 10, API models and chats with models removed in v2.5.0 only wrote this because of
* undefined behavior in Release builds */
int internalStateVersion; // for future use
stream >> internalStateVersion;
} }
QString response; QString response;
stream >> response; stream >> response;
@ -1114,7 +1131,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
if (version <= 3) { if (version < 4) {
int responseLogits; int responseLogits;
stream >> responseLogits; stream >> responseLogits;
} }
@ -1144,7 +1161,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
stream.skipRawData(tokensSize * sizeof(int)); stream.skipRawData(tokensSize * sizeof(int));
} }
if (version > 0) { if (version >= 1) {
QByteArray compressed; QByteArray compressed;
stream >> compressed; stream >> compressed;
if (!discardKV) if (!discardKV)
@ -1169,7 +1186,7 @@ void ChatLLM::saveState()
if (!isModelLoaded() || m_pristineLoadedState) if (!isModelLoaded() || m_pristineLoadedState)
return; return;
if (m_llModelType == LLModelType::API_) { if (m_llModelType == LLModelTypeV1::API) {
m_state.clear(); m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly); QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4); stream.setVersion(QDataStream::Qt_6_4);
@ -1197,7 +1214,7 @@ void ChatLLM::restoreState()
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
if (m_llModelType == LLModelType::API_) { if (m_llModelType == LLModelTypeV1::API) {
QDataStream stream(m_state); QDataStream stream(m_state);
stream.setVersion(QDataStream::Qt_6_4); stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get()); ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());

View File

@ -28,12 +28,60 @@ using namespace Qt::Literals::StringLiterals;
class QDataStream; class QDataStream;
// NOTE: values serialized to disk, do not change or reuse // NOTE: values serialized to disk, do not change or reuse
enum LLModelType { enum class LLModelTypeV0 { // chat versions 2-5
GPTJ_ = 0, // no longer used MPT = 0,
LLAMA_ = 1, GPTJ = 1,
API_ = 2, LLAMA = 2,
BERT_ = 3, // no longer used CHATGPT = 3,
REPLIT = 4,
FALCON = 5,
BERT = 6, // not used
STARCODER = 7,
}; };
enum class LLModelTypeV1 { // since chat version 6 (v2.5.0)
GPTJ = 0, // not for new chats
LLAMA = 1,
API = 2,
BERT = 3, // not used
// none of the below are used in new chats
REPLIT = 4,
FALCON = 5,
MPT = 6,
STARCODER = 7,
NONE = -1, // no state
};
static LLModelTypeV1 parseLLModelTypeV1(int type)
{
switch (LLModelTypeV1(type)) {
case LLModelTypeV1::GPTJ:
case LLModelTypeV1::LLAMA:
case LLModelTypeV1::API:
// case LLModelTypeV1::BERT: -- not used
case LLModelTypeV1::REPLIT:
case LLModelTypeV1::FALCON:
case LLModelTypeV1::MPT:
case LLModelTypeV1::STARCODER:
return LLModelTypeV1(type);
default:
return LLModelTypeV1::NONE;
}
}
static LLModelTypeV1 parseLLModelTypeV0(int v0)
{
switch (LLModelTypeV0(v0)) {
case LLModelTypeV0::MPT: return LLModelTypeV1::MPT;
case LLModelTypeV0::GPTJ: return LLModelTypeV1::GPTJ;
case LLModelTypeV0::LLAMA: return LLModelTypeV1::LLAMA;
case LLModelTypeV0::CHATGPT: return LLModelTypeV1::API;
case LLModelTypeV0::REPLIT: return LLModelTypeV1::REPLIT;
case LLModelTypeV0::FALCON: return LLModelTypeV1::FALCON;
// case LLModelTypeV0::BERT: -- not used
case LLModelTypeV0::STARCODER: return LLModelTypeV1::STARCODER;
default: return LLModelTypeV1::NONE;
}
}
class ChatLLM; class ChatLLM;
class ChatModel; class ChatModel;
@ -219,12 +267,13 @@ protected:
private: private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
const Chat *m_chat;
std::string m_response; std::string m_response;
std::string m_trimmedResponse; std::string m_trimmedResponse;
std::string m_nameResponse; std::string m_nameResponse;
QString m_questionResponse; QString m_questionResponse;
LLModelInfo m_llModelInfo; LLModelInfo m_llModelInfo;
LLModelType m_llModelType; LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE;
ModelInfo m_modelInfo; ModelInfo m_modelInfo;
TokenTimer *m_timer; TokenTimer *m_timer;
QByteArray m_state; QByteArray m_state;

View File

@ -386,7 +386,7 @@ public:
stream << c.stopped; stream << c.stopped;
stream << c.thumbsUpState; stream << c.thumbsUpState;
stream << c.thumbsDownState; stream << c.thumbsDownState;
if (version > 7) { if (version >= 8) {
stream << c.sources.size(); stream << c.sources.size();
for (const ResultInfo &info : c.sources) { for (const ResultInfo &info : c.sources) {
Q_ASSERT(!info.file.isEmpty()); Q_ASSERT(!info.file.isEmpty());
@ -401,7 +401,7 @@ public:
stream << info.from; stream << info.from;
stream << info.to; stream << info.to;
} }
} else if (version > 2) { } else if (version >= 3) {
QList<QString> references; QList<QString> references;
QList<QString> referencesContext; QList<QString> referencesContext;
int validReferenceNumber = 1; int validReferenceNumber = 1;
@ -468,7 +468,7 @@ public:
stream >> c.stopped; stream >> c.stopped;
stream >> c.thumbsUpState; stream >> c.thumbsUpState;
stream >> c.thumbsDownState; stream >> c.thumbsDownState;
if (version > 7) { if (version >= 8) {
qsizetype count; qsizetype count;
stream >> count; stream >> count;
QList<ResultInfo> sources; QList<ResultInfo> sources;
@ -488,7 +488,7 @@ public:
} }
c.sources = sources; c.sources = sources;
c.consolidatedSources = consolidateSources(sources); c.consolidatedSources = consolidateSources(sources);
} else if (version > 2) { } else if (version >= 3) {
QString references; QString references;
QList<QString> referencesContext; QList<QString> referencesContext;
stream >> references; stream >> references;