mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-29 08:46:10 +00:00
Preliminary support for chatgpt models.
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "../gpt4all-backend/gptj.h"
|
||||
#include "../gpt4all-backend/llamamodel.h"
|
||||
#include "../gpt4all-backend/mpt.h"
|
||||
#include "chatgpt.h"
|
||||
|
||||
#include <QCoreApplication>
|
||||
#include <QDir>
|
||||
@@ -21,17 +22,15 @@
|
||||
#define GPTJ_INTERNAL_STATE_VERSION 0
|
||||
#define LLAMA_INTERNAL_STATE_VERSION 0
|
||||
|
||||
static QString modelFilePath(const QString &modelName)
|
||||
static QString modelFilePath(const QString &modelName, bool isChatGPT)
|
||||
{
|
||||
QString appPath = QCoreApplication::applicationDirPath()
|
||||
+ "/ggml-" + modelName + ".bin";
|
||||
QString modelFilename = isChatGPT ? modelName + ".txt" : "/ggml-" + modelName + ".bin";
|
||||
QString appPath = QCoreApplication::applicationDirPath() + modelFilename;
|
||||
QFileInfo infoAppPath(appPath);
|
||||
if (infoAppPath.exists())
|
||||
return appPath;
|
||||
|
||||
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath()
|
||||
+ "/ggml-" + modelName + ".bin";
|
||||
|
||||
QString downloadPath = Download::globalInstance()->downloadLocalModelsPath() + modelFilename;
|
||||
QFileInfo infoLocalPath(downloadPath);
|
||||
if (infoLocalPath.exists())
|
||||
return downloadPath;
|
||||
@@ -139,7 +138,8 @@ bool ChatLLM::loadModel(const QString &modelName)
|
||||
if (isModelLoaded() && m_modelName == modelName)
|
||||
return true;
|
||||
|
||||
QString filePath = modelFilePath(modelName);
|
||||
const bool isChatGPT = modelName.startsWith("chatgpt-");
|
||||
QString filePath = modelFilePath(modelName, isChatGPT);
|
||||
QFileInfo fileInfo(filePath);
|
||||
|
||||
// We have a live model, but it isn't the one we want
|
||||
@@ -198,25 +198,42 @@ bool ChatLLM::loadModel(const QString &modelName)
|
||||
m_modelInfo.fileInfo = fileInfo;
|
||||
|
||||
if (fileInfo.exists()) {
|
||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
fin.seekg(0);
|
||||
fin.close();
|
||||
const bool isGPTJ = magic == 0x67676d6c;
|
||||
const bool isMPT = magic == 0x67676d6d;
|
||||
if (isGPTJ) {
|
||||
m_modelType = LLModelType::GPTJ_;
|
||||
m_modelInfo.model = new GPTJ;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
} else if (isMPT) {
|
||||
m_modelType = LLModelType::MPT_;
|
||||
m_modelInfo.model = new MPT;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
if (isChatGPT) {
|
||||
QString apiKey;
|
||||
QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix
|
||||
{
|
||||
QFile file(filePath);
|
||||
file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text);
|
||||
QTextStream stream(&file);
|
||||
apiKey = stream.readAll();
|
||||
file.close();
|
||||
}
|
||||
m_modelType = LLModelType::CHATGPT_;
|
||||
ChatGPT *model = new ChatGPT();
|
||||
model->setModelName(chatGPTModel);
|
||||
model->setAPIKey(apiKey);
|
||||
m_modelInfo.model = model;
|
||||
} else {
|
||||
m_modelType = LLModelType::LLAMA_;
|
||||
m_modelInfo.model = new LLamaModel;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
fin.seekg(0);
|
||||
fin.close();
|
||||
const bool isGPTJ = magic == 0x67676d6c;
|
||||
const bool isMPT = magic == 0x67676d6d;
|
||||
if (isGPTJ) {
|
||||
m_modelType = LLModelType::GPTJ_;
|
||||
m_modelInfo.model = new GPTJ;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
} else if (isMPT) {
|
||||
m_modelType = LLModelType::MPT_;
|
||||
m_modelInfo.model = new MPT;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
} else {
|
||||
m_modelType = LLModelType::LLAMA_;
|
||||
m_modelInfo.model = new LLamaModel;
|
||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
||||
}
|
||||
}
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
|
||||
@@ -241,8 +258,10 @@ bool ChatLLM::loadModel(const QString &modelName)
|
||||
emit modelLoadingError(error);
|
||||
}
|
||||
|
||||
if (m_modelInfo.model)
|
||||
setModelName(fileInfo.completeBaseName().remove(0, 5)); // remove the ggml- prefix
|
||||
if (m_modelInfo.model) {
|
||||
QString basename = fileInfo.completeBaseName();
|
||||
setModelName(isChatGPT ? basename : basename.remove(0, 5)); // remove the ggml- prefix
|
||||
}
|
||||
|
||||
return m_modelInfo.model;
|
||||
}
|
||||
@@ -440,7 +459,7 @@ void ChatLLM::forceUnloadModel()
|
||||
|
||||
void ChatLLM::unloadModel()
|
||||
{
|
||||
if (!isModelLoaded() || m_isServer)
|
||||
if (!isModelLoaded() || m_isServer) // FIXME: What if server switches models?
|
||||
return;
|
||||
|
||||
saveState();
|
||||
@@ -454,7 +473,7 @@ void ChatLLM::unloadModel()
|
||||
|
||||
void ChatLLM::reloadModel()
|
||||
{
|
||||
if (isModelLoaded() || m_isServer)
|
||||
if (isModelLoaded() || m_isServer) // FIXME: What if server switches models?
|
||||
return;
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
|
Reference in New Issue
Block a user