feat: Add support for Mistral API models (#2053)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Signed-off-by: Cédric Sazos <cedric.sazos@tutanota.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Olyxz16
2024-03-13 23:23:57 +01:00
committed by GitHub
parent 406e88b59a
commit 2c0a660e6e
7 changed files with 242 additions and 98 deletions

View File

@@ -1,6 +1,6 @@
#include "chatllm.h"
#include "chat.h"
#include "chatgpt.h"
#include "chatapi.h"
#include "localdocs.h"
#include "modellist.h"
#include "network.h"
@@ -213,7 +213,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (isModelLoaded() && this->modelInfo() == modelInfo)
return true;
bool isChatGPT = modelInfo.isOnline; // right now only chatgpt is offered for online chat models...
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);
@@ -279,19 +278,23 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.fileInfo = fileInfo;
if (fileInfo.exists()) {
if (isChatGPT) {
if (modelInfo.isOnline) {
QString apiKey;
QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix
QString modelName;
{
QFile file(filePath);
file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text);
QTextStream stream(&file);
apiKey = stream.readAll();
file.close();
QString text = stream.readAll();
QJsonDocument doc = QJsonDocument::fromJson(text.toUtf8());
QJsonObject obj = doc.object();
apiKey = obj["apiKey"].toString();
modelName = obj["modelName"].toString();
}
m_llModelType = LLModelType::CHATGPT_;
ChatGPT *model = new ChatGPT();
model->setModelName(chatGPTModel);
m_llModelType = LLModelType::API_;
ChatAPI *model = new ChatAPI();
model->setModelName(modelName);
model->setRequestURL(modelInfo.url());
model->setAPIKey(apiKey);
m_llModelInfo.model = model;
} else {
@@ -468,7 +471,7 @@ void ChatLLM::regenerateResponse()
{
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens.
if (m_llModelType == LLModelType::CHATGPT_)
if (m_llModelType == LLModelType::API_)
m_ctx.n_past -= 1;
else
m_ctx.n_past -= m_promptResponseTokens;
@@ -958,12 +961,12 @@ void ChatLLM::saveState()
if (!isModelLoaded())
return;
if (m_llModelType == LLModelType::CHATGPT_) {
if (m_llModelType == LLModelType::API_) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_llModelInfo.model);
stream << chatGPT->context();
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
stream << chatAPI->context();
return;
}
@@ -980,13 +983,13 @@ void ChatLLM::restoreState()
if (!isModelLoaded())
return;
if (m_llModelType == LLModelType::CHATGPT_) {
if (m_llModelType == LLModelType::API_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_llModelInfo.model);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
QList<QString> context;
stream >> context;
chatGPT->setContext(context);
chatAPI->setContext(context);
m_state.clear();
m_state.squeeze();
return;