feat: Add support for Mistral API models (#2053)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Signed-off-by: Cédric Sazos <cedric.sazos@tutanota.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Olyxz16 2024-03-13 23:23:57 +01:00 committed by GitHub
parent 406e88b59a
commit 2c0a660e6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 242 additions and 98 deletions

View File

@ -73,7 +73,7 @@ qt_add_executable(chat
chat.h chat.cpp chat.h chat.cpp
chatllm.h chatllm.cpp chatllm.h chatllm.cpp
chatmodel.h chatlistmodel.h chatlistmodel.cpp chatmodel.h chatlistmodel.h chatlistmodel.cpp
chatgpt.h chatgpt.cpp chatapi.h chatapi.cpp
database.h database.cpp database.h database.cpp
embeddings.h embeddings.cpp embeddings.h embeddings.cpp
download.h download.cpp download.h download.cpp

View File

@ -1,4 +1,4 @@
#include "chatgpt.h" #include "chatapi.h"
#include <string> #include <string>
#include <vector> #include <vector>
@ -13,14 +13,15 @@
//#define DEBUG //#define DEBUG
ChatGPT::ChatGPT() ChatAPI::ChatAPI()
: QObject(nullptr) : QObject(nullptr)
, m_modelName("gpt-3.5-turbo") , m_modelName("gpt-3.5-turbo")
, m_requestURL("")
, m_responseCallback(nullptr) , m_responseCallback(nullptr)
{ {
} }
size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx, int ngl) size_t ChatAPI::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
{ {
Q_UNUSED(modelPath); Q_UNUSED(modelPath);
Q_UNUSED(n_ctx); Q_UNUSED(n_ctx);
@ -28,7 +29,7 @@ size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
return 0; return 0;
} }
bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx, int ngl) bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl)
{ {
Q_UNUSED(modelPath); Q_UNUSED(modelPath);
Q_UNUSED(n_ctx); Q_UNUSED(n_ctx);
@ -36,59 +37,59 @@ bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx, int ngl)
return true; return true;
} }
void ChatGPT::setThreadCount(int32_t n_threads) void ChatAPI::setThreadCount(int32_t n_threads)
{ {
Q_UNUSED(n_threads); Q_UNUSED(n_threads);
qt_noop(); qt_noop();
} }
int32_t ChatGPT::threadCount() const int32_t ChatAPI::threadCount() const
{ {
return 1; return 1;
} }
ChatGPT::~ChatGPT() ChatAPI::~ChatAPI()
{ {
} }
bool ChatGPT::isModelLoaded() const bool ChatAPI::isModelLoaded() const
{ {
return true; return true;
} }
// All three of the state virtual functions are handled custom inside of chatllm save/restore // All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t ChatGPT::stateSize() const size_t ChatAPI::stateSize() const
{ {
return 0; return 0;
} }
size_t ChatGPT::saveState(uint8_t *dest) const size_t ChatAPI::saveState(uint8_t *dest) const
{ {
Q_UNUSED(dest); Q_UNUSED(dest);
return 0; return 0;
} }
size_t ChatGPT::restoreState(const uint8_t *src) size_t ChatAPI::restoreState(const uint8_t *src)
{ {
Q_UNUSED(src); Q_UNUSED(src);
return 0; return 0;
} }
void ChatGPT::prompt(const std::string &prompt, void ChatAPI::prompt(const std::string &prompt,
const std::string &promptTemplate, const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback, std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback, std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback, std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx, PromptContext &promptCtx,
bool special, bool special,
std::string *fakeReply) { std::string *fakeReply) {
Q_UNUSED(promptCallback); Q_UNUSED(promptCallback);
Q_UNUSED(recalculateCallback); Q_UNUSED(recalculateCallback);
Q_UNUSED(special); Q_UNUSED(special);
if (!isModelLoaded()) { if (!isModelLoaded()) {
std::cerr << "ChatGPT ERROR: prompt won't work with an unloaded model!\n"; std::cerr << "ChatAPI ERROR: prompt won't work with an unloaded model!\n";
return; return;
} }
@ -128,7 +129,7 @@ void ChatGPT::prompt(const std::string &prompt,
QJsonArray messages; QJsonArray messages;
for (int i = 0; i < m_context.count(); ++i) { for (int i = 0; i < m_context.count(); ++i) {
QJsonObject message; QJsonObject message;
message.insert("role", i % 2 == 0 ? "assistant" : "user"); message.insert("role", i % 2 == 0 ? "user" : "assistant");
message.insert("content", m_context.at(i)); message.insert("content", m_context.at(i));
messages.append(message); messages.append(message);
} }
@ -142,7 +143,7 @@ void ChatGPT::prompt(const std::string &prompt,
QJsonDocument doc(root); QJsonDocument doc(root);
#if defined(DEBUG) #if defined(DEBUG)
qDebug().noquote() << "ChatGPT::prompt begin network request" << doc.toJson(); qDebug().noquote() << "ChatAPI::prompt begin network request" << doc.toJson();
#endif #endif
m_responseCallback = responseCallback; m_responseCallback = responseCallback;
@ -150,10 +151,10 @@ void ChatGPT::prompt(const std::string &prompt,
// The following code sets up a worker thread and object to perform the actual api request to // The following code sets up a worker thread and object to perform the actual api request to
// chatgpt and then blocks until it is finished // chatgpt and then blocks until it is finished
QThread workerThread; QThread workerThread;
ChatGPTWorker worker(this); ChatAPIWorker worker(this);
worker.moveToThread(&workerThread); worker.moveToThread(&workerThread);
connect(&worker, &ChatGPTWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &ChatGPT::request, &worker, &ChatGPTWorker::request, Qt::QueuedConnection); connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection);
workerThread.start(); workerThread.start();
emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact)); emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
workerThread.wait(); workerThread.wait();
@ -164,40 +165,40 @@ void ChatGPT::prompt(const std::string &prompt,
m_responseCallback = nullptr; m_responseCallback = nullptr;
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "ChatGPT::prompt end network request"; qDebug() << "ChatAPI::prompt end network request";
#endif #endif
} }
bool ChatGPT::callResponse(int32_t token, const std::string& string) bool ChatAPI::callResponse(int32_t token, const std::string& string)
{ {
Q_ASSERT(m_responseCallback); Q_ASSERT(m_responseCallback);
if (!m_responseCallback) { if (!m_responseCallback) {
std::cerr << "ChatGPT ERROR: no response callback!\n"; std::cerr << "ChatAPI ERROR: no response callback!\n";
return false; return false;
} }
return m_responseCallback(token, string); return m_responseCallback(token, string);
} }
void ChatGPTWorker::request(const QString &apiKey, void ChatAPIWorker::request(const QString &apiKey,
LLModel::PromptContext *promptCtx, LLModel::PromptContext *promptCtx,
const QByteArray &array) const QByteArray &array)
{ {
m_ctx = promptCtx; m_ctx = promptCtx;
QUrl openaiUrl("https://api.openai.com/v1/chat/completions"); QUrl apiUrl(m_chat->url());
const QString authorization = QString("Bearer %1").arg(apiKey).trimmed(); const QString authorization = QString("Bearer %1").arg(apiKey).trimmed();
QNetworkRequest request(openaiUrl); QNetworkRequest request(apiUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8()); request.setRawHeader("Authorization", authorization.toUtf8());
m_networkManager = new QNetworkAccessManager(this); m_networkManager = new QNetworkAccessManager(this);
QNetworkReply *reply = m_networkManager->post(request, array); QNetworkReply *reply = m_networkManager->post(request, array);
connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort); connect(qApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
connect(reply, &QNetworkReply::finished, this, &ChatGPTWorker::handleFinished); connect(reply, &QNetworkReply::finished, this, &ChatAPIWorker::handleFinished);
connect(reply, &QNetworkReply::readyRead, this, &ChatGPTWorker::handleReadyRead); connect(reply, &QNetworkReply::readyRead, this, &ChatAPIWorker::handleReadyRead);
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPTWorker::handleErrorOccurred); connect(reply, &QNetworkReply::errorOccurred, this, &ChatAPIWorker::handleErrorOccurred);
} }
void ChatGPTWorker::handleFinished() void ChatAPIWorker::handleFinished()
{ {
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender()); QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) { if (!reply) {
@ -210,14 +211,14 @@ void ChatGPTWorker::handleFinished()
bool ok; bool ok;
int code = response.toInt(&ok); int code = response.toInt(&ok);
if (!ok || code != 200) { if (!ok || code != 200) {
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") qWarning().noquote() << "ERROR: ChatAPIWorker::handleFinished got HTTP Error" << code << "response:"
.arg(code).arg(reply->errorString()); << reply->errorString();
} }
reply->deleteLater(); reply->deleteLater();
emit finished(); emit finished();
} }
void ChatGPTWorker::handleReadyRead() void ChatAPIWorker::handleReadyRead()
{ {
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender()); QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) { if (!reply) {
@ -230,8 +231,11 @@ void ChatGPTWorker::handleReadyRead()
bool ok; bool ok;
int code = response.toInt(&ok); int code = response.toInt(&ok);
if (!ok || code != 200) { if (!ok || code != 200) {
m_chat->callResponse(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n") m_chat->callResponse(
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString()); -1,
QString("ERROR: ChatAPIWorker::handleReadyRead got HTTP Error %1 %2: %3")
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString()
);
emit finished(); emit finished();
return; return;
} }
@ -251,8 +255,8 @@ void ChatGPTWorker::handleReadyRead()
QJsonParseError err; QJsonParseError err;
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err); const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
if (err.error != QJsonParseError::NoError) { if (err.error != QJsonParseError::NoError) {
m_chat->callResponse(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n") m_chat->callResponse(-1, QString("ERROR: ChatAPI responded with invalid json \"%1\"")
.arg(err.errorString()).toStdString()); .arg(err.errorString()).toStdString());
continue; continue;
} }
@ -271,7 +275,7 @@ void ChatGPTWorker::handleReadyRead()
} }
} }
void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code) void ChatAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
{ {
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender()); QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply || reply->error() == QNetworkReply::OperationCanceledError /*when we call abort on purpose*/) { if (!reply || reply->error() == QNetworkReply::OperationCanceledError /*when we call abort on purpose*/) {
@ -279,7 +283,7 @@ void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
return; return;
} }
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") qWarning().noquote() << "ERROR: ChatAPIWorker::handleErrorOccurred got HTTP Error" << code << "response:"
.arg(code).arg(reply->errorString()); << reply->errorString();
emit finished(); emit finished();
} }

View File

@ -1,5 +1,5 @@
#ifndef CHATGPT_H #ifndef CHATAPI_H
#define CHATGPT_H #define CHATAPI_H
#include <stdexcept> #include <stdexcept>
@ -13,22 +13,22 @@
#include "../gpt4all-backend/llmodel.h" #include "../gpt4all-backend/llmodel.h"
class ChatGPT; class ChatAPI;
class ChatGPTWorker : public QObject { class ChatAPIWorker : public QObject {
Q_OBJECT Q_OBJECT
public: public:
ChatGPTWorker(ChatGPT *chatGPT) ChatAPIWorker(ChatAPI *chatAPI)
: QObject(nullptr) : QObject(nullptr)
, m_ctx(nullptr) , m_ctx(nullptr)
, m_networkManager(nullptr) , m_networkManager(nullptr)
, m_chat(chatGPT) {} , m_chat(chatAPI) {}
virtual ~ChatGPTWorker() {} virtual ~ChatAPIWorker() {}
QString currentResponse() const { return m_currentResponse; } QString currentResponse() const { return m_currentResponse; }
void request(const QString &apiKey, void request(const QString &apiKey,
LLModel::PromptContext *promptCtx, LLModel::PromptContext *promptCtx,
const QByteArray &array); const QByteArray &array);
Q_SIGNALS: Q_SIGNALS:
void finished(); void finished();
@ -39,17 +39,17 @@ private Q_SLOTS:
void handleErrorOccurred(QNetworkReply::NetworkError code); void handleErrorOccurred(QNetworkReply::NetworkError code);
private: private:
ChatGPT *m_chat; ChatAPI *m_chat;
LLModel::PromptContext *m_ctx; LLModel::PromptContext *m_ctx;
QNetworkAccessManager *m_networkManager; QNetworkAccessManager *m_networkManager;
QString m_currentResponse; QString m_currentResponse;
}; };
class ChatGPT : public QObject, public LLModel { class ChatAPI : public QObject, public LLModel {
Q_OBJECT Q_OBJECT
public: public:
ChatGPT(); ChatAPI();
virtual ~ChatGPT(); virtual ~ChatAPI();
bool supportsEmbedding() const override { return false; } bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; } bool supportsCompletion() const override { return true; }
@ -60,19 +60,21 @@ public:
size_t saveState(uint8_t *dest) const override; size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override; size_t restoreState(const uint8_t *src) override;
void prompt(const std::string &prompt, void prompt(const std::string &prompt,
const std::string &promptTemplate, const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback, std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback, std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback, std::function<bool(bool)> recalculateCallback,
PromptContext &ctx, PromptContext &ctx,
bool special, bool special,
std::string *fakeReply) override; std::string *fakeReply) override;
void setThreadCount(int32_t n_threads) override; void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override; int32_t threadCount() const override;
void setModelName(const QString &modelName) { m_modelName = modelName; } void setModelName(const QString &modelName) { m_modelName = modelName; }
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; }
QString url() const { return m_requestURL; }
QList<QString> context() const { return m_context; } QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; } void setContext(const QList<QString> &context) { m_context = context; }
@ -81,8 +83,8 @@ public:
Q_SIGNALS: Q_SIGNALS:
void request(const QString &apiKey, void request(const QString &apiKey,
LLModel::PromptContext *ctx, LLModel::PromptContext *ctx,
const QByteArray &array); const QByteArray &array);
protected: protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use // We have to implement these as they are pure virtual in base class, but we don't actually use
@ -128,8 +130,9 @@ private:
std::function<bool(int32_t, const std::string&)> m_responseCallback; std::function<bool(int32_t, const std::string&)> m_responseCallback;
QString m_modelName; QString m_modelName;
QString m_apiKey; QString m_apiKey;
QString m_requestURL;
QList<QString> m_context; QList<QString> m_context;
QStringList m_queuedPrompts; QStringList m_queuedPrompts;
}; };
#endif // CHATGPT_H #endif // CHATAPI_H

View File

@ -1,6 +1,6 @@
#include "chatllm.h" #include "chatllm.h"
#include "chat.h" #include "chat.h"
#include "chatgpt.h" #include "chatapi.h"
#include "localdocs.h" #include "localdocs.h"
#include "modellist.h" #include "modellist.h"
#include "network.h" #include "network.h"
@ -213,7 +213,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (isModelLoaded() && this->modelInfo() == modelInfo) if (isModelLoaded() && this->modelInfo() == modelInfo)
return true; return true;
bool isChatGPT = modelInfo.isOnline; // right now only chatgpt is offered for online chat models...
QString filePath = modelInfo.dirpath + modelInfo.filename(); QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath); QFileInfo fileInfo(filePath);
@ -279,19 +278,23 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.fileInfo = fileInfo; m_llModelInfo.fileInfo = fileInfo;
if (fileInfo.exists()) { if (fileInfo.exists()) {
if (isChatGPT) { if (modelInfo.isOnline) {
QString apiKey; QString apiKey;
QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix QString modelName;
{ {
QFile file(filePath); QFile file(filePath);
file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text); file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text);
QTextStream stream(&file); QTextStream stream(&file);
apiKey = stream.readAll(); QString text = stream.readAll();
file.close(); QJsonDocument doc = QJsonDocument::fromJson(text.toUtf8());
QJsonObject obj = doc.object();
apiKey = obj["apiKey"].toString();
modelName = obj["modelName"].toString();
} }
m_llModelType = LLModelType::CHATGPT_; m_llModelType = LLModelType::API_;
ChatGPT *model = new ChatGPT(); ChatAPI *model = new ChatAPI();
model->setModelName(chatGPTModel); model->setModelName(modelName);
model->setRequestURL(modelInfo.url());
model->setAPIKey(apiKey); model->setAPIKey(apiKey);
m_llModelInfo.model = model; m_llModelInfo.model = model;
} else { } else {
@ -468,7 +471,7 @@ void ChatLLM::regenerateResponse()
{ {
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens. // of n_past is of the number of prompt/response pairs, rather than for total tokens.
if (m_llModelType == LLModelType::CHATGPT_) if (m_llModelType == LLModelType::API_)
m_ctx.n_past -= 1; m_ctx.n_past -= 1;
else else
m_ctx.n_past -= m_promptResponseTokens; m_ctx.n_past -= m_promptResponseTokens;
@ -958,12 +961,12 @@ void ChatLLM::saveState()
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
if (m_llModelType == LLModelType::CHATGPT_) { if (m_llModelType == LLModelType::API_) {
m_state.clear(); m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly); QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4); stream.setVersion(QDataStream::Qt_6_4);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_llModelInfo.model); ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
stream << chatGPT->context(); stream << chatAPI->context();
return; return;
} }
@ -980,13 +983,13 @@ void ChatLLM::restoreState()
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
if (m_llModelType == LLModelType::CHATGPT_) { if (m_llModelType == LLModelType::API_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly); QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_4); stream.setVersion(QDataStream::Qt_6_4);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_llModelInfo.model); ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
QList<QString> context; QList<QString> context;
stream >> context; stream >> context;
chatGPT->setContext(context); chatAPI->setContext(context);
m_state.clear(); m_state.clear();
m_state.squeeze(); m_state.squeeze();
return; return;

View File

@ -12,7 +12,7 @@
enum LLModelType { enum LLModelType {
GPTJ_, GPTJ_,
LLAMA_, LLAMA_,
CHATGPT_, API_,
}; };
struct LLModelInfo { struct LLModelInfo {

View File

@ -182,8 +182,17 @@ void Download::installModel(const QString &modelFile, const QString &apiKey)
QString filePath = MySettings::globalInstance()->modelPath() + modelFile; QString filePath = MySettings::globalInstance()->modelPath() + modelFile;
QFile file(filePath); QFile file(filePath);
if (file.open(QIODeviceBase::WriteOnly | QIODeviceBase::Text)) { if (file.open(QIODeviceBase::WriteOnly | QIODeviceBase::Text)) {
QJsonObject obj;
QString modelName(modelFile);
modelName.remove(0, 8); // strip "gpt4all-" prefix
modelName.chop(7); // strip ".rmodel" extension
obj.insert("apiKey", apiKey);
obj.insert("modelName", modelName);
QJsonDocument doc(obj);
QTextStream stream(&file); QTextStream stream(&file);
stream << apiKey; stream << doc.toJson();
file.close(); file.close();
ModelList::globalInstance()->updateModelsFromDirectory(); ModelList::globalInstance()->updateModelsFromDirectory();
} }

View File

@ -1172,6 +1172,44 @@ void ModelList::updateModelsFromDirectory()
const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
const QString localPath = MySettings::globalInstance()->modelPath(); const QString localPath = MySettings::globalInstance()->modelPath();
auto updateOldRemoteModels = [&](const QString& path) {
QDirIterator it(path, QDirIterator::Subdirectories);
while (it.hasNext()) {
it.next();
if (!it.fileInfo().isDir()) {
QString filename = it.fileName();
if (filename.endsWith(".txt")) {
QString apikey;
QString modelname(filename);
modelname.chop(4); // strip ".txt" extension
if (filename.startsWith("chatgpt-")) {
modelname.remove(0, 8); // strip "chatgpt-" prefix
}
QFile file(path + filename);
if (file.open(QIODevice::ReadWrite)) {
QTextStream in(&file);
apikey = in.readAll();
file.close();
}
QJsonObject obj;
obj.insert("apiKey", apikey);
obj.insert("modelName", modelname);
QJsonDocument doc(obj);
auto newfilename = QString("gpt4all-%1.rmodel").arg(modelname);
QFile newfile(path + newfilename);
if (newfile.open(QIODevice::ReadWrite)) {
QTextStream out(&newfile);
out << doc.toJson();
newfile.close();
}
file.remove();
}
}
}
};
auto processDirectory = [&](const QString& path) { auto processDirectory = [&](const QString& path) {
QDirIterator it(path, QDirIterator::Subdirectories); QDirIterator it(path, QDirIterator::Subdirectories);
while (it.hasNext()) { while (it.hasNext()) {
@ -1180,8 +1218,7 @@ void ModelList::updateModelsFromDirectory()
if (!it.fileInfo().isDir()) { if (!it.fileInfo().isDir()) {
QString filename = it.fileName(); QString filename = it.fileName();
if ((filename.endsWith(".gguf") && !filename.startsWith("incomplete")) if ((filename.endsWith(".gguf") && !filename.startsWith("incomplete")) || filename.endsWith(".rmodel")) {
|| (filename.endsWith(".txt") && (filename.startsWith("chatgpt-") || filename.startsWith("nomic-")))) {
QString filePath = it.filePath(); QString filePath = it.filePath();
QFileInfo info(filePath); QFileInfo info(filePath);
@ -1207,8 +1244,7 @@ void ModelList::updateModelsFromDirectory()
QVector<QPair<int, QVariant>> data { QVector<QPair<int, QVariant>> data {
{ InstalledRole, true }, { InstalledRole, true },
{ FilenameRole, filename }, { FilenameRole, filename },
// FIXME: WE should change this to use a consistent filename for online models { OnlineRole, filename.endsWith(".rmodel") },
{ OnlineRole, filename.startsWith("chatgpt-") || filename.startsWith("nomic-") },
{ DirpathRole, info.dir().absolutePath() + "/" }, { DirpathRole, info.dir().absolutePath() + "/" },
{ FilesizeRole, toFileSize(info.size()) }, { FilesizeRole, toFileSize(info.size()) },
}; };
@ -1219,9 +1255,13 @@ void ModelList::updateModelsFromDirectory()
} }
}; };
updateOldRemoteModels(exePath);
processDirectory(exePath); processDirectory(exePath);
if (localPath != exePath) if (localPath != exePath) {
updateOldRemoteModels(localPath);
processDirectory(localPath); processDirectory(localPath);
}
} }
#define MODELS_VERSION 3 #define MODELS_VERSION 3
@ -1466,7 +1506,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
{ {
const QString modelName = "ChatGPT-3.5 Turbo"; const QString modelName = "ChatGPT-3.5 Turbo";
const QString id = modelName; const QString id = modelName;
const QString modelFilename = "chatgpt-gpt-3.5-turbo.txt"; const QString modelFilename = "gpt4all-gpt-3.5-turbo.rmodel";
if (contains(modelFilename)) if (contains(modelFilename))
changeId(modelFilename, id); changeId(modelFilename, id);
if (!contains(id)) if (!contains(id))
@ -1478,12 +1518,13 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
{ ModelList::OnlineRole, true }, { ModelList::OnlineRole, true },
{ ModelList::DescriptionRole, { ModelList::DescriptionRole,
tr("<strong>OpenAI's ChatGPT model GPT-3.5 Turbo</strong><br>") + chatGPTDesc }, tr("<strong>OpenAI's ChatGPT model GPT-3.5 Turbo</strong><br>") + chatGPTDesc },
{ ModelList::RequiresVersionRole, "2.4.2" }, { ModelList::RequiresVersionRole, "2.7.4" },
{ ModelList::OrderRole, "ca" }, { ModelList::OrderRole, "ca" },
{ ModelList::RamrequiredRole, 0 }, { ModelList::RamrequiredRole, 0 },
{ ModelList::ParametersRole, "?" }, { ModelList::ParametersRole, "?" },
{ ModelList::QuantRole, "NA" }, { ModelList::QuantRole, "NA" },
{ ModelList::TypeRole, "GPT" }, { ModelList::TypeRole, "GPT" },
{ ModelList::UrlRole, "https://api.openai.com/v1/chat/completions"},
}; };
updateData(id, data); updateData(id, data);
} }
@ -1493,7 +1534,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
const QString modelName = "ChatGPT-4"; const QString modelName = "ChatGPT-4";
const QString id = modelName; const QString id = modelName;
const QString modelFilename = "chatgpt-gpt-4.txt"; const QString modelFilename = "gpt4all-gpt-4.rmodel";
if (contains(modelFilename)) if (contains(modelFilename))
changeId(modelFilename, id); changeId(modelFilename, id);
if (!contains(id)) if (!contains(id))
@ -1505,15 +1546,99 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
{ ModelList::OnlineRole, true }, { ModelList::OnlineRole, true },
{ ModelList::DescriptionRole, { ModelList::DescriptionRole,
tr("<strong>OpenAI's ChatGPT model GPT-4</strong><br>") + chatGPTDesc + chatGPT4Warn }, tr("<strong>OpenAI's ChatGPT model GPT-4</strong><br>") + chatGPTDesc + chatGPT4Warn },
{ ModelList::RequiresVersionRole, "2.4.2" }, { ModelList::RequiresVersionRole, "2.7.4" },
{ ModelList::OrderRole, "cb" }, { ModelList::OrderRole, "cb" },
{ ModelList::RamrequiredRole, 0 }, { ModelList::RamrequiredRole, 0 },
{ ModelList::ParametersRole, "?" }, { ModelList::ParametersRole, "?" },
{ ModelList::QuantRole, "NA" }, { ModelList::QuantRole, "NA" },
{ ModelList::TypeRole, "GPT" }, { ModelList::TypeRole, "GPT" },
{ ModelList::UrlRole, "https://api.openai.com/v1/chat/completions"},
}; };
updateData(id, data); updateData(id, data);
} }
const QString mistralDesc = tr("<ul><li>Requires personal Mistral API key.</li><li>WARNING: Will send"
" your chats to Mistral!</li><li>Your API key will be stored on disk</li><li>Will only be used"
" to communicate with Mistral</li><li>You can apply for an API key"
" <a href=\"https://console.mistral.ai/user/api-keys\">here</a>.</li>");
{
const QString modelName = "Mistral Tiny API";
const QString id = modelName;
const QString modelFilename = "gpt4all-mistral-tiny.rmodel";
if (contains(modelFilename))
changeId(modelFilename, id);
if (!contains(id))
addModel(id);
QVector<QPair<int, QVariant>> data {
{ ModelList::NameRole, modelName },
{ ModelList::FilenameRole, modelFilename },
{ ModelList::FilesizeRole, "minimal" },
{ ModelList::OnlineRole, true },
{ ModelList::DescriptionRole,
tr("<strong>Mistral Tiny model</strong><br>") + mistralDesc },
{ ModelList::RequiresVersionRole, "2.7.4" },
{ ModelList::OrderRole, "cc" },
{ ModelList::RamrequiredRole, 0 },
{ ModelList::ParametersRole, "?" },
{ ModelList::QuantRole, "NA" },
{ ModelList::TypeRole, "Mistral" },
{ ModelList::UrlRole, "https://api.mistral.ai/v1/chat/completions"},
};
updateData(id, data);
}
{
const QString modelName = "Mistral Small API";
const QString id = modelName;
const QString modelFilename = "gpt4all-mistral-small.rmodel";
if (contains(modelFilename))
changeId(modelFilename, id);
if (!contains(id))
addModel(id);
QVector<QPair<int, QVariant>> data {
{ ModelList::NameRole, modelName },
{ ModelList::FilenameRole, modelFilename },
{ ModelList::FilesizeRole, "minimal" },
{ ModelList::OnlineRole, true },
{ ModelList::DescriptionRole,
tr("<strong>Mistral Small model</strong><br>") + mistralDesc },
{ ModelList::RequiresVersionRole, "2.7.4" },
{ ModelList::OrderRole, "cd" },
{ ModelList::RamrequiredRole, 0 },
{ ModelList::ParametersRole, "?" },
{ ModelList::QuantRole, "NA" },
{ ModelList::TypeRole, "Mistral" },
{ ModelList::UrlRole, "https://api.mistral.ai/v1/chat/completions"},
};
updateData(id, data);
}
{
const QString modelName = "Mistral Medium API";
const QString id = modelName;
const QString modelFilename = "gpt4all-mistral-medium.rmodel";
if (contains(modelFilename))
changeId(modelFilename, id);
if (!contains(id))
addModel(id);
QVector<QPair<int, QVariant>> data {
{ ModelList::NameRole, modelName },
{ ModelList::FilenameRole, modelFilename },
{ ModelList::FilesizeRole, "minimal" },
{ ModelList::OnlineRole, true },
{ ModelList::DescriptionRole,
tr("<strong>Mistral Medium model</strong><br>") + mistralDesc },
{ ModelList::RequiresVersionRole, "2.7.4" },
{ ModelList::OrderRole, "ce" },
{ ModelList::RamrequiredRole, 0 },
{ ModelList::ParametersRole, "?" },
{ ModelList::QuantRole, "NA" },
{ ModelList::TypeRole, "Mistral" },
{ ModelList::UrlRole, "https://api.mistral.ai/v1/chat/completions"},
};
updateData(id, data);
}
{ {
const QString nomicEmbedDesc = tr("<ul><li>For use with LocalDocs feature</li>" const QString nomicEmbedDesc = tr("<ul><li>For use with LocalDocs feature</li>"