Make it possible to keep some chats after downgrading GPT4All (#3030)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-10-04 14:25:17 -04:00 committed by GitHub
parent b850e7c867
commit ec4e1e4812
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 141 additions and 42 deletions

View File

@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996))
- Fix "regenerate" always forgetting the most recent message ([#3011](https://github.com/nomic-ai/gpt4all/pull/3011))
- Fix loaded chats forgetting context when there is a system prompt ([#3015](https://github.com/nomic-ai/gpt4all/pull/3015))
- Make it possible to downgrade and keep some chats, and avoid crash for some model types ([#3030](https://github.com/nomic-ai/gpt4all/pull/3030))
## [3.3.1] - 2024-09-27 ([v3.3.y](https://github.com/nomic-ai/gpt4all/tree/v3.3.y))

View File

@ -101,6 +101,7 @@ void Chat::reset()
// is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue.
m_chatModel->clear();
m_needsSave = true;
}
void Chat::processSystemPrompt()
@ -163,6 +164,7 @@ void Chat::prompt(const QString &prompt)
{
resetResponseState();
emit promptRequested(m_collections, prompt);
m_needsSave = true;
}
void Chat::regenerateResponse()
@ -170,6 +172,7 @@ void Chat::regenerateResponse()
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, QList<ResultInfo>());
emit regenerateResponseRequested();
m_needsSave = true;
}
void Chat::stopGenerating()
@ -224,7 +227,7 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
void Chat::promptProcessing()
{
m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged();
emit responseStateChanged();
}
void Chat::generatingQuestions()
@ -261,10 +264,12 @@ ModelInfo Chat::modelInfo() const
void Chat::setModelInfo(const ModelInfo &modelInfo)
{
if (m_modelInfo == modelInfo && isModelLoaded())
if (m_modelInfo != modelInfo) {
m_modelInfo = modelInfo;
m_needsSave = true;
} else if (isModelLoaded())
return;
m_modelInfo = modelInfo;
emit modelInfoChanged();
emit modelChangeRequested(modelInfo);
}
@ -343,12 +348,14 @@ void Chat::generatedNameChanged(const QString &name)
int wordCount = qMin(7, words.size());
m_name = words.mid(0, wordCount).join(' ');
emit nameChanged();
m_needsSave = true;
}
void Chat::generatedQuestionFinished(const QString &question)
{
m_generatedQuestions << question;
emit generatedQuestionsChanged();
m_needsSave = true;
}
void Chat::handleRestoringFromText()
@ -393,6 +400,7 @@ void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
m_databaseResults = results;
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, m_databaseResults);
m_needsSave = true;
}
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
@ -402,6 +410,7 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
m_modelInfo = modelInfo;
emit modelInfoChanged();
m_needsSave = true;
}
void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value)
@ -416,15 +425,15 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_id;
stream << m_name;
stream << m_userName;
if (version > 4)
if (version >= 5)
stream << m_modelInfo.id();
else
stream << m_modelInfo.filename();
if (version > 2)
if (version >= 3)
stream << m_collections;
const bool serializeKV = MySettings::globalInstance()->saveChatsContext();
if (version > 5)
if (version >= 6)
stream << serializeKV;
if (!m_llmodel->serialize(stream, version, serializeKV))
return false;
@ -445,7 +454,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
QString modelId;
stream >> modelId;
if (version > 4) {
if (version >= 5) {
if (ModelList::globalInstance()->contains(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
} else {
@ -457,13 +466,13 @@ bool Chat::deserialize(QDataStream &stream, int version)
bool discardKV = m_modelInfo.id().isEmpty();
if (version > 2) {
if (version >= 3) {
stream >> m_collections;
emit collectionListChanged(m_collections);
}
bool deserializeKV = true;
if (version > 5)
if (version >= 6)
stream >> deserializeKV;
m_llmodel->setModelInfo(m_modelInfo);
@ -473,7 +482,11 @@ bool Chat::deserialize(QDataStream &stream, int version)
return false;
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
if (stream.status() != QDataStream::Ok)
return false;
m_needsSave = false;
return true;
}
QList<QString> Chat::collectionList() const
@ -493,6 +506,7 @@ void Chat::addCollection(const QString &collection)
m_collections.append(collection);
emit collectionListChanged(m_collections);
m_needsSave = true;
}
void Chat::removeCollection(const QString &collection)
@ -502,4 +516,5 @@ void Chat::removeCollection(const QString &collection)
m_collections.removeAll(collection);
emit collectionListChanged(m_collections);
m_needsSave = true;
}

View File

@ -67,6 +67,7 @@ public:
{
m_userName = name;
emit nameChanged();
m_needsSave = true;
}
ChatModel *chatModel() { return m_chatModel; }
@ -124,6 +125,8 @@ public:
QList<QString> generatedQuestions() const { return m_generatedQuestions; }
bool needsSave() const { return m_needsSave; }
public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
@ -203,6 +206,10 @@ private:
bool m_firstResponse = true;
int m_trySwitchContextInProgress = 0;
bool m_isCurrentlyLoading = false;
// True if we need to serialize the chat to disk, because of one of two reasons:
// - The chat was freshly created during this launch.
// - The chat was changed after loading it from disk.
bool m_needsSave = true;
};
#endif // CHAT_H

View File

@ -99,7 +99,12 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
QElapsedTimer timer;
timer.start();
const QString savePath = MySettings::globalInstance()->modelPath();
qsizetype nSavedChats = 0;
for (Chat *chat : chats) {
if (!chat->needsSave())
continue;
++nSavedChats;
QString fileName = "gpt4all-" + chat->id() + ".chat";
QString filePath = savePath + "/" + fileName;
QFile originalFile(filePath);
@ -129,7 +134,7 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
}
qint64 elapsedTime = timer.elapsed();
qDebug() << "serializing chats took:" << elapsedTime << "ms";
qDebug() << "serializing chats took" << elapsedTime << "ms, saved" << nSavedChats << "/" << chats.size() << "chats";
emit saveChatsFinished();
}
@ -194,11 +199,16 @@ void ChatsRestoreThread::run()
qint32 version;
in >> version;
if (version < 1) {
qWarning() << "ERROR: Chat file has non supported version:" << file.fileName();
qWarning() << "WARNING: Chat file version" << version << "is not supported:" << file.fileName();
continue;
}
if (version > CHAT_FORMAT_VERSION) {
qWarning().nospace() << "WARNING: Chat file is from a future version (have " << version << " want "
<< CHAT_FORMAT_VERSION << "): " << file.fileName();
continue;
}
if (version <= 1)
if (version < 2)
in.setVersion(QDataStream::Qt_6_2);
FileInfo info;
@ -239,7 +249,7 @@ void ChatsRestoreThread::run()
continue;
}
if (version <= 1)
if (version < 2)
in.setVersion(QDataStream::Qt_6_2);
}

View File

@ -42,8 +42,8 @@ using namespace Qt::Literals::StringLiterals;
//#define DEBUG
//#define DEBUG_MODEL_LOADING
#define GPTJ_INTERNAL_STATE_VERSION 0 // GPT-J is gone but old chats still use this
#define LLAMA_INTERNAL_STATE_VERSION 0
static constexpr int LLAMA_INTERNAL_STATE_VERSION = 0;
static constexpr int API_INTERNAL_STATE_VERSION = 0;
class LLModelStore {
public:
@ -105,6 +105,7 @@ void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) {
ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr}
, m_chat(parent)
, m_promptResponseTokens(0)
, m_promptTokens(0)
, m_restoringFromText(false)
@ -355,7 +356,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
requestUrl = modelInfo.url();
}
}
m_llModelType = LLModelType::API_;
m_llModelType = LLModelTypeV1::API;
ChatAPI *model = new ChatAPI();
model->setModelName(modelName);
model->setRequestURL(requestUrl);
@ -571,7 +572,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
}
switch (m_llModelInfo.model->implementation().modelType()[0]) {
case 'L': m_llModelType = LLModelType::LLAMA_; break;
case 'L': m_llModelType = LLModelTypeV1::LLAMA; break;
default:
{
m_llModelInfo.resetModel(this);
@ -624,7 +625,7 @@ void ChatLLM::regenerateResponse()
{
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens.
if (m_llModelType == LLModelType::API_)
if (m_llModelType == LLModelTypeV1::API)
m_ctx.n_past -= 1;
else
m_ctx.n_past -= m_promptResponseTokens;
@ -1045,12 +1046,18 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
// we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
if (version >= 2) {
if (m_llModelType == LLModelTypeV1::NONE) {
qWarning() << "ChatLLM ERROR: attempted to serialize a null model for chat id" << m_chat->id()
<< "name" << m_chat->name();
return false;
}
stream << m_llModelType;
switch (m_llModelType) {
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
default: Q_UNREACHABLE();
case LLModelTypeV1::LLAMA: stream << LLAMA_INTERNAL_STATE_VERSION; break;
case LLModelTypeV1::API: stream << API_INTERNAL_STATE_VERSION; break;
default: stream << 0; // models removed in v2.5.0
}
}
stream << response();
@ -1064,7 +1071,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
return stream.status() == QDataStream::Ok;
}
if (version <= 3) {
if (version < 4) {
int responseLogits = 0;
stream << responseLogits;
}
@ -1085,10 +1092,20 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{
if (version > 1) {
int internalStateVersion;
stream >> m_llModelType;
stream >> internalStateVersion; // for future use
if (version >= 2) {
int llModelType;
stream >> llModelType;
m_llModelType = (version >= 6 ? parseLLModelTypeV1 : parseLLModelTypeV0)(llModelType);
if (m_llModelType == LLModelTypeV1::NONE) {
qWarning().nospace() << "error loading chat id " << m_chat->id() << ": unrecognized model type: "
<< llModelType;
return false;
}
/* note: prior to chat version 10, API models and chats with models removed in v2.5.0 only wrote this because of
* undefined behavior in Release builds */
int internalStateVersion; // for future use
stream >> internalStateVersion;
}
QString response;
stream >> response;
@ -1114,7 +1131,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
return stream.status() == QDataStream::Ok;
}
if (version <= 3) {
if (version < 4) {
int responseLogits;
stream >> responseLogits;
}
@ -1144,7 +1161,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
stream.skipRawData(tokensSize * sizeof(int));
}
if (version > 0) {
if (version >= 1) {
QByteArray compressed;
stream >> compressed;
if (!discardKV)
@ -1169,7 +1186,7 @@ void ChatLLM::saveState()
if (!isModelLoaded() || m_pristineLoadedState)
return;
if (m_llModelType == LLModelType::API_) {
if (m_llModelType == LLModelTypeV1::API) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4);
@ -1197,7 +1214,7 @@ void ChatLLM::restoreState()
if (!isModelLoaded())
return;
if (m_llModelType == LLModelType::API_) {
if (m_llModelType == LLModelTypeV1::API) {
QDataStream stream(m_state);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());

View File

@ -28,12 +28,60 @@ using namespace Qt::Literals::StringLiterals;
class QDataStream;
// NOTE: values serialized to disk, do not change or reuse
enum LLModelType {
GPTJ_ = 0, // no longer used
LLAMA_ = 1,
API_ = 2,
BERT_ = 3, // no longer used
enum class LLModelTypeV0 { // chat versions 2-5
MPT = 0,
GPTJ = 1,
LLAMA = 2,
CHATGPT = 3,
REPLIT = 4,
FALCON = 5,
BERT = 6, // not used
STARCODER = 7,
};
enum class LLModelTypeV1 { // since chat version 6 (v2.5.0)
GPTJ = 0, // not for new chats
LLAMA = 1,
API = 2,
BERT = 3, // not used
// none of the below are used in new chats
REPLIT = 4,
FALCON = 5,
MPT = 6,
STARCODER = 7,
NONE = -1, // no state
};
static LLModelTypeV1 parseLLModelTypeV1(int type)
{
switch (LLModelTypeV1(type)) {
case LLModelTypeV1::GPTJ:
case LLModelTypeV1::LLAMA:
case LLModelTypeV1::API:
// case LLModelTypeV1::BERT: -- not used
case LLModelTypeV1::REPLIT:
case LLModelTypeV1::FALCON:
case LLModelTypeV1::MPT:
case LLModelTypeV1::STARCODER:
return LLModelTypeV1(type);
default:
return LLModelTypeV1::NONE;
}
}
static LLModelTypeV1 parseLLModelTypeV0(int v0)
{
switch (LLModelTypeV0(v0)) {
case LLModelTypeV0::MPT: return LLModelTypeV1::MPT;
case LLModelTypeV0::GPTJ: return LLModelTypeV1::GPTJ;
case LLModelTypeV0::LLAMA: return LLModelTypeV1::LLAMA;
case LLModelTypeV0::CHATGPT: return LLModelTypeV1::API;
case LLModelTypeV0::REPLIT: return LLModelTypeV1::REPLIT;
case LLModelTypeV0::FALCON: return LLModelTypeV1::FALCON;
// case LLModelTypeV0::BERT: -- not used
case LLModelTypeV0::STARCODER: return LLModelTypeV1::STARCODER;
default: return LLModelTypeV1::NONE;
}
}
class ChatLLM;
class ChatModel;
@ -219,12 +267,13 @@ protected:
private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
const Chat *m_chat;
std::string m_response;
std::string m_trimmedResponse;
std::string m_nameResponse;
QString m_questionResponse;
LLModelInfo m_llModelInfo;
LLModelType m_llModelType;
LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE;
ModelInfo m_modelInfo;
TokenTimer *m_timer;
QByteArray m_state;

View File

@ -386,7 +386,7 @@ public:
stream << c.stopped;
stream << c.thumbsUpState;
stream << c.thumbsDownState;
if (version > 7) {
if (version >= 8) {
stream << c.sources.size();
for (const ResultInfo &info : c.sources) {
Q_ASSERT(!info.file.isEmpty());
@ -401,7 +401,7 @@ public:
stream << info.from;
stream << info.to;
}
} else if (version > 2) {
} else if (version >= 3) {
QList<QString> references;
QList<QString> referencesContext;
int validReferenceNumber = 1;
@ -468,7 +468,7 @@ public:
stream >> c.stopped;
stream >> c.thumbsUpState;
stream >> c.thumbsDownState;
if (version > 7) {
if (version >= 8) {
qsizetype count;
stream >> count;
QList<ResultInfo> sources;
@ -488,7 +488,7 @@ public:
}
c.sources = sources;
c.consolidatedSources = consolidateSources(sources);
} else if (version > 2) {
} else if (version >= 3) {
QString references;
QList<QString> referencesContext;
stream >> references;