Preliminary support for chatgpt models.

This commit is contained in:
Adam Treat
2023-05-14 20:12:15 -04:00
committed by AT
parent 22dc2bc5d2
commit fc6ab1f776
11 changed files with 439 additions and 34 deletions

View File

@@ -5,6 +5,7 @@
#include "../gpt4all-backend/gptj.h"
#include "../gpt4all-backend/llamamodel.h"
#include "../gpt4all-backend/mpt.h"
#include "chatgpt.h"
#include <QCoreApplication>
#include <QDir>
@@ -21,17 +22,15 @@
#define GPTJ_INTERNAL_STATE_VERSION 0
#define LLAMA_INTERNAL_STATE_VERSION 0
static QString modelFilePath(const QString &modelName)
static QString modelFilePath(const QString &modelName, bool isChatGPT)
{
QString appPath = QCoreApplication::applicationDirPath()
+ "/ggml-" + modelName + ".bin";
QString modelFilename = isChatGPT ? modelName + ".txt" : "/ggml-" + modelName + ".bin";
QString appPath = QCoreApplication::applicationDirPath() + modelFilename;
QFileInfo infoAppPath(appPath);
if (infoAppPath.exists())
return appPath;
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath()
+ "/ggml-" + modelName + ".bin";
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() + modelFilename;
QFileInfo infoLocalPath(downloadPath);
if (infoLocalPath.exists())
return downloadPath;
@@ -139,7 +138,8 @@ bool ChatLLM::loadModel(const QString &modelName)
if (isModelLoaded() && m_modelName == modelName)
return true;
QString filePath = modelFilePath(modelName);
const bool isChatGPT = modelName.startsWith("chatgpt-");
QString filePath = modelFilePath(modelName, isChatGPT);
QFileInfo fileInfo(filePath);
// We have a live model, but it isn't the one we want
@@ -198,25 +198,42 @@ bool ChatLLM::loadModel(const QString &modelName)
m_modelInfo.fileInfo = fileInfo;
if (fileInfo.exists()) {
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
fin.seekg(0);
fin.close();
const bool isGPTJ = magic == 0x67676d6c;
const bool isMPT = magic == 0x67676d6d;
if (isGPTJ) {
m_modelType = LLModelType::GPTJ_;
m_modelInfo.model = new GPTJ;
m_modelInfo.model->loadModel(filePath.toStdString());
} else if (isMPT) {
m_modelType = LLModelType::MPT_;
m_modelInfo.model = new MPT;
m_modelInfo.model->loadModel(filePath.toStdString());
if (isChatGPT) {
QString apiKey;
QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix
{
QFile file(filePath);
file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text);
QTextStream stream(&file);
apiKey = stream.readAll();
file.close();
}
m_modelType = LLModelType::CHATGPT_;
ChatGPT *model = new ChatGPT();
model->setModelName(chatGPTModel);
model->setAPIKey(apiKey);
m_modelInfo.model = model;
} else {
m_modelType = LLModelType::LLAMA_;
m_modelInfo.model = new LLamaModel;
m_modelInfo.model->loadModel(filePath.toStdString());
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
fin.seekg(0);
fin.close();
const bool isGPTJ = magic == 0x67676d6c;
const bool isMPT = magic == 0x67676d6d;
if (isGPTJ) {
m_modelType = LLModelType::GPTJ_;
m_modelInfo.model = new GPTJ;
m_modelInfo.model->loadModel(filePath.toStdString());
} else if (isMPT) {
m_modelType = LLModelType::MPT_;
m_modelInfo.model = new MPT;
m_modelInfo.model->loadModel(filePath.toStdString());
} else {
m_modelType = LLModelType::LLAMA_;
m_modelInfo.model = new LLamaModel;
m_modelInfo.model->loadModel(filePath.toStdString());
}
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
@@ -241,8 +258,10 @@ bool ChatLLM::loadModel(const QString &modelName)
emit modelLoadingError(error);
}
if (m_modelInfo.model)
setModelName(fileInfo.completeBaseName().remove(0, 5)); // remove the ggml- prefix
if (m_modelInfo.model) {
QString basename = fileInfo.completeBaseName();
setModelName(isChatGPT ? basename : basename.remove(0, 5)); // remove the ggml- prefix
}
return m_modelInfo.model;
}
@@ -440,7 +459,7 @@ void ChatLLM::forceUnloadModel()
void ChatLLM::unloadModel()
{
if (!isModelLoaded() || m_isServer)
if (!isModelLoaded() || m_isServer) // FIXME: What if server switches models?
return;
saveState();
@@ -454,7 +473,7 @@ void ChatLLM::unloadModel()
void ChatLLM::reloadModel()
{
if (isModelLoaded() || m_isServer)
if (isModelLoaded() || m_isServer) // FIXME: What if server switches models?
return;
#if defined(DEBUG_MODEL_LOADING)