Modellist temp

This commit is contained in:
Adam Treat
2023-06-22 15:44:49 -04:00
parent c1794597a7
commit 7f01b153b3
25 changed files with 1784 additions and 1108 deletions

View File

@@ -1,9 +1,9 @@
#include "chatllm.h"
#include "chat.h"
#include "download.h"
#include "chatgpt.h"
#include "modellist.h"
#include "network.h"
#include "../gpt4all-backend/llmodel.h"
#include "chatgpt.h"
#include <QCoreApplication>
#include <QDir>
@@ -20,29 +20,6 @@
#define REPLIT_INTERNAL_STATE_VERSION 0
#define LLAMA_INTERNAL_STATE_VERSION 0
static QString modelFilePath(const QString &modelName, bool isChatGPT)
{
QVector<QString> possibleFilePaths;
if (isChatGPT)
possibleFilePaths << "/" + modelName + ".txt";
else {
possibleFilePaths << "/ggml-" + modelName + ".bin";
possibleFilePaths << "/" + modelName + ".bin";
}
for (const QString &modelFilename : possibleFilePaths) {
QString appPath = QCoreApplication::applicationDirPath() + modelFilename;
QFileInfo infoAppPath(appPath);
if (infoAppPath.exists())
return appPath;
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() + modelFilename;
QFileInfo infoLocalPath(downloadPath);
if (infoLocalPath.exists())
return downloadPath;
}
return QString();
}
class LLModelStore {
public:
static LLModelStore *globalInstance();
@@ -102,7 +79,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(parent, &Chat::defaultModelChanged, this, &ChatLLM::handleDefaultModelChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
// The following are blocking operations and will block the llm thread
@@ -121,8 +97,8 @@ ChatLLM::~ChatLLM()
// The only time we should have a model loaded here is on shutdown
// as we explicitly unload the model in all other circumstances
if (isModelLoaded()) {
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
}
}
@@ -135,14 +111,15 @@ void ChatLLM::handleThreadStarted()
bool ChatLLM::loadDefaultModel()
{
if (m_defaultModel.isEmpty()) {
emit modelLoadingError(QString("Could not find default model to load"));
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
if (defaultModel.filename.isEmpty()) {
emit modelLoadingError(QString("Could not find any model to load"));
return false;
}
return loadModel(m_defaultModel);
return loadModel(defaultModel);
}
bool ChatLLM::loadModel(const QString &modelName)
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
{
// This is a complicated method because N different possible threads are interested in the outcome
// of this method. Why? Because we have a main/gui thread trying to monitor the state of N different
@@ -153,11 +130,11 @@ bool ChatLLM::loadModel(const QString &modelName)
// to provide an overview of what we're doing here.
// We're already loaded with this model
if (isModelLoaded() && this->modelName() == modelName)
if (isModelLoaded() && this->modelInfo() == modelInfo)
return true;
bool isChatGPT = modelName.startsWith("chatgpt-");
QString filePath = modelFilePath(modelName, isChatGPT);
bool isChatGPT = modelInfo.isChatGPT;
QString filePath = modelInfo.dirpath + modelInfo.filename;
QFileInfo fileInfo(filePath);
// We have a live model, but it isn't the one we want
@@ -165,36 +142,36 @@ bool ChatLLM::loadModel(const QString &modelName)
if (alreadyAcquired) {
resetContext();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
emit isModelLoadedChanged(false);
} else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_modelInfo = LLModelStore::globalInstance()->acquireModel();
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo3.model;
#endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading
if (!m_shouldBeLoaded) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "no longer need model" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo();
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
emit isModelLoadedChanged(false);
return false;
}
// Check if the store just gave us exactly the model we were looking for
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
restoreState();
emit isModelLoadedChanged(true);
@@ -202,18 +179,18 @@ bool ChatLLM::loadModel(const QString &modelName)
} else {
// Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "deleting model" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
}
}
// Guarantee we've released the previous models memory
Q_ASSERT(!m_modelInfo.model);
Q_ASSERT(!m_llModelInfo.model);
// Store the file info in the modelInfo in case we have an error loading
m_modelInfo.fileInfo = fileInfo;
m_llModelInfo.fileInfo = fileInfo;
if (fileInfo.exists()) {
if (isChatGPT) {
@@ -226,46 +203,46 @@ bool ChatLLM::loadModel(const QString &modelName)
apiKey = stream.readAll();
file.close();
}
m_modelType = LLModelType::CHATGPT_;
m_llModelType = LLModelType::CHATGPT_;
ChatGPT *model = new ChatGPT();
model->setModelName(chatGPTModel);
model->setAPIKey(apiKey);
m_modelInfo.model = model;
m_llModelInfo.model = model;
} else {
m_modelInfo.model = LLModel::construct(filePath.toStdString());
if (m_modelInfo.model) {
bool success = m_modelInfo.model->loadModel(filePath.toStdString());
m_llModelInfo.model = LLModel::construct(filePath.toStdString());
if (m_llModelInfo.model) {
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
if (!success) {
delete std::exchange(m_modelInfo.model, nullptr);
delete std::exchange(m_llModelInfo.model, nullptr);
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
m_modelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelName));
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename));
} else {
switch (m_modelInfo.model->implementation().modelType[0]) {
case 'L': m_modelType = LLModelType::LLAMA_; break;
case 'G': m_modelType = LLModelType::GPTJ_; break;
case 'M': m_modelType = LLModelType::MPT_; break;
case 'R': m_modelType = LLModelType::REPLIT_; break;
switch (m_llModelInfo.model->implementation().modelType[0]) {
case 'L': m_llModelType = LLModelType::LLAMA_; break;
case 'G': m_llModelType = LLModelType::GPTJ_; break;
case 'M': m_llModelType = LLModelType::MPT_; break;
case 'R': m_llModelType = LLModelType::REPLIT_; break;
default:
{
delete std::exchange(m_modelInfo.model, nullptr);
delete std::exchange(m_llModelInfo.model, nullptr);
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
m_modelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelName));
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename));
}
}
}
} else {
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
m_modelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid format for %1").arg(modelName));
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid format for %1").arg(modelInfo.filename));
}
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "new model" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
restoreState();
#if defined(DEBUG)
@@ -282,31 +259,27 @@ bool ChatLLM::loadModel(const QString &modelName)
emit sendModelLoaded();
} else {
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
m_modelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not find file for model %1").arg(modelName));
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename));
}
if (m_modelInfo.model) {
QString basename = fileInfo.completeBaseName();
if (basename.startsWith("ggml-")) // remove the ggml- prefix
basename.remove(0, 5);
setModelName(basename);
}
if (m_llModelInfo.model)
setModelInfo(modelInfo);
return m_modelInfo.model;
return m_llModelInfo.model;
}
bool ChatLLM::isModelLoaded() const
{
return m_modelInfo.model && m_modelInfo.model->isModelLoaded();
return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded();
}
void ChatLLM::regenerateResponse()
{
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens.
if (m_modelType == LLModelType::CHATGPT_)
if (m_llModelType == LLModelType::CHATGPT_)
m_ctx.n_past -= 1;
else
m_ctx.n_past -= m_promptResponseTokens;
@@ -357,20 +330,20 @@ QString ChatLLM::response() const
return QString::fromStdString(remove_leading_whitespace(m_response));
}
QString ChatLLM::modelName() const
ModelInfo ChatLLM::modelInfo() const
{
return m_modelName;
return m_modelInfo;
}
void ChatLLM::setModelName(const QString &modelName)
void ChatLLM::setModelInfo(const ModelInfo &modelInfo)
{
m_modelName = modelName;
emit modelNameChanged();
m_modelInfo = modelInfo;
emit modelInfoChanged(modelInfo);
}
void ChatLLM::modelNameChangeRequested(const QString &modelName)
void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
{
loadModel(modelName);
loadModel(modelInfo);
}
bool ChatLLM::handlePrompt(int32_t token)
@@ -454,13 +427,13 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_modelInfo.model->setThreadCount(n_threads);
m_llModelInfo.model->setThreadCount(n_threads);
#if defined(DEBUG)
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_timer->start();
m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -478,7 +451,7 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
void ChatLLM::setShouldBeLoaded(bool b)
{
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_modelInfo.model;
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model;
#endif
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged();
@@ -505,10 +478,10 @@ void ChatLLM::unloadModel()
saveState();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo();
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
emit isModelLoadedChanged(false);
}
@@ -518,10 +491,10 @@ void ChatLLM::reloadModel()
return;
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model;
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
const QString m = modelName();
if (m.isEmpty())
const ModelInfo m = modelInfo();
if (m.name.isEmpty())
loadDefaultModel();
else
loadModel(m);
@@ -545,7 +518,7 @@ void ChatLLM::generateName()
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
#endif
m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
m_llModelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -562,11 +535,6 @@ void ChatLLM::handleChatIdChanged(const QString &id)
m_llmThread.setObjectName(id);
}
void ChatLLM::handleDefaultModelChanged(const QString &defaultModel)
{
m_defaultModel = defaultModel;
}
bool ChatLLM::handleNamePrompt(int32_t token)
{
Q_UNUSED(token);
@@ -595,8 +563,8 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
bool ChatLLM::serialize(QDataStream &stream, int version)
{
if (version > 1) {
stream << m_modelType;
switch (m_modelType) {
stream << m_llModelType;
switch (m_llModelType) {
case REPLIT_: stream << REPLIT_INTERNAL_STATE_VERSION; break;
case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break;
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
@@ -629,7 +597,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
{
if (version > 1) {
int internalStateVersion;
stream >> m_modelType;
stream >> m_llModelType;
stream >> internalStateVersion; // for future use
}
QString response;
@@ -670,21 +638,21 @@ void ChatLLM::saveState()
if (!isModelLoaded())
return;
if (m_modelType == LLModelType::CHATGPT_) {
if (m_llModelType == LLModelType::CHATGPT_) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_5);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_llModelInfo.model);
stream << chatGPT->context();
return;
}
const size_t stateSize = m_modelInfo.model->stateSize();
const size_t stateSize = m_llModelInfo.model->stateSize();
m_state.resize(stateSize);
#if defined(DEBUG)
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_llModelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
}
void ChatLLM::restoreState()
@@ -692,10 +660,10 @@ void ChatLLM::restoreState()
if (!isModelLoaded() || m_state.isEmpty())
return;
if (m_modelType == LLModelType::CHATGPT_) {
if (m_llModelType == LLModelType::CHATGPT_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_5);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_llModelInfo.model);
QList<QString> context;
stream >> context;
chatGPT->setContext(context);
@@ -707,7 +675,7 @@ void ChatLLM::restoreState()
#if defined(DEBUG)
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear();
m_state.resize(0);
}