This commit is contained in:
Jared Van Bortel 2025-03-03 11:16:36 -05:00
parent 1ba555a174
commit 1dc9f22d5b
25 changed files with 556 additions and 889 deletions

View File

@ -13,7 +13,10 @@ add_subdirectory(src)
target_sources(gpt4all-backend PUBLIC
FILE_SET public_headers TYPE HEADERS BASE_DIRS include FILES
include/gpt4all-backend/formatters.h
include/gpt4all-backend/generation-params.h
include/gpt4all-backend/json-helpers.h
include/gpt4all-backend/ollama-client.h
include/gpt4all-backend/ollama-model.h
include/gpt4all-backend/ollama-types.h
include/gpt4all-backend/rest.h
)

View File

@ -0,0 +1,22 @@
#pragma once
#include <QtTypes>
namespace gpt4all::backend {
struct GenerationParams {
uint n_predict;
float temperature;
float top_p;
// int32_t top_k = 40;
// float min_p = 0.0f;
// int32_t n_batch = 9;
// float repeat_penalty = 1.10f;
// int32_t repeat_last_n = 64; // last n tokens to penalize
// float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};
} // namespace gpt4all::backend

View File

@ -22,6 +22,7 @@ class QRestReply;
namespace gpt4all::backend {
struct ResponseError {
public:
struct BadStatus { int code; };
@ -52,7 +53,7 @@ using DataOrRespErr = std::expected<T, ResponseError>;
class OllamaClient {
public:
OllamaClient(QUrl baseUrl, QString m_userAgent = QStringLiteral("GPT4All"))
OllamaClient(QUrl baseUrl, QString m_userAgent)
: m_baseUrl(baseUrl)
, m_userAgent(std::move(m_userAgent))
{}

View File

@ -0,0 +1,13 @@
#pragma once
namespace gpt4all::backend {
class OllamaClient;
class OllamaModel {
OllamaClient *client;
};
} // namespace gpt4all::backend

View File

@ -0,0 +1,13 @@
#pragma once
class QRestReply;
class QString;
namespace gpt4all::backend {
QString restErrorString(const QRestReply &reply);
} // namespace gpt4all::backend

View File

@ -5,6 +5,7 @@ add_library(${TARGET} STATIC
ollama-client.cpp
ollama-types.cpp
qt-json-stream.cpp
rest.cpp
)
target_compile_features(${TARGET} PUBLIC cxx_std_23)
gpt4all_add_warning_options(${TARGET})

View File

@ -2,6 +2,7 @@
#include "json-helpers.h" // IWYU pragma: keep
#include "qt-json-stream.h"
#include "rest.h"
#include <QCoro/QCoroIODevice> // IWYU pragma: keep
#include <QCoro/QCoroNetworkReply> // IWYU pragma: keep
@ -26,22 +27,14 @@ namespace gpt4all::backend {
ResponseError::ResponseError(const QRestReply *reply)
{
auto *nr = reply->networkReply();
if (reply->hasError()) {
error = nr->error();
errorString = nr->errorString();
error = reply->networkReply()->error();
} else if (!reply->isHttpStatusSuccess()) {
auto code = reply->httpStatus();
auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute);
error = BadStatus(code);
errorString = u"HTTP %1%2%3 for URL \"%4\""_s.arg(
QString::number(code),
reason.isValid() ? u" "_s : QString(),
reason.toString(),
nr->request().url().toString()
);
error = BadStatus(reply->httpStatus());
} else
Q_UNREACHABLE();
errorString = restErrorString(*reply);
}
QNetworkRequest OllamaClient::makeRequest(const QString &path) const

View File

@ -0,0 +1,34 @@
#include "rest.h"
#include <QNetworkReply>
#include <QRestReply>
#include <QString>
using namespace Qt::Literals::StringLiterals;
namespace gpt4all::backend {
QString restErrorString(const QRestReply &reply)
{
auto *nr = reply.networkReply();
if (reply.hasError())
return nr->errorString();
if (!reply.isHttpStatusSuccess()) {
auto code = reply.httpStatus();
auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute);
return u"HTTP %1%2%3 for URL \"%4\""_s.arg(
QString::number(code),
reason.isValid() ? u" "_s : QString(),
reason.toString(),
nr->request().url().toString()
);
}
Q_UNREACHABLE();
}
} // namespace gpt4all::backend

View File

@ -227,9 +227,10 @@ if (APPLE)
endif()
qt_add_executable(chat
src/llmodel/provider.cpp src/llmodel/provider.h
src/llmodel/openai.cpp src/llmodel/openai.h
src/main.cpp
src/chat.cpp src/chat.h
src/chatapi.cpp src/chatapi.h
src/chatlistmodel.cpp src/chatlistmodel.h
src/chatllm.cpp src/chatllm.h
src/chatmodel.h src/chatmodel.cpp
@ -456,8 +457,16 @@ else()
# Link PDFium
target_link_libraries(chat PRIVATE pdfium)
endif()
target_link_libraries(chat
PRIVATE gpt4all-backend llmodel nlohmann_json::nlohmann_json SingleApplication fmt::fmt duckx::duckx QXlsx)
target_link_libraries(chat PRIVATE
QCoro6::Core QCoro6::Network
QXlsx
SingleApplication
duckx::duckx
fmt::fmt
gpt4all-backend
llmodel
nlohmann_json::nlohmann_json
)
target_include_directories(chat PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/deps/minja/include)
if (APPLE)

View File

@ -412,21 +412,6 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
emit tokenSpeedChanged();
}
QString Chat::deviceBackend() const
{
return m_llmodel->deviceBackend();
}
QString Chat::device() const
{
return m_llmodel->device();
}
QString Chat::fallbackReason() const
{
return m_llmodel->fallbackReason();
}
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{
m_databaseResults = results;

View File

@ -39,9 +39,6 @@ class Chat : public QObject
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged)
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
// 0=no, 1=waiting, 2=working
Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
@ -122,10 +119,6 @@ public:
QString modelLoadingError() const { return m_modelLoadingError; }
QString tokenSpeed() const { return m_tokenSpeed; }
QString deviceBackend() const;
QString device() const;
// not loaded -> QString(), no fallback -> QString("")
QString fallbackReason() const;
int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }
@ -159,8 +152,6 @@ Q_SIGNALS:
void isServerChanged();
void collectionListChanged(const QList<QString> &collectionList);
void tokenSpeedChanged();
void deviceChanged();
void fallbackReasonChanged();
void collectionModelChanged();
void trySwitchContextInProgressChanged();
void loadedModelInfoChanged();
@ -192,8 +183,6 @@ private:
ModelInfo m_modelInfo;
QString m_modelLoadingError;
QString m_tokenSpeed;
QString m_device;
QString m_fallbackReason;
QList<QString> m_collections;
QList<QString> m_generatedQuestions;
ChatModel *m_chatModel;

View File

@ -1,359 +0,0 @@
#include "chatapi.h"
#include "utils.h"
#include <fmt/format.h>
#include <gpt4all-backend/formatters.h>
#include <QAnyStringView>
#include <QCoreApplication>
#include <QDebug>
#include <QGuiApplication>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonValue>
#include <QLatin1String>
#include <QNetworkAccessManager>
#include <QNetworkRequest>
#include <QStringView>
#include <QThread>
#include <QUrl>
#include <QUtf8StringView> // IWYU pragma: keep
#include <QVariant>
#include <QXmlStreamReader>
#include <Qt>
#include <QtAssert>
#include <QtLogging>
#include <expected>
#include <functional>
#include <iostream>
#include <utility>
using namespace Qt::Literals::StringLiterals;
//#define DEBUG
ChatAPI::ChatAPI()
: QObject(nullptr)
, m_modelName("gpt-3.5-turbo")
, m_requestURL("")
, m_responseCallback(nullptr)
{
}
size_t ChatAPI::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
{
Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
Q_UNUSED(ngl);
return 0;
}
bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl)
{
Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
Q_UNUSED(ngl);
return true;
}
void ChatAPI::setThreadCount(int32_t n_threads)
{
Q_UNUSED(n_threads);
}
int32_t ChatAPI::threadCount() const
{
return 1;
}
ChatAPI::~ChatAPI()
{
}
bool ChatAPI::isModelLoaded() const
{
return true;
}
static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QString>
{
QJsonArray messages;
auto xmlError = [&xml] {
return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString()));
};
if (xml.hasError())
return xmlError();
if (xml.atEnd())
return messages;
// skip header
bool foundElement = false;
do {
switch (xml.readNext()) {
using enum QXmlStreamReader::TokenType;
case Invalid:
return xmlError();
case EndDocument:
return messages;
default:
foundElement = true;
case StartDocument:
case Comment:
case DTD:
case ProcessingInstruction:
;
}
} while (!foundElement);
// document body loop
bool foundRoot = false;
for (;;) {
switch (xml.tokenType()) {
using enum QXmlStreamReader::TokenType;
case StartElement:
{
auto name = xml.name();
if (!foundRoot) {
if (name != "chat"_L1)
return std::unexpected(u"unexpected tag: %1"_s.arg(name));
foundRoot = true;
} else {
if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1)
return std::unexpected(u"unknown role: %1"_s.arg(name));
auto content = xml.readElementText();
if (xml.tokenType() != EndElement)
return xmlError();
messages << makeJsonObject({
{ "role"_L1, name.toString().trimmed() },
{ "content"_L1, content },
});
}
break;
}
case Characters:
if (!xml.isWhitespace())
return std::unexpected(u"unexpected text: %1"_s.arg(xml.text()));
case Comment:
case ProcessingInstruction:
case EndElement:
break;
case EndDocument:
return messages;
case Invalid:
return xmlError();
default:
return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString()));
}
xml.readNext();
}
}
void ChatAPI::prompt(
std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &promptCtx
) {
Q_UNUSED(promptCallback)
if (!isModelLoaded())
throw std::invalid_argument("Attempted to prompt an unloaded model.");
if (!promptCtx.n_predict)
return; // nothing requested
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
// the OpenAI tiktoken library or to implement our own tokenization function that matches precisely
// the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
// using the REST API to count tokens in a prompt.
auto root = makeJsonObject({
{ "model"_L1, m_modelName },
{ "stream"_L1, true },
{ "temperature"_L1, promptCtx.temp },
{ "top_p"_L1, promptCtx.top_p },
});
// conversation history
{
QUtf8StringView promptUtf8(prompt);
QXmlStreamReader xml(promptUtf8);
auto messages = parsePrompt(xml);
if (!messages) {
auto error = fmt::format("Failed to parse API model prompt: {}", messages.error());
qDebug().noquote() << "ChatAPI ERROR:" << error << "Prompt:\n\n" << promptUtf8 << '\n';
throw std::invalid_argument(error);
}
root.insert("messages"_L1, *messages);
}
QJsonDocument doc(root);
#if defined(DEBUG)
qDebug().noquote() << "ChatAPI::prompt begin network request" << doc.toJson();
#endif
m_responseCallback = responseCallback;
// The following code sets up a worker thread and object to perform the actual api request to
// chatgpt and then blocks until it is finished
QThread workerThread;
ChatAPIWorker worker(this);
worker.moveToThread(&workerThread);
connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection);
workerThread.start();
emit request(m_apiKey, doc.toJson(QJsonDocument::Compact));
workerThread.wait();
m_responseCallback = nullptr;
#if defined(DEBUG)
qDebug() << "ChatAPI::prompt end network request";
#endif
}
bool ChatAPI::callResponse(int32_t token, const std::string& string)
{
Q_ASSERT(m_responseCallback);
if (!m_responseCallback) {
std::cerr << "ChatAPI ERROR: no response callback!\n";
return false;
}
return m_responseCallback(token, string);
}
void ChatAPIWorker::request(const QString &apiKey, const QByteArray &array)
{
QUrl apiUrl(m_chat->url());
const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed();
QNetworkRequest request(apiUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8());
#if defined(DEBUG)
qDebug() << "ChatAPI::request"
<< "API URL: " << apiUrl.toString()
<< "Authorization: " << authorization.toUtf8();
#endif
m_networkManager = new QNetworkAccessManager(this);
QNetworkReply *reply = m_networkManager->post(request, array);
connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
connect(reply, &QNetworkReply::finished, this, &ChatAPIWorker::handleFinished);
connect(reply, &QNetworkReply::readyRead, this, &ChatAPIWorker::handleReadyRead);
connect(reply, &QNetworkReply::errorOccurred, this, &ChatAPIWorker::handleErrorOccurred);
}
void ChatAPIWorker::handleFinished()
{
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) {
emit finished();
return;
}
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
if (!response.isValid()) {
m_chat->callResponse(
-1,
tr("ERROR: Network error occurred while connecting to the API server")
.toStdString()
);
return;
}
bool ok;
int code = response.toInt(&ok);
if (!ok || code != 200) {
bool isReplyEmpty(reply->readAll().isEmpty());
if (isReplyEmpty)
m_chat->callResponse(
-1,
tr("ChatAPIWorker::handleFinished got HTTP Error %1 %2")
.arg(code)
.arg(reply->errorString())
.toStdString()
);
qWarning().noquote() << "ERROR: ChatAPIWorker::handleFinished got HTTP Error" << code << "response:"
<< reply->errorString();
}
reply->deleteLater();
emit finished();
}
void ChatAPIWorker::handleReadyRead()
{
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) {
emit finished();
return;
}
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
if (!response.isValid())
return;
bool ok;
int code = response.toInt(&ok);
if (!ok || code != 200) {
m_chat->callResponse(
-1,
u"ERROR: ChatAPIWorker::handleReadyRead got HTTP Error %1 %2: %3"_s
.arg(code).arg(reply->errorString(), reply->readAll()).toStdString()
);
emit finished();
return;
}
while (reply->canReadLine()) {
QString jsonData = reply->readLine().trimmed();
if (jsonData.startsWith("data:"))
jsonData.remove(0, 5);
jsonData = jsonData.trimmed();
if (jsonData.isEmpty())
continue;
if (jsonData == "[DONE]")
continue;
#if defined(DEBUG)
qDebug().noquote() << "line" << jsonData;
#endif
QJsonParseError err;
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
if (err.error != QJsonParseError::NoError) {
m_chat->callResponse(-1, u"ERROR: ChatAPI responded with invalid json \"%1\""_s
.arg(err.errorString()).toStdString());
continue;
}
const QJsonObject root = document.object();
const QJsonArray choices = root.value("choices").toArray();
const QJsonObject choice = choices.first().toObject();
const QJsonObject delta = choice.value("delta").toObject();
const QString content = delta.value("content").toString();
m_currentResponse += content;
if (!m_chat->callResponse(0, content.toStdString())) {
reply->abort();
emit finished();
return;
}
}
}
void ChatAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
{
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply || reply->error() == QNetworkReply::OperationCanceledError /*when we call abort on purpose*/) {
emit finished();
return;
}
qWarning().noquote() << "ERROR: ChatAPIWorker::handleErrorOccurred got HTTP Error" << code << "response:"
<< reply->errorString();
emit finished();
}

View File

@ -1,173 +0,0 @@
#ifndef CHATAPI_H
#define CHATAPI_H
#include <gpt4all-backend/llmodel.h>
#include <QByteArray>
#include <QNetworkReply>
#include <QObject>
#include <QString>
#include <QtPreprocessorSupport>
#include <cstddef>
#include <cstdint>
#include <span>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>
// IWYU pragma: no_forward_declare QByteArray
class ChatAPI;
class QNetworkAccessManager;
class ChatAPIWorker : public QObject {
Q_OBJECT
public:
ChatAPIWorker(ChatAPI *chatAPI)
: QObject(nullptr)
, m_networkManager(nullptr)
, m_chat(chatAPI) {}
virtual ~ChatAPIWorker() {}
QString currentResponse() const { return m_currentResponse; }
void request(const QString &apiKey, const QByteArray &array);
Q_SIGNALS:
void finished();
private Q_SLOTS:
void handleFinished();
void handleReadyRead();
void handleErrorOccurred(QNetworkReply::NetworkError code);
private:
ChatAPI *m_chat;
QNetworkAccessManager *m_networkManager;
QString m_currentResponse;
};
class ChatAPI : public QObject, public LLModel {
Q_OBJECT
public:
ChatAPI();
virtual ~ChatAPI();
bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t stateSize() const override
{ throwNotImplemented(); }
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override
{ Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); }
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
void prompt(std::string_view prompt,
const PromptCallback &promptCallback,
const ResponseCallback &responseCallback,
const PromptContext &ctx) override;
[[noreturn]]
int32_t countPromptTokens(std::string_view prompt) const override
{ Q_UNUSED(prompt); throwNotImplemented(); }
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
void setModelName(const QString &modelName) { m_modelName = modelName; }
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; }
QString url() const { return m_requestURL; }
bool callResponse(int32_t token, const std::string &string);
[[noreturn]]
int32_t contextLength() const override
{ throwNotImplemented(); }
auto specialTokens() -> std::unordered_map<std::string, std::string> const override
{ return {}; }
Q_SIGNALS:
void request(const QString &apiKey, const QByteArray &array);
protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
[[noreturn]]
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
[[noreturn]]
std::vector<Token> tokenize(std::string_view str) const override
{ Q_UNUSED(str); throwNotImplemented(); }
[[noreturn]]
bool isSpecialToken(Token id) const override
{ Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]]
std::string tokenToString(Token id) const override
{ Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]]
void initSampler(const PromptContext &ctx) override
{ Q_UNUSED(ctx); throwNotImplemented(); }
[[noreturn]]
Token sampleToken() const override
{ throwNotImplemented(); }
[[noreturn]]
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override
{ Q_UNUSED(nPast); Q_UNUSED(tokens); throwNotImplemented(); }
[[noreturn]]
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override
{ Q_UNUSED(promptCtx); Q_UNUSED(nPast); throwNotImplemented(); }
[[noreturn]]
int32_t inputLength() const override
{ throwNotImplemented(); }
[[noreturn]]
int32_t computeModelInputPosition(std::span<const Token> input) const override
{ Q_UNUSED(input); throwNotImplemented(); }
[[noreturn]]
void setModelInputPosition(int32_t pos) override
{ Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
void appendInputToken(Token tok) override
{ Q_UNUSED(tok); throwNotImplemented(); }
[[noreturn]]
const std::vector<Token> &endTokens() const override
{ throwNotImplemented(); }
[[noreturn]]
bool shouldAddBOS() const override
{ throwNotImplemented(); }
[[noreturn]]
std::span<const Token> inputTokens() const override
{ throwNotImplemented(); }
private:
ResponseCallback m_responseCallback;
QString m_modelName;
QString m_apiKey;
QString m_requestURL;
};
#endif // CHATAPI_H

View File

@ -1,17 +1,19 @@
#include "chatllm.h"
#include "chat.h"
#include "chatapi.h"
#include "chatmodel.h"
#include "jinja_helpers.h"
#include "llmodel/chat.h"
#include "llmodel/openai.h"
#include "localdocs.h"
#include "mysettings.h"
#include "network.h"
#include "tool.h"
#include "toolmodel.h"
#include "toolcallparser.h"
#include "toolmodel.h"
#include <fmt/format.h>
#include <gpt4all-backend/generation-params.h>
#include <minja/minja.hpp>
#include <nlohmann/json.hpp>
@ -64,6 +66,8 @@
using namespace Qt::Literals::StringLiterals;
using namespace ToolEnums;
using namespace gpt4all;
using namespace gpt4all::ui;
namespace ranges = std::ranges;
using json = nlohmann::ordered_json;
@ -115,14 +119,12 @@ public:
};
static auto promptModelWithTools(
LLModel *model, const LLModel::PromptCallback &promptCallback, BaseResponseHandler &respHandler,
const LLModel::PromptContext &ctx, const QByteArray &prompt, const QStringList &toolNames
ChatLLModel *model, BaseResponseHandler &respHandler, const backend::GenerationParams &params,
const QByteArray &prompt, const QStringList &toolNames
) -> std::pair<QStringList, bool>
{
ToolCallParser toolCallParser(toolNames);
auto handleResponse = [&toolCallParser, &respHandler](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
auto handleResponse = [&toolCallParser, &respHandler](std::string_view piece) -> bool {
toolCallParser.update(piece.data());
// Split the response into two if needed
@ -157,7 +159,7 @@ static auto promptModelWithTools(
return !shouldExecuteToolCall && !respHandler.getStopGenerating();
};
model->prompt(std::string_view(prompt), promptCallback, handleResponse, ctx);
model->prompt(std::string_view(prompt), promptCallback, handleResponse, params);
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
@ -217,9 +219,8 @@ void LLModelStore::destroy()
m_availableModel.reset();
}
void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) {
void LLModelInfo::resetModel(ChatLLM *cllm, ChatLLModel *model) {
this->model.reset(model);
fallbackReason.reset();
emit cllm->loadedModelInfoChanged();
}
@ -232,8 +233,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_stopGenerating(false)
, m_timer(nullptr)
, m_isServer(isServer)
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
, m_chatModel(parent->chatModel())
{
moveToThread(&m_llmThread);
@ -243,8 +242,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged);
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
@ -284,31 +281,6 @@ void ChatLLM::handleThreadStarted()
emit threadStarted();
}
void ChatLLM::handleForceMetalChanged(bool forceMetal)
{
#if defined(Q_OS_MAC) && defined(__aarch64__)
m_forceMetal = forceMetal;
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
m_reloadingToChangeVariant = false;
}
#else
Q_UNUSED(forceMetal);
#endif
}
void ChatLLM::handleDeviceChanged()
{
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
m_reloadingToChangeVariant = false;
}
}
bool ChatLLM::loadDefaultModel()
{
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
@ -325,10 +297,9 @@ void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
// and if so we just acquire it from the store and switch the context and return true. If the
// store doesn't have it or we're already loaded or in any other case just return false.
// If we're already loaded or a server or we're reloading to change the variant/device or the
// modelInfo is empty, then this should fail
// If we're already loaded or a server or the modelInfo is empty, then this should fail
if (
isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded
isModelLoaded() || m_isServer || modelInfo.name().isEmpty() || !m_shouldBeLoaded
) {
emit trySwitchContextOfLoadedModelCompleted(0);
return;
@ -409,7 +380,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
}
// Check if the store just gave us exactly the model we were looking for
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
@ -482,7 +453,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
emit loadedModelInfoChanged();
modelLoadProps.insert("requestedDevice", MySettings::globalInstance()->device());
modelLoadProps.insert("model", modelInfo.filename());
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
} else {
@ -504,43 +474,17 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
QElapsedTimer modelLoadTimer;
modelLoadTimer.start();
QString requestedDevice = MySettings::globalInstance()->device();
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
std::string backend = "auto";
#ifdef Q_OS_MAC
if (requestedDevice == "CPU") {
backend = "cpu";
} else if (m_forceMetal) {
#ifdef __aarch64__
backend = "metal";
#endif
}
#else // !defined(Q_OS_MAC)
if (requestedDevice.startsWith("CUDA: "))
backend = "cuda";
#endif
QString filePath = modelInfo.dirpath + modelInfo.filename();
auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx](std::string const &backend) {
auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx]() {
QString constructError;
m_llModelInfo.resetModel(this);
try {
auto *model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx);
auto *model = LLModel::Implementation::construct(filePath.toStdString(), "", n_ctx);
m_llModelInfo.resetModel(this, model);
} catch (const LLModel::MissingImplementationError &e) {
modelLoadProps.insert("error", "missing_model_impl");
constructError = e.what();
} catch (const LLModel::UnsupportedModelError &e) {
modelLoadProps.insert("error", "unsupported_model_file");
constructError = e.what();
} catch (const LLModel::BadArchError &e) {
constructError = e.what();
modelLoadProps.insert("error", "unsupported_model_arch");
modelLoadProps.insert("model_arch", QString::fromStdString(e.arch()));
}
if (!m_llModelInfo.model) {
if (!m_isServer)
@ -558,7 +502,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
return true;
};
if (!construct(backend))
if (!construct())
return true;
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
@ -572,58 +516,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
}
}
auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) {
float memGB = dev->heapSize / float(1024 * 1024 * 1024);
return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
};
std::vector<LLModel::GPUDevice> availableDevices;
const LLModel::GPUDevice *defaultDevice = nullptr;
{
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx, ngl);
availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
// Pick the best device
// NB: relies on the fact that Kompute devices are listed first
if (!availableDevices.empty() && availableDevices.front().type == 2 /*a discrete gpu*/) {
defaultDevice = &availableDevices.front();
float memGB = defaultDevice->heapSize / float(1024 * 1024 * 1024);
memGB = std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
modelLoadProps.insert("default_device", QString::fromStdString(defaultDevice->name));
modelLoadProps.insert("default_device_mem", approxDeviceMemGB(defaultDevice));
modelLoadProps.insert("default_device_backend", QString::fromStdString(defaultDevice->backendName()));
}
}
bool actualDeviceIsCPU = true;
#if defined(Q_OS_MAC) && defined(__aarch64__)
if (m_llModelInfo.model->implementation().buildVariant() == "metal")
actualDeviceIsCPU = false;
#else
if (requestedDevice != "CPU") {
const auto *device = defaultDevice;
if (requestedDevice != "Auto") {
// Use the selected device
for (const LLModel::GPUDevice &d : availableDevices) {
if (QString::fromStdString(d.selectionName()) == requestedDevice) {
device = &d;
break;
}
}
}
std::string unavail_reason;
if (!device) {
// GPU not available
} else if (!m_llModelInfo.model->initializeGPUDevice(device->index, &unavail_reason)) {
m_llModelInfo.fallbackReason = QString::fromStdString(unavail_reason);
} else {
actualDeviceIsCPU = false;
modelLoadProps.insert("requested_device_mem", approxDeviceMemGB(device));
}
}
#endif
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
if (!m_shouldBeLoaded) {
@ -635,35 +527,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
return false;
}
if (actualDeviceIsCPU) {
// we asked llama.cpp to use the CPU
} else if (!success) {
// llama_init_from_file returned nullptr
m_llModelInfo.fallbackReason = "GPU loading failed (out of VRAM?)";
modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed");
// For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls
if (backend == "cuda" && !construct("auto"))
return true;
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
if (!m_shouldBeLoaded) {
m_llModelInfo.resetModel(this);
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
resetModel();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
} else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate
// for instance if the quantization method is not supported on Vulkan yet
m_llModelInfo.fallbackReason = "model or quant has no GPU support";
modelLoadProps.insert("cpu_fallback_reason", "gpu_unsupported_model");
}
if (!success) {
m_llModelInfo.resetModel(this);
if (!m_isServer)
@ -756,7 +619,7 @@ void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
}
}
static LLModel::PromptContext promptContextFromSettings(const ModelInfo &modelInfo)
static backend::GenerationParams genParamsFromSettings(const ModelInfo &modelInfo)
{
auto *mySettings = MySettings::globalInstance();
return {
@ -779,7 +642,7 @@ void ChatLLM::prompt(const QStringList &enabledCollections)
}
try {
promptInternalChat(enabledCollections, promptContextFromSettings(m_modelInfo));
promptInternalChat(enabledCollections, genParamsFromSettings(m_modelInfo));
} catch (const std::exception &e) {
// FIXME(jared): this is neither translated nor serialized
m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what())));
@ -906,7 +769,7 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const MessageItem> items) cons
Q_UNREACHABLE();
}
auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx,
auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const backend::GenerationParams &params,
qsizetype startOffset) -> ChatPromptResult
{
Q_ASSERT(isModelLoaded());
@ -944,7 +807,7 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL
auto messageItems = getChat();
messageItems.pop_back(); // exclude new response
auto result = promptInternal(messageItems, ctx, !databaseResults.isEmpty());
auto result = promptInternal(messageItems, params, !databaseResults.isEmpty());
return {
/*PromptResult*/ {
.response = std::move(result.response),
@ -1014,7 +877,7 @@ private:
auto ChatLLM::promptInternal(
const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
const LLModel::PromptContext &ctx,
const backend::GenerationParams params,
bool usedLocalDocs
) -> PromptResult
{
@ -1052,13 +915,6 @@ auto ChatLLM::promptInternal(
PromptResult result {};
auto handlePrompt = [this, &result](std::span<const LLModel::Token> batch, bool cached) -> bool {
Q_UNUSED(cached)
result.promptTokens += batch.size();
m_timer->start();
return !m_stopGenerating;
};
QElapsedTimer totalTime;
totalTime.start();
ChatViewResponseHandler respHandler(this, &totalTime, &result);
@ -1070,8 +926,10 @@ auto ChatLLM::promptInternal(
emit promptProcessing();
m_llModelInfo.model->setThreadCount(mySettings->threadCount());
m_stopGenerating = false;
// TODO: set result.promptTokens based on ollama prompt_eval_count
// TODO: support interruption via m_stopGenerating
std::tie(finalBuffers, shouldExecuteTool) = promptModelWithTools(
m_llModelInfo.model.get(), handlePrompt, respHandler, ctx,
m_llModelInfo.model.get(), handlePrompt, respHandler, params,
QByteArray::fromRawData(conversation.data(), conversation.size()),
ToolCallConstants::AllTagNames
);
@ -1251,10 +1109,10 @@ void ChatLLM::generateName()
NameResponseHandler respHandler(this);
try {
// TODO: support interruption via m_stopGenerating
promptModelWithTools(
m_llModelInfo.model.get(),
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
respHandler, promptContextFromSettings(m_modelInfo),
respHandler, genParamsFromSettings(m_modelInfo),
applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(),
{ ToolCallConstants::ThinkTagName }
);
@ -1327,10 +1185,10 @@ void ChatLLM::generateQuestions(qint64 elapsed)
QElapsedTimer totalTime;
totalTime.start();
try {
// TODO: support interruption via m_stopGenerating
promptModelWithTools(
m_llModelInfo.model.get(),
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
respHandler, promptContextFromSettings(m_modelInfo),
respHandler, genParamsFromSettings(m_modelInfo),
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
{ ToolCallConstants::ThinkTagName }
);

View File

@ -3,10 +3,9 @@
#include "chatmodel.h"
#include "database.h"
#include "llmodel/chat.h"
#include "modellist.h"
#include <gpt4all-backend/llmodel.h>
#include <QByteArray>
#include <QElapsedTimer>
#include <QFileInfo>
@ -91,14 +90,9 @@ inline LLModelTypeV1 parseLLModelTypeV0(int v0)
}
struct LLModelInfo {
std::unique_ptr<LLModel> model;
std::unique_ptr<gpt4all::ui::ChatLLModel> model;
QFileInfo fileInfo;
std::optional<QString> fallbackReason;
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
// must be able to serialize the information even if it is in the unloaded state
void resetModel(ChatLLM *cllm, LLModel *model = nullptr);
void resetModel(ChatLLM *cllm, gpt4all::ui::ChatLLModel *model = nullptr);
};
class TokenTimer : public QObject {
@ -145,9 +139,6 @@ class Chat;
class ChatLLM : public QObject
{
Q_OBJECT
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
public:
ChatLLM(Chat *parent, bool isServer = false);
virtual ~ChatLLM();
@ -175,27 +166,6 @@ public:
void acquireModel();
void resetModel();
QString deviceBackend() const
{
if (!isModelLoaded()) return QString();
std::string name = LLModel::GPUDevice::backendIdToName(m_llModelInfo.model->backendName());
return QString::fromStdString(name);
}
QString device() const
{
if (!isModelLoaded()) return QString();
const char *name = m_llModelInfo.model->gpuDeviceName();
return name ? QString(name) : u"CPU"_s;
}
// not loaded -> QString(), no fallback -> QString("")
QString fallbackReason() const
{
if (!isModelLoaded()) return QString();
return m_llModelInfo.fallbackReason.value_or(u""_s);
}
bool serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);
@ -211,8 +181,6 @@ public Q_SLOTS:
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
Q_SIGNALS:
void loadedModelInfoChanged();
@ -233,8 +201,6 @@ Q_SIGNALS:
void trySwitchContextOfLoadedModelCompleted(int value);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
void reportFallbackReason(const QString &fallbackReason);
void databaseResultsChanged(const QList<ResultInfo>&);
void modelInfoChanged(const ModelInfo &modelInfo);
@ -249,12 +215,11 @@ protected:
QList<ResultInfo> databaseResults;
};
ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx,
qsizetype startOffset = 0);
auto promptInternalChat(const QStringList &enabledCollections, const gpt4all::backend::GenerationParams &params,
qsizetype startOffset = 0) -> ChatPromptResult;
// passing a string_view directly skips templating and uses the raw string
PromptResult promptInternal(const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
const LLModel::PromptContext &ctx,
bool usedLocalDocs);
auto promptInternal(const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
const gpt4all::backend::GenerationParams &params, bool usedLocalDocs) -> PromptResult;
private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
@ -282,8 +247,6 @@ private:
std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion;
bool m_isServer;
bool m_forceMetal;
bool m_reloadingToChangeVariant;
friend class ChatViewResponseHandler;
friend class SimpleResponseHandler;
};

View File

@ -88,73 +88,14 @@ bool EmbeddingLLMWorker::loadModel()
return false;
}
QString requestedDevice = MySettings::globalInstance()->localDocsEmbedDevice();
std::string backend = "auto";
#ifdef Q_OS_MAC
if (requestedDevice == "Auto" || requestedDevice == "CPU")
backend = "cpu";
#else
if (requestedDevice.startsWith("CUDA: "))
backend = "cuda";
#endif
try {
m_model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx);
m_model = LLModel::Implementation::construct(filePath.toStdString(), "", n_ctx);
} catch (const std::exception &e) {
qWarning() << "embllm WARNING: Could not load embedding model:" << e.what();
return false;
}
bool actualDeviceIsCPU = true;
#if defined(Q_OS_MAC) && defined(__aarch64__)
if (m_model->implementation().buildVariant() == "metal")
actualDeviceIsCPU = false;
#else
if (requestedDevice != "CPU") {
const LLModel::GPUDevice *device = nullptr;
std::vector<LLModel::GPUDevice> availableDevices = m_model->availableGPUDevices(0);
if (requestedDevice != "Auto") {
// Use the selected device
for (const LLModel::GPUDevice &d : availableDevices) {
if (QString::fromStdString(d.selectionName()) == requestedDevice) {
device = &d;
break;
}
}
}
std::string unavail_reason;
if (!device) {
// GPU not available
} else if (!m_model->initializeGPUDevice(device->index, &unavail_reason)) {
qWarning().noquote() << "embllm WARNING: Did not use GPU:" << QString::fromStdString(unavail_reason);
} else {
actualDeviceIsCPU = false;
}
}
#endif
bool success = m_model->loadModel(filePath.toStdString(), n_ctx, 100);
// CPU fallback
if (!actualDeviceIsCPU && !success) {
// llama_init_from_file returned nullptr
qWarning() << "embllm WARNING: Did not use GPU: GPU loading failed (out of VRAM?)";
if (backend == "cuda") {
// For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls
try {
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto", n_ctx);
} catch (const std::exception &e) {
qWarning() << "embllm WARNING: Could not load embedding model:" << e.what();
return false;
}
}
success = m_model->loadModel(filePath.toStdString(), n_ctx, 0);
}
if (!success) {
qWarning() << "embllm WARNING: Could not load embedding model";
delete m_model;

View File

@ -0,0 +1,32 @@
#pragma once
#include <QStringView>
class QString;
namespace QCoro { template <typename T> class AsyncGenerator; }
namespace gpt4all::backend { struct GenerationParams; }
namespace gpt4all::ui {
struct ChatResponseMetadata {
int nPromptTokens;
int nResponseTokens;
};
// TODO: implement two of these; one based on Ollama (TBD) and the other based on OpenAI (chatapi.h)
class ChatLLModel {
public:
virtual ~ChatLLModel() = 0;
[[nodiscard]]
virtual QString name() = 0;
virtual void preload() = 0;
virtual auto chat(QStringView prompt, const backend::GenerationParams &params,
/*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator<QString> = 0;
};
} // namespace gpt4all::ui

View File

@ -0,0 +1,217 @@
#include "openai.h"
#include "mysettings.h"
#include "utils.h"
#include <QCoro/QCoroAsyncGenerator> // IWYU pragma: keep
#include <QCoro/QCoroNetworkReply> // IWYU pragma: keep
#include <fmt/format.h>
#include <gpt4all-backend/formatters.h>
#include <gpt4all-backend/generation-params.h>
#include <gpt4all-backend/rest.h>
#include <QByteArray>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonValue>
#include <QLatin1String>
#include <QNetworkAccessManager>
#include <QNetworkRequest>
#include <QRestAccessManager>
#include <QRestReply>
#include <QStringView>
#include <QUrl>
#include <QUtf8StringView> // IWYU pragma: keep
#include <QVariant>
#include <QXmlStreamReader>
#include <Qt>
#include <expected>
#include <optional>
#include <utility>
using namespace Qt::Literals::StringLiterals;
//#define DEBUG
static auto processRespLine(const QByteArray &line) -> std::optional<QString>
{
auto jsonData = line.trimmed();
if (jsonData.startsWith("data:"_ba))
jsonData.remove(0, 5);
jsonData = jsonData.trimmed();
if (jsonData.isEmpty())
return std::nullopt;
if (jsonData == "[DONE]")
return std::nullopt;
QJsonParseError err;
auto document = QJsonDocument::fromJson(jsonData, &err);
if (document.isNull())
throw std::runtime_error(fmt::format("OpenAI chat response parsing failed: {}", err.errorString()));
auto root = document.object();
auto choices = root.value("choices").toArray();
auto choice = choices.first().toObject();
auto delta = choice.value("delta").toObject();
return delta.value("content").toString();
}
namespace gpt4all::ui {
void OpenaiModelDescription::setDisplayName(QString value)
{
if (m_displayName != value) {
m_displayName = std::move(value);
emit displayNameChanged(m_displayName);
}
}
void OpenaiModelDescription::setModelName(QString value)
{
if (m_modelName != value) {
m_modelName = std::move(value);
emit modelNameChanged(m_modelName);
}
}
OpenaiLLModel::OpenaiLLModel(OpenaiConnectionDetails connDetails, QNetworkAccessManager *nam)
: m_connDetails(std::move(connDetails))
, m_nam(nam)
{}
static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QString>
{
QJsonArray messages;
auto xmlError = [&xml] {
return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString()));
};
if (xml.hasError())
return xmlError();
if (xml.atEnd())
return messages;
// skip header
bool foundElement = false;
do {
switch (xml.readNext()) {
using enum QXmlStreamReader::TokenType;
case Invalid:
return xmlError();
case EndDocument:
return messages;
default:
foundElement = true;
case StartDocument:
case Comment:
case DTD:
case ProcessingInstruction:
;
}
} while (!foundElement);
// document body loop
bool foundRoot = false;
for (;;) {
switch (xml.tokenType()) {
using enum QXmlStreamReader::TokenType;
case StartElement:
{
auto name = xml.name();
if (!foundRoot) {
if (name != "chat"_L1)
return std::unexpected(u"unexpected tag: %1"_s.arg(name));
foundRoot = true;
} else {
if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1)
return std::unexpected(u"unknown role: %1"_s.arg(name));
auto content = xml.readElementText();
if (xml.tokenType() != EndElement)
return xmlError();
messages << makeJsonObject({
{ "role"_L1, name.toString().trimmed() },
{ "content"_L1, content },
});
}
break;
}
case Characters:
if (!xml.isWhitespace())
return std::unexpected(u"unexpected text: %1"_s.arg(xml.text()));
case Comment:
case ProcessingInstruction:
case EndElement:
break;
case EndDocument:
return messages;
case Invalid:
return xmlError();
default:
return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString()));
}
xml.readNext();
}
}
auto OpenaiLLModel::chat(QStringView prompt, const backend::GenerationParams &params,
/*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator<QString>
{
auto *mySettings = MySettings::globalInstance();
if (!params.n_predict)
co_return; // nothing requested
auto reqBody = makeJsonObject({
{ "model"_L1, m_connDetails.modelName },
{ "max_completion_tokens"_L1, qint64(params.n_predict) },
{ "stream"_L1, true },
{ "temperature"_L1, params.temperature },
{ "top_p"_L1, params.top_p },
});
// conversation history
{
QXmlStreamReader xml(prompt);
auto messages = parsePrompt(xml);
if (!messages)
throw std::invalid_argument(fmt::format("Failed to parse OpenAI prompt: {}", messages.error()));
reqBody.insert("messages"_L1, *messages);
}
QNetworkRequest request(m_connDetails.baseUrl.resolved(QUrl("/v1/chat/completions")));
request.setHeader(QNetworkRequest::UserAgentHeader, mySettings->userAgent());
request.setRawHeader("authorization", u"Bearer %1"_s.arg(m_connDetails.apiKey).toUtf8());
QRestAccessManager restNam(m_nam);
std::unique_ptr<QNetworkReply> reply(restNam.post(request, QJsonDocument(reqBody)));
auto makeError = [](const QRestReply &reply) {
return std::runtime_error(fmt::format("OpenAI chat request failed: {}", backend::restErrorString(reply)));
};
QRestReply restReply(reply.get());
if (reply->error())
throw makeError(restReply);
auto coroReply = qCoro(reply.get());
for (;;) {
auto line = co_await coroReply.readLine();
if (!restReply.isSuccess())
throw makeError(restReply);
if (line.isEmpty()) {
Q_ASSERT(reply->atEnd());
break;
}
if (auto chunk = processRespLine(line))
co_yield *chunk;
}
}
} // namespace gpt4all::ui

View File

@ -0,0 +1,75 @@
#pragma once
#include "chat.h"
#include "provider.h"
#include <QObject>
#include <QQmlEngine>
#include <QString>
#include <QUrl>
class QNetworkAccessManager;
namespace gpt4all::ui {
class OpenaiModelDescription : public QObject {
Q_OBJECT
QML_ELEMENT
public:
explicit OpenaiModelDescription(OpenaiProvider *provider, QString displayName, QString modelName)
: QObject(provider)
, m_provider(provider)
, m_displayName(std::move(displayName))
, m_modelName(std::move(modelName))
{}
// getters
[[nodiscard]] OpenaiProvider *provider () const { return m_provider; }
[[nodiscard]] const QString &displayName() const { return m_displayName; }
[[nodiscard]] const QString &modelName () const { return m_modelName; }
// setters
void setDisplayName(QString value);
void setModelName (QString value);
Q_SIGNALS:
void displayNameChanged(const QString &value);
void modelNameChanged (const QString &value);
private:
OpenaiProvider *m_provider;
QString m_displayName;
QString m_modelName;
};
struct OpenaiConnectionDetails {
QUrl baseUrl;
QString modelName;
QString apiKey;
OpenaiConnectionDetails(const OpenaiModelDescription *desc)
: baseUrl(desc->provider()->baseUrl())
, apiKey(desc->provider()->apiKey())
, modelName(desc->modelName())
{}
};
class OpenaiLLModel : public ChatLLModel {
public:
explicit OpenaiLLModel(OpenaiConnectionDetails connDetails, QNetworkAccessManager *nam);
void preload() override { /* not supported -> no-op */ }
auto chat(QStringView prompt, const backend::GenerationParams &params, /*out*/ ChatResponseMetadata &metadata)
-> QCoro::AsyncGenerator<QString> override;
private:
OpenaiConnectionDetails m_connDetails;
QNetworkAccessManager *m_nam;
};
} // namespace gpt4all::ui

View File

@ -0,0 +1,26 @@
#include "provider.h"
#include <utility>
namespace gpt4all::ui {
void OpenaiProvider::setBaseUrl(QUrl value)
{
if (m_baseUrl != value) {
m_baseUrl = std::move(value);
emit baseUrlChanged(m_baseUrl);
}
}
void OpenaiProvider::setApiKey(QString value)
{
if (m_apiKey != value) {
m_apiKey = std::move(value);
emit apiKeyChanged(m_apiKey);
}
}
} // namespace gpt4all::ui

View File

@ -0,0 +1,47 @@
#pragma once
#include <QObject>
#include <QQmlEngine>
#include <QString>
#include <QUrl>
namespace gpt4all::ui {
class ModelProvider : public QObject {
Q_OBJECT
Q_PROPERTY(QString name READ name CONSTANT)
public:
[[nodiscard]] virtual QString name() = 0;
};
class OpenaiProvider : public ModelProvider {
Q_OBJECT
QML_ELEMENT
Q_PROPERTY(QUrl baseUrl READ baseUrl WRITE setBaseUrl NOTIFY baseUrlChanged)
Q_PROPERTY(QString apiKey READ apiKey WRITE setApiKey NOTIFY apiKeyChanged)
public:
[[nodiscard]] QString name() override { return m_name; }
[[nodiscard]] const QUrl &baseUrl() { return m_baseUrl; }
[[nodiscard]] const QString &apiKey () { return m_apiKey; }
void setBaseUrl(QUrl value);
void setApiKey (QString value);
Q_SIGNALS:
void baseUrlChanged(const QUrl &value);
void apiKeyChanged (const QString &value);
private:
QString m_name;
QUrl m_baseUrl;
QString m_apiKey;
};
} // namespace gpt4all::ui

View File

@ -1,6 +1,7 @@
#include "mysettings.h"
#include "chatllm.h"
#include "config.h"
#include "modellist.h"
#include <gpt4all-backend/llmodel.h>
@ -48,7 +49,6 @@ namespace ModelSettingsKey { namespace {
namespace defaults {
static const int threadCount = std::min(4, (int32_t) std::thread::hardware_concurrency());
static const bool forceMetal = false;
static const bool networkIsActive = false;
static const bool networkUsageStatsActive = false;
static const QString device = "Auto";
@ -71,7 +71,6 @@ static const QVariantMap basicDefaults {
{ "localdocs/fileExtensions", QStringList { "docx", "pdf", "txt", "md", "rst" } },
{ "localdocs/useRemoteEmbed", false },
{ "localdocs/nomicAPIKey", "" },
{ "localdocs/embedDevice", "Auto" },
{ "network/attribution", "" },
};
@ -174,11 +173,16 @@ MySettings *MySettings::globalInstance()
MySettings::MySettings()
: QObject(nullptr)
, m_deviceList(getDevices())
, m_embeddingsDeviceList(getDevices(/*skipKompute*/ true))
, m_uiLanguages(getUiLanguages(modelPath()))
{
}
const QString &MySettings::userAgent()
{
static const QString s_userAgent = QStringLiteral("gpt4all/" APP_VERSION);
return s_userAgent;
}
QVariant MySettings::checkJinjaTemplateError(const QString &tmpl)
{
if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString()))
@ -256,7 +260,6 @@ void MySettings::restoreApplicationDefaults()
setNetworkPort(basicDefaults.value("networkPort").toInt());
setModelPath(defaultLocalModelsPath());
setUserDefaultModel(basicDefaults.value("userDefaultModel").toString());
setForceMetal(defaults::forceMetal);
setSuggestionMode(basicDefaults.value("suggestionMode").value<SuggestionMode>());
setLanguageAndLocale(defaults::languageAndLocale);
}
@ -269,7 +272,6 @@ void MySettings::restoreLocalDocsDefaults()
setLocalDocsFileExtensions(basicDefaults.value("localdocs/fileExtensions").toStringList());
setLocalDocsUseRemoteEmbed(basicDefaults.value("localdocs/useRemoteEmbed").toBool());
setLocalDocsNomicAPIKey(basicDefaults.value("localdocs/nomicAPIKey").toString());
setLocalDocsEmbedDevice(basicDefaults.value("localdocs/embedDevice").toString());
}
void MySettings::eraseModel(const ModelInfo &info)
@ -628,7 +630,6 @@ bool MySettings::localDocsShowReferences() const { return getBasicSetting
QStringList MySettings::localDocsFileExtensions() const { return getBasicSetting("localdocs/fileExtensions").toStringList(); }
bool MySettings::localDocsUseRemoteEmbed() const { return getBasicSetting("localdocs/useRemoteEmbed").toBool(); }
QString MySettings::localDocsNomicAPIKey() const { return getBasicSetting("localdocs/nomicAPIKey" ).toString(); }
QString MySettings::localDocsEmbedDevice() const { return getBasicSetting("localdocs/embedDevice" ).toString(); }
QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); }
ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnumSetting("chatTheme", chatThemeNames)); }
@ -646,7 +647,6 @@ void MySettings::setLocalDocsShowReferences(bool value) { setBasic
void MySettings::setLocalDocsFileExtensions(const QStringList &value) { setBasicSetting("localdocs/fileExtensions", value, "localDocsFileExtensions"); }
void MySettings::setLocalDocsUseRemoteEmbed(bool value) { setBasicSetting("localdocs/useRemoteEmbed", value, "localDocsUseRemoteEmbed"); }
void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasicSetting("localdocs/nomicAPIKey", value, "localDocsNomicAPIKey"); }
void MySettings::setLocalDocsEmbedDevice(const QString &value) { setBasicSetting("localdocs/embedDevice", value, "localDocsEmbedDevice"); }
void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); }
void MySettings::setChatTheme(ChatTheme value) { setBasicSetting("chatTheme", chatThemeNames .value(int(value))); }
@ -706,19 +706,6 @@ void MySettings::setDevice(const QString &value)
}
}
bool MySettings::forceMetal() const
{
return m_forceMetal;
}
void MySettings::setForceMetal(bool value)
{
if (m_forceMetal != value) {
m_forceMetal = value;
emit forceMetalChanged(value);
}
}
bool MySettings::networkIsActive() const
{
return m_settings.value("network/isActive", defaults::networkIsActive).toBool();

View File

@ -62,7 +62,6 @@ class MySettings : public QObject
Q_PROPERTY(ChatTheme chatTheme READ chatTheme WRITE setChatTheme NOTIFY chatThemeChanged)
Q_PROPERTY(FontSize fontSize READ fontSize WRITE setFontSize NOTIFY fontSizeChanged)
Q_PROPERTY(QString languageAndLocale READ languageAndLocale WRITE setLanguageAndLocale NOTIFY languageAndLocaleChanged)
Q_PROPERTY(bool forceMetal READ forceMetal WRITE setForceMetal NOTIFY forceMetalChanged)
Q_PROPERTY(QString lastVersionStarted READ lastVersionStarted WRITE setLastVersionStarted NOTIFY lastVersionStartedChanged)
Q_PROPERTY(int localDocsChunkSize READ localDocsChunkSize WRITE setLocalDocsChunkSize NOTIFY localDocsChunkSizeChanged)
Q_PROPERTY(int localDocsRetrievalSize READ localDocsRetrievalSize WRITE setLocalDocsRetrievalSize NOTIFY localDocsRetrievalSizeChanged)
@ -70,13 +69,11 @@ class MySettings : public QObject
Q_PROPERTY(QStringList localDocsFileExtensions READ localDocsFileExtensions WRITE setLocalDocsFileExtensions NOTIFY localDocsFileExtensionsChanged)
Q_PROPERTY(bool localDocsUseRemoteEmbed READ localDocsUseRemoteEmbed WRITE setLocalDocsUseRemoteEmbed NOTIFY localDocsUseRemoteEmbedChanged)
Q_PROPERTY(QString localDocsNomicAPIKey READ localDocsNomicAPIKey WRITE setLocalDocsNomicAPIKey NOTIFY localDocsNomicAPIKeyChanged)
Q_PROPERTY(QString localDocsEmbedDevice READ localDocsEmbedDevice WRITE setLocalDocsEmbedDevice NOTIFY localDocsEmbedDeviceChanged)
Q_PROPERTY(QString networkAttribution READ networkAttribution WRITE setNetworkAttribution NOTIFY networkAttributionChanged)
Q_PROPERTY(bool networkIsActive READ networkIsActive WRITE setNetworkIsActive NOTIFY networkIsActiveChanged)
Q_PROPERTY(bool networkUsageStatsActive READ networkUsageStatsActive WRITE setNetworkUsageStatsActive NOTIFY networkUsageStatsActiveChanged)
Q_PROPERTY(QString device READ device WRITE setDevice NOTIFY deviceChanged)
Q_PROPERTY(QStringList deviceList MEMBER m_deviceList CONSTANT)
Q_PROPERTY(QStringList embeddingsDeviceList MEMBER m_embeddingsDeviceList CONSTANT)
Q_PROPERTY(int networkPort READ networkPort WRITE setNetworkPort NOTIFY networkPortChanged)
Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged)
Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT)
@ -91,6 +88,8 @@ public Q_SLOTS:
public:
static MySettings *globalInstance();
static const QString &userAgent();
Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl);
// Restore methods
@ -172,8 +171,6 @@ public:
void setChatTheme(ChatTheme value);
FontSize fontSize() const;
void setFontSize(FontSize value);
bool forceMetal() const;
void setForceMetal(bool value);
QString device();
void setDevice(const QString &value);
int32_t contextLength() const;
@ -203,8 +200,6 @@ public:
void setLocalDocsUseRemoteEmbed(bool value);
QString localDocsNomicAPIKey() const;
void setLocalDocsNomicAPIKey(const QString &value);
QString localDocsEmbedDevice() const;
void setLocalDocsEmbedDevice(const QString &value);
// Network settings
QString networkAttribution() const;
@ -243,7 +238,6 @@ Q_SIGNALS:
void userDefaultModelChanged();
void chatThemeChanged();
void fontSizeChanged();
void forceMetalChanged(bool);
void lastVersionStartedChanged();
void localDocsChunkSizeChanged();
void localDocsRetrievalSizeChanged();
@ -251,7 +245,6 @@ Q_SIGNALS:
void localDocsFileExtensionsChanged();
void localDocsUseRemoteEmbedChanged();
void localDocsNomicAPIKeyChanged();
void localDocsEmbedDeviceChanged();
void networkAttributionChanged();
void networkIsActiveChanged();
void networkPortChanged();
@ -287,9 +280,7 @@ private:
private:
QSettings m_settings;
bool m_forceMetal;
const QStringList m_deviceList;
const QStringList m_embeddingsDeviceList;
const QStringList m_uiLanguages;
std::unique_ptr<QTranslator> m_translator;

View File

@ -372,8 +372,6 @@ void Network::trackChatEvent(const QString &ev, QVariantMap props)
Q_ASSERT(curChat);
if (!props.contains("model"))
props.insert("model", curChat->modelInfo().filename());
props.insert("device_backend", curChat->deviceBackend());
props.insert("actualDevice", curChat->device());
props.insert("doc_collections_enabled", curChat->collectionList().count());
props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount());
props.insert("datalake_active", MySettings::globalInstance()->networkIsActive());

View File

@ -7,6 +7,7 @@
#include <fmt/format.h>
#include <gpt4all-backend/formatters.h>
#include <gpt4all-backend/generation-params.h>
#include <gpt4all-backend/llmodel.h>
#include <QByteArray>
@ -126,7 +127,7 @@ class BaseCompletionRequest {
public:
QString model; // required
// NB: some parameters are not supported yet
int32_t max_tokens = 16;
uint max_tokens = 16;
qint64 n = 1;
float temperature = 1.f;
float top_p = 1.f;
@ -161,7 +162,7 @@ protected:
value = reqValue("max_tokens", Integer, false, /*min*/ 1);
if (!value.isNull())
this->max_tokens = int32_t(qMin(value.toInteger(), INT32_MAX));
this->max_tokens = uint(qMin(value.toInteger(), UINT32_MAX));
value = reqValue("n", Integer, false, /*min*/ 1);
if (!value.isNull())
@ -666,7 +667,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
m_chatModel->appendResponse();
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
LLModel::PromptContext promptCtx {
backend::GenerationParams genParams {
.n_predict = request.max_tokens,
.top_k = mySettings->modelTopK(modelInfo),
.top_p = request.top_p,
@ -685,7 +686,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
PromptResult result;
try {
result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()),
promptCtx,
genParams,
/*usedLocalDocs*/ false);
} catch (const std::exception &e) {
m_chatModel->setResponseValue(e.what());
@ -779,7 +780,7 @@ auto Server::handleChatRequest(const ChatRequest &request)
auto startOffset = m_chatModel->appendResponseWithHistory(messages);
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
LLModel::PromptContext promptCtx {
backend::GenerationParams genParams {
.n_predict = request.max_tokens,
.top_k = mySettings->modelTopK(modelInfo),
.top_p = request.top_p,
@ -796,7 +797,7 @@ auto Server::handleChatRequest(const ChatRequest &request)
for (int i = 0; i < request.n; ++i) {
ChatPromptResult result;
try {
result = promptInternalChat(m_collections, promptCtx, startOffset);
result = promptInternalChat(m_collections, genParams, startOffset);
} catch (const std::exception &e) {
m_chatModel->setResponseValue(e.what());
m_chatModel->setError();