Much better memory mgmt for multi-threaded model loading/unloading.

This commit is contained in:
Adam Treat
2023-05-13 19:05:35 -04:00
committed by AT
parent 2989b74d43
commit ddc24acf33
6 changed files with 243 additions and 74 deletions

View File

@@ -3,9 +3,23 @@
#include <QObject>
#include <QThread>
#include <QFileInfo>
#include "../gpt4all-backend/llmodel.h"
enum LLModelType {
MPT_,
GPTJ_,
LLAMA_
};
struct LLModelInfo {
LLModel *model = nullptr;
QFileInfo fileInfo;
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
// must be able to serialize the information even if it is in the unloaded state
};
class Chat;
class ChatLLM : public QObject
{
@@ -17,12 +31,6 @@ class ChatLLM : public QObject
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
public:
enum ModelType {
MPT_,
GPTJ_,
LLAMA_
};
ChatLLM(Chat *parent);
virtual ~ChatLLM();
@@ -33,6 +41,9 @@ public:
void stopGenerating() { m_stopGenerating = true; }
bool shouldBeLoaded() const { return m_shouldBeLoaded; }
void setShouldBeLoaded(bool b);
QString response() const;
QString modelName() const;
@@ -52,10 +63,12 @@ public Q_SLOTS:
bool loadDefaultModel();
bool loadModel(const QString &modelName);
void modelNameChangeRequested(const QString &modelName);
void forceUnloadModel();
void unloadModel();
void reloadModel(const QString &modelName);
void reloadModel();
void generateName();
void handleChatIdChanged();
void handleShouldBeLoadedChanged();
Q_SIGNALS:
void isModelLoadedChanged();
@@ -71,6 +84,7 @@ Q_SIGNALS:
void generatedNameChanged();
void stateChanged();
void threadStarted();
void shouldBeLoadedChanged();
protected:
LLModel::PromptContext m_ctx;
@@ -89,16 +103,17 @@ private:
void restoreState();
private:
LLModel *m_llmodel;
LLModelInfo m_modelInfo;
LLModelType m_modelType;
std::string m_response;
std::string m_nameResponse;
quint32 m_responseLogits;
QString m_modelName;
ModelType m_modelType;
Chat *m_chat;
QByteArray m_state;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
bool m_isRecalc;
};