mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-05 11:36:16 +00:00
Make it possible to keep some chats after downgrading GPT4All (#3030)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
b850e7c867
commit
ec4e1e4812
@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
- Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996))
|
||||
- Fix "regenerate" always forgetting the most recent message ([#3011](https://github.com/nomic-ai/gpt4all/pull/3011))
|
||||
- Fix loaded chats forgetting context when there is a system prompt ([#3015](https://github.com/nomic-ai/gpt4all/pull/3015))
|
||||
- Make it possible to downgrade and keep some chats, and avoid crash for some model types ([#3030](https://github.com/nomic-ai/gpt4all/pull/3030))
|
||||
|
||||
## [3.3.1] - 2024-09-27 ([v3.3.y](https://github.com/nomic-ai/gpt4all/tree/v3.3.y))
|
||||
|
||||
|
@ -101,6 +101,7 @@ void Chat::reset()
|
||||
// is to allow switching models but throwing up a dialog warning users if we switch between types
|
||||
// of models that a long recalculation will ensue.
|
||||
m_chatModel->clear();
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::processSystemPrompt()
|
||||
@ -163,6 +164,7 @@ void Chat::prompt(const QString &prompt)
|
||||
{
|
||||
resetResponseState();
|
||||
emit promptRequested(m_collections, prompt);
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::regenerateResponse()
|
||||
@ -170,6 +172,7 @@ void Chat::regenerateResponse()
|
||||
const int index = m_chatModel->count() - 1;
|
||||
m_chatModel->updateSources(index, QList<ResultInfo>());
|
||||
emit regenerateResponseRequested();
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::stopGenerating()
|
||||
@ -261,10 +264,12 @@ ModelInfo Chat::modelInfo() const
|
||||
|
||||
void Chat::setModelInfo(const ModelInfo &modelInfo)
|
||||
{
|
||||
if (m_modelInfo == modelInfo && isModelLoaded())
|
||||
if (m_modelInfo != modelInfo) {
|
||||
m_modelInfo = modelInfo;
|
||||
m_needsSave = true;
|
||||
} else if (isModelLoaded())
|
||||
return;
|
||||
|
||||
m_modelInfo = modelInfo;
|
||||
emit modelInfoChanged();
|
||||
emit modelChangeRequested(modelInfo);
|
||||
}
|
||||
@ -343,12 +348,14 @@ void Chat::generatedNameChanged(const QString &name)
|
||||
int wordCount = qMin(7, words.size());
|
||||
m_name = words.mid(0, wordCount).join(' ');
|
||||
emit nameChanged();
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::generatedQuestionFinished(const QString &question)
|
||||
{
|
||||
m_generatedQuestions << question;
|
||||
emit generatedQuestionsChanged();
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::handleRestoringFromText()
|
||||
@ -393,6 +400,7 @@ void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
||||
m_databaseResults = results;
|
||||
const int index = m_chatModel->count() - 1;
|
||||
m_chatModel->updateSources(index, m_databaseResults);
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
|
||||
@ -402,6 +410,7 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
|
||||
|
||||
m_modelInfo = modelInfo;
|
||||
emit modelInfoChanged();
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value)
|
||||
@ -416,15 +425,15 @@ bool Chat::serialize(QDataStream &stream, int version) const
|
||||
stream << m_id;
|
||||
stream << m_name;
|
||||
stream << m_userName;
|
||||
if (version > 4)
|
||||
if (version >= 5)
|
||||
stream << m_modelInfo.id();
|
||||
else
|
||||
stream << m_modelInfo.filename();
|
||||
if (version > 2)
|
||||
if (version >= 3)
|
||||
stream << m_collections;
|
||||
|
||||
const bool serializeKV = MySettings::globalInstance()->saveChatsContext();
|
||||
if (version > 5)
|
||||
if (version >= 6)
|
||||
stream << serializeKV;
|
||||
if (!m_llmodel->serialize(stream, version, serializeKV))
|
||||
return false;
|
||||
@ -445,7 +454,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
||||
|
||||
QString modelId;
|
||||
stream >> modelId;
|
||||
if (version > 4) {
|
||||
if (version >= 5) {
|
||||
if (ModelList::globalInstance()->contains(modelId))
|
||||
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
|
||||
} else {
|
||||
@ -457,13 +466,13 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
||||
|
||||
bool discardKV = m_modelInfo.id().isEmpty();
|
||||
|
||||
if (version > 2) {
|
||||
if (version >= 3) {
|
||||
stream >> m_collections;
|
||||
emit collectionListChanged(m_collections);
|
||||
}
|
||||
|
||||
bool deserializeKV = true;
|
||||
if (version > 5)
|
||||
if (version >= 6)
|
||||
stream >> deserializeKV;
|
||||
|
||||
m_llmodel->setModelInfo(m_modelInfo);
|
||||
@ -473,7 +482,11 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
||||
return false;
|
||||
|
||||
emit chatModelChanged();
|
||||
return stream.status() == QDataStream::Ok;
|
||||
if (stream.status() != QDataStream::Ok)
|
||||
return false;
|
||||
|
||||
m_needsSave = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
QList<QString> Chat::collectionList() const
|
||||
@ -493,6 +506,7 @@ void Chat::addCollection(const QString &collection)
|
||||
|
||||
m_collections.append(collection);
|
||||
emit collectionListChanged(m_collections);
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
||||
void Chat::removeCollection(const QString &collection)
|
||||
@ -502,4 +516,5 @@ void Chat::removeCollection(const QString &collection)
|
||||
|
||||
m_collections.removeAll(collection);
|
||||
emit collectionListChanged(m_collections);
|
||||
m_needsSave = true;
|
||||
}
|
||||
|
@ -67,6 +67,7 @@ public:
|
||||
{
|
||||
m_userName = name;
|
||||
emit nameChanged();
|
||||
m_needsSave = true;
|
||||
}
|
||||
ChatModel *chatModel() { return m_chatModel; }
|
||||
|
||||
@ -124,6 +125,8 @@ public:
|
||||
|
||||
QList<QString> generatedQuestions() const { return m_generatedQuestions; }
|
||||
|
||||
bool needsSave() const { return m_needsSave; }
|
||||
|
||||
public Q_SLOTS:
|
||||
void serverNewPromptResponsePair(const QString &prompt, const QList<PromptAttachment> &attachments = {});
|
||||
|
||||
@ -203,6 +206,10 @@ private:
|
||||
bool m_firstResponse = true;
|
||||
int m_trySwitchContextInProgress = 0;
|
||||
bool m_isCurrentlyLoading = false;
|
||||
// True if we need to serialize the chat to disk, because of one of two reasons:
|
||||
// - The chat was freshly created during this launch.
|
||||
// - The chat was changed after loading it from disk.
|
||||
bool m_needsSave = true;
|
||||
};
|
||||
|
||||
#endif // CHAT_H
|
||||
|
@ -99,7 +99,12 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
|
||||
QElapsedTimer timer;
|
||||
timer.start();
|
||||
const QString savePath = MySettings::globalInstance()->modelPath();
|
||||
qsizetype nSavedChats = 0;
|
||||
for (Chat *chat : chats) {
|
||||
if (!chat->needsSave())
|
||||
continue;
|
||||
++nSavedChats;
|
||||
|
||||
QString fileName = "gpt4all-" + chat->id() + ".chat";
|
||||
QString filePath = savePath + "/" + fileName;
|
||||
QFile originalFile(filePath);
|
||||
@ -129,7 +134,7 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
|
||||
}
|
||||
|
||||
qint64 elapsedTime = timer.elapsed();
|
||||
qDebug() << "serializing chats took:" << elapsedTime << "ms";
|
||||
qDebug() << "serializing chats took" << elapsedTime << "ms, saved" << nSavedChats << "/" << chats.size() << "chats";
|
||||
emit saveChatsFinished();
|
||||
}
|
||||
|
||||
@ -194,11 +199,16 @@ void ChatsRestoreThread::run()
|
||||
qint32 version;
|
||||
in >> version;
|
||||
if (version < 1) {
|
||||
qWarning() << "ERROR: Chat file has non supported version:" << file.fileName();
|
||||
qWarning() << "WARNING: Chat file version" << version << "is not supported:" << file.fileName();
|
||||
continue;
|
||||
}
|
||||
if (version > CHAT_FORMAT_VERSION) {
|
||||
qWarning().nospace() << "WARNING: Chat file is from a future version (have " << version << " want "
|
||||
<< CHAT_FORMAT_VERSION << "): " << file.fileName();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (version <= 1)
|
||||
if (version < 2)
|
||||
in.setVersion(QDataStream::Qt_6_2);
|
||||
|
||||
FileInfo info;
|
||||
@ -239,7 +249,7 @@ void ChatsRestoreThread::run()
|
||||
continue;
|
||||
}
|
||||
|
||||
if (version <= 1)
|
||||
if (version < 2)
|
||||
in.setVersion(QDataStream::Qt_6_2);
|
||||
}
|
||||
|
||||
|
@ -42,8 +42,8 @@ using namespace Qt::Literals::StringLiterals;
|
||||
//#define DEBUG
|
||||
//#define DEBUG_MODEL_LOADING
|
||||
|
||||
#define GPTJ_INTERNAL_STATE_VERSION 0 // GPT-J is gone but old chats still use this
|
||||
#define LLAMA_INTERNAL_STATE_VERSION 0
|
||||
static constexpr int LLAMA_INTERNAL_STATE_VERSION = 0;
|
||||
static constexpr int API_INTERNAL_STATE_VERSION = 0;
|
||||
|
||||
class LLModelStore {
|
||||
public:
|
||||
@ -105,6 +105,7 @@ void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) {
|
||||
|
||||
ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
: QObject{nullptr}
|
||||
, m_chat(parent)
|
||||
, m_promptResponseTokens(0)
|
||||
, m_promptTokens(0)
|
||||
, m_restoringFromText(false)
|
||||
@ -355,7 +356,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
requestUrl = modelInfo.url();
|
||||
}
|
||||
}
|
||||
m_llModelType = LLModelType::API_;
|
||||
m_llModelType = LLModelTypeV1::API;
|
||||
ChatAPI *model = new ChatAPI();
|
||||
model->setModelName(modelName);
|
||||
model->setRequestURL(requestUrl);
|
||||
@ -571,7 +572,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
|
||||
}
|
||||
|
||||
switch (m_llModelInfo.model->implementation().modelType()[0]) {
|
||||
case 'L': m_llModelType = LLModelType::LLAMA_; break;
|
||||
case 'L': m_llModelType = LLModelTypeV1::LLAMA; break;
|
||||
default:
|
||||
{
|
||||
m_llModelInfo.resetModel(this);
|
||||
@ -624,7 +625,7 @@ void ChatLLM::regenerateResponse()
|
||||
{
|
||||
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
|
||||
// of n_past is of the number of prompt/response pairs, rather than for total tokens.
|
||||
if (m_llModelType == LLModelType::API_)
|
||||
if (m_llModelType == LLModelTypeV1::API)
|
||||
m_ctx.n_past -= 1;
|
||||
else
|
||||
m_ctx.n_past -= m_promptResponseTokens;
|
||||
@ -1045,12 +1046,18 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
// we want to also serialize n_ctx, and read it at load time.
|
||||
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
{
|
||||
if (version > 1) {
|
||||
if (version >= 2) {
|
||||
if (m_llModelType == LLModelTypeV1::NONE) {
|
||||
qWarning() << "ChatLLM ERROR: attempted to serialize a null model for chat id" << m_chat->id()
|
||||
<< "name" << m_chat->name();
|
||||
return false;
|
||||
}
|
||||
|
||||
stream << m_llModelType;
|
||||
switch (m_llModelType) {
|
||||
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
|
||||
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
|
||||
default: Q_UNREACHABLE();
|
||||
case LLModelTypeV1::LLAMA: stream << LLAMA_INTERNAL_STATE_VERSION; break;
|
||||
case LLModelTypeV1::API: stream << API_INTERNAL_STATE_VERSION; break;
|
||||
default: stream << 0; // models removed in v2.5.0
|
||||
}
|
||||
}
|
||||
stream << response();
|
||||
@ -1064,7 +1071,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
if (version <= 3) {
|
||||
if (version < 4) {
|
||||
int responseLogits = 0;
|
||||
stream << responseLogits;
|
||||
}
|
||||
@ -1085,10 +1092,20 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
|
||||
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
|
||||
{
|
||||
if (version > 1) {
|
||||
int internalStateVersion;
|
||||
stream >> m_llModelType;
|
||||
stream >> internalStateVersion; // for future use
|
||||
if (version >= 2) {
|
||||
int llModelType;
|
||||
stream >> llModelType;
|
||||
m_llModelType = (version >= 6 ? parseLLModelTypeV1 : parseLLModelTypeV0)(llModelType);
|
||||
if (m_llModelType == LLModelTypeV1::NONE) {
|
||||
qWarning().nospace() << "error loading chat id " << m_chat->id() << ": unrecognized model type: "
|
||||
<< llModelType;
|
||||
return false;
|
||||
}
|
||||
|
||||
/* note: prior to chat version 10, API models and chats with models removed in v2.5.0 only wrote this because of
|
||||
* undefined behavior in Release builds */
|
||||
int internalStateVersion; // for future use
|
||||
stream >> internalStateVersion;
|
||||
}
|
||||
QString response;
|
||||
stream >> response;
|
||||
@ -1114,7 +1131,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
if (version <= 3) {
|
||||
if (version < 4) {
|
||||
int responseLogits;
|
||||
stream >> responseLogits;
|
||||
}
|
||||
@ -1144,7 +1161,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
stream.skipRawData(tokensSize * sizeof(int));
|
||||
}
|
||||
|
||||
if (version > 0) {
|
||||
if (version >= 1) {
|
||||
QByteArray compressed;
|
||||
stream >> compressed;
|
||||
if (!discardKV)
|
||||
@ -1169,7 +1186,7 @@ void ChatLLM::saveState()
|
||||
if (!isModelLoaded() || m_pristineLoadedState)
|
||||
return;
|
||||
|
||||
if (m_llModelType == LLModelType::API_) {
|
||||
if (m_llModelType == LLModelTypeV1::API) {
|
||||
m_state.clear();
|
||||
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
@ -1197,7 +1214,7 @@ void ChatLLM::restoreState()
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
if (m_llModelType == LLModelType::API_) {
|
||||
if (m_llModelType == LLModelTypeV1::API) {
|
||||
QDataStream stream(m_state);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model.get());
|
||||
|
@ -28,12 +28,60 @@ using namespace Qt::Literals::StringLiterals;
|
||||
class QDataStream;
|
||||
|
||||
// NOTE: values serialized to disk, do not change or reuse
|
||||
enum LLModelType {
|
||||
GPTJ_ = 0, // no longer used
|
||||
LLAMA_ = 1,
|
||||
API_ = 2,
|
||||
BERT_ = 3, // no longer used
|
||||
enum class LLModelTypeV0 { // chat versions 2-5
|
||||
MPT = 0,
|
||||
GPTJ = 1,
|
||||
LLAMA = 2,
|
||||
CHATGPT = 3,
|
||||
REPLIT = 4,
|
||||
FALCON = 5,
|
||||
BERT = 6, // not used
|
||||
STARCODER = 7,
|
||||
};
|
||||
enum class LLModelTypeV1 { // since chat version 6 (v2.5.0)
|
||||
GPTJ = 0, // not for new chats
|
||||
LLAMA = 1,
|
||||
API = 2,
|
||||
BERT = 3, // not used
|
||||
// none of the below are used in new chats
|
||||
REPLIT = 4,
|
||||
FALCON = 5,
|
||||
MPT = 6,
|
||||
STARCODER = 7,
|
||||
NONE = -1, // no state
|
||||
};
|
||||
|
||||
static LLModelTypeV1 parseLLModelTypeV1(int type)
|
||||
{
|
||||
switch (LLModelTypeV1(type)) {
|
||||
case LLModelTypeV1::GPTJ:
|
||||
case LLModelTypeV1::LLAMA:
|
||||
case LLModelTypeV1::API:
|
||||
// case LLModelTypeV1::BERT: -- not used
|
||||
case LLModelTypeV1::REPLIT:
|
||||
case LLModelTypeV1::FALCON:
|
||||
case LLModelTypeV1::MPT:
|
||||
case LLModelTypeV1::STARCODER:
|
||||
return LLModelTypeV1(type);
|
||||
default:
|
||||
return LLModelTypeV1::NONE;
|
||||
}
|
||||
}
|
||||
|
||||
static LLModelTypeV1 parseLLModelTypeV0(int v0)
|
||||
{
|
||||
switch (LLModelTypeV0(v0)) {
|
||||
case LLModelTypeV0::MPT: return LLModelTypeV1::MPT;
|
||||
case LLModelTypeV0::GPTJ: return LLModelTypeV1::GPTJ;
|
||||
case LLModelTypeV0::LLAMA: return LLModelTypeV1::LLAMA;
|
||||
case LLModelTypeV0::CHATGPT: return LLModelTypeV1::API;
|
||||
case LLModelTypeV0::REPLIT: return LLModelTypeV1::REPLIT;
|
||||
case LLModelTypeV0::FALCON: return LLModelTypeV1::FALCON;
|
||||
// case LLModelTypeV0::BERT: -- not used
|
||||
case LLModelTypeV0::STARCODER: return LLModelTypeV1::STARCODER;
|
||||
default: return LLModelTypeV1::NONE;
|
||||
}
|
||||
}
|
||||
|
||||
class ChatLLM;
|
||||
class ChatModel;
|
||||
@ -219,12 +267,13 @@ protected:
|
||||
private:
|
||||
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
||||
|
||||
const Chat *m_chat;
|
||||
std::string m_response;
|
||||
std::string m_trimmedResponse;
|
||||
std::string m_nameResponse;
|
||||
QString m_questionResponse;
|
||||
LLModelInfo m_llModelInfo;
|
||||
LLModelType m_llModelType;
|
||||
LLModelTypeV1 m_llModelType = LLModelTypeV1::NONE;
|
||||
ModelInfo m_modelInfo;
|
||||
TokenTimer *m_timer;
|
||||
QByteArray m_state;
|
||||
|
@ -386,7 +386,7 @@ public:
|
||||
stream << c.stopped;
|
||||
stream << c.thumbsUpState;
|
||||
stream << c.thumbsDownState;
|
||||
if (version > 7) {
|
||||
if (version >= 8) {
|
||||
stream << c.sources.size();
|
||||
for (const ResultInfo &info : c.sources) {
|
||||
Q_ASSERT(!info.file.isEmpty());
|
||||
@ -401,7 +401,7 @@ public:
|
||||
stream << info.from;
|
||||
stream << info.to;
|
||||
}
|
||||
} else if (version > 2) {
|
||||
} else if (version >= 3) {
|
||||
QList<QString> references;
|
||||
QList<QString> referencesContext;
|
||||
int validReferenceNumber = 1;
|
||||
@ -468,7 +468,7 @@ public:
|
||||
stream >> c.stopped;
|
||||
stream >> c.thumbsUpState;
|
||||
stream >> c.thumbsDownState;
|
||||
if (version > 7) {
|
||||
if (version >= 8) {
|
||||
qsizetype count;
|
||||
stream >> count;
|
||||
QList<ResultInfo> sources;
|
||||
@ -488,7 +488,7 @@ public:
|
||||
}
|
||||
c.sources = sources;
|
||||
c.consolidatedSources = consolidateSources(sources);
|
||||
} else if (version > 2) {
|
||||
} else if (version >= 3) {
|
||||
QString references;
|
||||
QList<QString> referencesContext;
|
||||
stream >> references;
|
||||
|
Loading…
Reference in New Issue
Block a user