diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index eba241b4..10077efd 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -13,7 +13,10 @@ add_subdirectory(src) target_sources(gpt4all-backend PUBLIC FILE_SET public_headers TYPE HEADERS BASE_DIRS include FILES include/gpt4all-backend/formatters.h + include/gpt4all-backend/generation-params.h include/gpt4all-backend/json-helpers.h include/gpt4all-backend/ollama-client.h + include/gpt4all-backend/ollama-model.h include/gpt4all-backend/ollama-types.h + include/gpt4all-backend/rest.h ) diff --git a/gpt4all-backend/include/gpt4all-backend/generation-params.h b/gpt4all-backend/include/gpt4all-backend/generation-params.h new file mode 100644 index 00000000..09cec5ef --- /dev/null +++ b/gpt4all-backend/include/gpt4all-backend/generation-params.h @@ -0,0 +1,22 @@ +#pragma once + +#include + + +namespace gpt4all::backend { + + +struct GenerationParams { + uint n_predict; + float temperature; + float top_p; + // int32_t top_k = 40; + // float min_p = 0.0f; + // int32_t n_batch = 9; + // float repeat_penalty = 1.10f; + // int32_t repeat_last_n = 64; // last n tokens to penalize + // float contextErase = 0.5f; // percent of context to erase if we exceed the context window +}; + + +} // namespace gpt4all::backend diff --git a/gpt4all-backend/include/gpt4all-backend/ollama-client.h b/gpt4all-backend/include/gpt4all-backend/ollama-client.h index 4f19e1fe..4c251024 100644 --- a/gpt4all-backend/include/gpt4all-backend/ollama-client.h +++ b/gpt4all-backend/include/gpt4all-backend/ollama-client.h @@ -22,6 +22,7 @@ class QRestReply; namespace gpt4all::backend { + struct ResponseError { public: struct BadStatus { int code; }; @@ -52,7 +53,7 @@ using DataOrRespErr = std::expected; class OllamaClient { public: - OllamaClient(QUrl baseUrl, QString m_userAgent = QStringLiteral("GPT4All")) + OllamaClient(QUrl baseUrl, QString m_userAgent) : m_baseUrl(baseUrl) , m_userAgent(std::move(m_userAgent)) {} diff --git a/gpt4all-backend/include/gpt4all-backend/ollama-model.h b/gpt4all-backend/include/gpt4all-backend/ollama-model.h new file mode 100644 index 00000000..bd163f05 --- /dev/null +++ b/gpt4all-backend/include/gpt4all-backend/ollama-model.h @@ -0,0 +1,13 @@ +#pragma once + +namespace gpt4all::backend { + + +class OllamaClient; + +class OllamaModel { + OllamaClient *client; +}; + + +} // namespace gpt4all::backend diff --git a/gpt4all-backend/include/gpt4all-backend/rest.h b/gpt4all-backend/include/gpt4all-backend/rest.h new file mode 100644 index 00000000..63e79941 --- /dev/null +++ b/gpt4all-backend/include/gpt4all-backend/rest.h @@ -0,0 +1,13 @@ +#pragma once + +class QRestReply; +class QString; + + +namespace gpt4all::backend { + + +QString restErrorString(const QRestReply &reply); + + +} // namespace gpt4all::backend diff --git a/gpt4all-backend/src/CMakeLists.txt b/gpt4all-backend/src/CMakeLists.txt index c6e618db..212264c3 100644 --- a/gpt4all-backend/src/CMakeLists.txt +++ b/gpt4all-backend/src/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(${TARGET} STATIC ollama-client.cpp ollama-types.cpp qt-json-stream.cpp + rest.cpp ) target_compile_features(${TARGET} PUBLIC cxx_std_23) gpt4all_add_warning_options(${TARGET}) diff --git a/gpt4all-backend/src/ollama-client.cpp b/gpt4all-backend/src/ollama-client.cpp index f7f0e2e0..eda623d8 100644 --- a/gpt4all-backend/src/ollama-client.cpp +++ b/gpt4all-backend/src/ollama-client.cpp @@ -2,6 +2,7 @@ #include "json-helpers.h" // IWYU pragma: keep #include "qt-json-stream.h" +#include "rest.h" #include // IWYU pragma: keep #include // IWYU pragma: keep @@ -26,22 +27,14 @@ namespace gpt4all::backend { ResponseError::ResponseError(const QRestReply *reply) { - auto *nr = reply->networkReply(); if (reply->hasError()) { - error = nr->error(); - errorString = nr->errorString(); + error = reply->networkReply()->error(); } else if (!reply->isHttpStatusSuccess()) { - auto code = reply->httpStatus(); - auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute); - error = BadStatus(code); - errorString = u"HTTP %1%2%3 for URL \"%4\""_s.arg( - QString::number(code), - reason.isValid() ? u" "_s : QString(), - reason.toString(), - nr->request().url().toString() - ); + error = BadStatus(reply->httpStatus()); } else Q_UNREACHABLE(); + + errorString = restErrorString(*reply); } QNetworkRequest OllamaClient::makeRequest(const QString &path) const diff --git a/gpt4all-backend/src/rest.cpp b/gpt4all-backend/src/rest.cpp new file mode 100644 index 00000000..974a2bb0 --- /dev/null +++ b/gpt4all-backend/src/rest.cpp @@ -0,0 +1,34 @@ +#include "rest.h" + +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + + +namespace gpt4all::backend { + + +QString restErrorString(const QRestReply &reply) +{ + auto *nr = reply.networkReply(); + if (reply.hasError()) + return nr->errorString(); + + if (!reply.isHttpStatusSuccess()) { + auto code = reply.httpStatus(); + auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute); + return u"HTTP %1%2%3 for URL \"%4\""_s.arg( + QString::number(code), + reason.isValid() ? u" "_s : QString(), + reason.toString(), + nr->request().url().toString() + ); + } + + Q_UNREACHABLE(); +} + + +} // namespace gpt4all::backend diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 366904de..07485bf7 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -227,9 +227,10 @@ if (APPLE) endif() qt_add_executable(chat + src/llmodel/provider.cpp src/llmodel/provider.h + src/llmodel/openai.cpp src/llmodel/openai.h src/main.cpp src/chat.cpp src/chat.h - src/chatapi.cpp src/chatapi.h src/chatlistmodel.cpp src/chatlistmodel.h src/chatllm.cpp src/chatllm.h src/chatmodel.h src/chatmodel.cpp @@ -456,8 +457,16 @@ else() # Link PDFium target_link_libraries(chat PRIVATE pdfium) endif() -target_link_libraries(chat - PRIVATE gpt4all-backend llmodel nlohmann_json::nlohmann_json SingleApplication fmt::fmt duckx::duckx QXlsx) +target_link_libraries(chat PRIVATE + QCoro6::Core QCoro6::Network + QXlsx + SingleApplication + duckx::duckx + fmt::fmt + gpt4all-backend + llmodel + nlohmann_json::nlohmann_json +) target_include_directories(chat PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/deps/minja/include) if (APPLE) diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index 52e8e629..e5c4ea58 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -412,21 +412,6 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed) emit tokenSpeedChanged(); } -QString Chat::deviceBackend() const -{ - return m_llmodel->deviceBackend(); -} - -QString Chat::device() const -{ - return m_llmodel->device(); -} - -QString Chat::fallbackReason() const -{ - return m_llmodel->fallbackReason(); -} - void Chat::handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; diff --git a/gpt4all-chat/src/chat.h b/gpt4all-chat/src/chat.h index f4ac654f..452f3970 100644 --- a/gpt4all-chat/src/chat.h +++ b/gpt4all-chat/src/chat.h @@ -39,9 +39,6 @@ class Chat : public QObject Q_PROPERTY(QList collectionList READ collectionList NOTIFY collectionListChanged) Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged) Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged) - Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) - Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) - Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged) // 0=no, 1=waiting, 2=working Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged) @@ -122,10 +119,6 @@ public: QString modelLoadingError() const { return m_modelLoadingError; } QString tokenSpeed() const { return m_tokenSpeed; } - QString deviceBackend() const; - QString device() const; - // not loaded -> QString(), no fallback -> QString("") - QString fallbackReason() const; int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; } @@ -159,8 +152,6 @@ Q_SIGNALS: void isServerChanged(); void collectionListChanged(const QList &collectionList); void tokenSpeedChanged(); - void deviceChanged(); - void fallbackReasonChanged(); void collectionModelChanged(); void trySwitchContextInProgressChanged(); void loadedModelInfoChanged(); @@ -192,8 +183,6 @@ private: ModelInfo m_modelInfo; QString m_modelLoadingError; QString m_tokenSpeed; - QString m_device; - QString m_fallbackReason; QList m_collections; QList m_generatedQuestions; ChatModel *m_chatModel; diff --git a/gpt4all-chat/src/chatapi.cpp b/gpt4all-chat/src/chatapi.cpp deleted file mode 100644 index 272b11df..00000000 --- a/gpt4all-chat/src/chatapi.cpp +++ /dev/null @@ -1,359 +0,0 @@ -#include "chatapi.h" - -#include "utils.h" - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // IWYU pragma: keep -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -using namespace Qt::Literals::StringLiterals; - -//#define DEBUG - - -ChatAPI::ChatAPI() - : QObject(nullptr) - , m_modelName("gpt-3.5-turbo") - , m_requestURL("") - , m_responseCallback(nullptr) -{ -} - -size_t ChatAPI::requiredMem(const std::string &modelPath, int n_ctx, int ngl) -{ - Q_UNUSED(modelPath); - Q_UNUSED(n_ctx); - Q_UNUSED(ngl); - return 0; -} - -bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl) -{ - Q_UNUSED(modelPath); - Q_UNUSED(n_ctx); - Q_UNUSED(ngl); - return true; -} - -void ChatAPI::setThreadCount(int32_t n_threads) -{ - Q_UNUSED(n_threads); -} - -int32_t ChatAPI::threadCount() const -{ - return 1; -} - -ChatAPI::~ChatAPI() -{ -} - -bool ChatAPI::isModelLoaded() const -{ - return true; -} - -static auto parsePrompt(QXmlStreamReader &xml) -> std::expected -{ - QJsonArray messages; - - auto xmlError = [&xml] { - return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString())); - }; - - if (xml.hasError()) - return xmlError(); - if (xml.atEnd()) - return messages; - - // skip header - bool foundElement = false; - do { - switch (xml.readNext()) { - using enum QXmlStreamReader::TokenType; - case Invalid: - return xmlError(); - case EndDocument: - return messages; - default: - foundElement = true; - case StartDocument: - case Comment: - case DTD: - case ProcessingInstruction: - ; - } - } while (!foundElement); - - // document body loop - bool foundRoot = false; - for (;;) { - switch (xml.tokenType()) { - using enum QXmlStreamReader::TokenType; - case StartElement: - { - auto name = xml.name(); - if (!foundRoot) { - if (name != "chat"_L1) - return std::unexpected(u"unexpected tag: %1"_s.arg(name)); - foundRoot = true; - } else { - if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1) - return std::unexpected(u"unknown role: %1"_s.arg(name)); - auto content = xml.readElementText(); - if (xml.tokenType() != EndElement) - return xmlError(); - messages << makeJsonObject({ - { "role"_L1, name.toString().trimmed() }, - { "content"_L1, content }, - }); - } - break; - } - case Characters: - if (!xml.isWhitespace()) - return std::unexpected(u"unexpected text: %1"_s.arg(xml.text())); - case Comment: - case ProcessingInstruction: - case EndElement: - break; - case EndDocument: - return messages; - case Invalid: - return xmlError(); - default: - return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString())); - } - xml.readNext(); - } -} - -void ChatAPI::prompt( - std::string_view prompt, - const PromptCallback &promptCallback, - const ResponseCallback &responseCallback, - const PromptContext &promptCtx -) { - Q_UNUSED(promptCallback) - - if (!isModelLoaded()) - throw std::invalid_argument("Attempted to prompt an unloaded model."); - if (!promptCtx.n_predict) - return; // nothing requested - - // FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering - // an error we need to be able to count the tokens in our prompt. The only way to do this is to use - // the OpenAI tiktoken library or to implement our own tokenization function that matches precisely - // the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of - // using the REST API to count tokens in a prompt. - auto root = makeJsonObject({ - { "model"_L1, m_modelName }, - { "stream"_L1, true }, - { "temperature"_L1, promptCtx.temp }, - { "top_p"_L1, promptCtx.top_p }, - }); - - // conversation history - { - QUtf8StringView promptUtf8(prompt); - QXmlStreamReader xml(promptUtf8); - auto messages = parsePrompt(xml); - if (!messages) { - auto error = fmt::format("Failed to parse API model prompt: {}", messages.error()); - qDebug().noquote() << "ChatAPI ERROR:" << error << "Prompt:\n\n" << promptUtf8 << '\n'; - throw std::invalid_argument(error); - } - root.insert("messages"_L1, *messages); - } - - QJsonDocument doc(root); - -#if defined(DEBUG) - qDebug().noquote() << "ChatAPI::prompt begin network request" << doc.toJson(); -#endif - - m_responseCallback = responseCallback; - - // The following code sets up a worker thread and object to perform the actual api request to - // chatgpt and then blocks until it is finished - QThread workerThread; - ChatAPIWorker worker(this); - worker.moveToThread(&workerThread); - connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); - connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection); - workerThread.start(); - emit request(m_apiKey, doc.toJson(QJsonDocument::Compact)); - workerThread.wait(); - - m_responseCallback = nullptr; - -#if defined(DEBUG) - qDebug() << "ChatAPI::prompt end network request"; -#endif -} - -bool ChatAPI::callResponse(int32_t token, const std::string& string) -{ - Q_ASSERT(m_responseCallback); - if (!m_responseCallback) { - std::cerr << "ChatAPI ERROR: no response callback!\n"; - return false; - } - return m_responseCallback(token, string); -} - -void ChatAPIWorker::request(const QString &apiKey, const QByteArray &array) -{ - QUrl apiUrl(m_chat->url()); - const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed(); - QNetworkRequest request(apiUrl); - request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - request.setRawHeader("Authorization", authorization.toUtf8()); -#if defined(DEBUG) - qDebug() << "ChatAPI::request" - << "API URL: " << apiUrl.toString() - << "Authorization: " << authorization.toUtf8(); -#endif - m_networkManager = new QNetworkAccessManager(this); - QNetworkReply *reply = m_networkManager->post(request, array); - connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort); - connect(reply, &QNetworkReply::finished, this, &ChatAPIWorker::handleFinished); - connect(reply, &QNetworkReply::readyRead, this, &ChatAPIWorker::handleReadyRead); - connect(reply, &QNetworkReply::errorOccurred, this, &ChatAPIWorker::handleErrorOccurred); -} - -void ChatAPIWorker::handleFinished() -{ - QNetworkReply *reply = qobject_cast(sender()); - if (!reply) { - emit finished(); - return; - } - - QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); - - if (!response.isValid()) { - m_chat->callResponse( - -1, - tr("ERROR: Network error occurred while connecting to the API server") - .toStdString() - ); - return; - } - - bool ok; - int code = response.toInt(&ok); - if (!ok || code != 200) { - bool isReplyEmpty(reply->readAll().isEmpty()); - if (isReplyEmpty) - m_chat->callResponse( - -1, - tr("ChatAPIWorker::handleFinished got HTTP Error %1 %2") - .arg(code) - .arg(reply->errorString()) - .toStdString() - ); - qWarning().noquote() << "ERROR: ChatAPIWorker::handleFinished got HTTP Error" << code << "response:" - << reply->errorString(); - } - reply->deleteLater(); - emit finished(); -} - -void ChatAPIWorker::handleReadyRead() -{ - QNetworkReply *reply = qobject_cast(sender()); - if (!reply) { - emit finished(); - return; - } - - QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); - - if (!response.isValid()) - return; - - bool ok; - int code = response.toInt(&ok); - if (!ok || code != 200) { - m_chat->callResponse( - -1, - u"ERROR: ChatAPIWorker::handleReadyRead got HTTP Error %1 %2: %3"_s - .arg(code).arg(reply->errorString(), reply->readAll()).toStdString() - ); - emit finished(); - return; - } - - while (reply->canReadLine()) { - QString jsonData = reply->readLine().trimmed(); - if (jsonData.startsWith("data:")) - jsonData.remove(0, 5); - jsonData = jsonData.trimmed(); - if (jsonData.isEmpty()) - continue; - if (jsonData == "[DONE]") - continue; -#if defined(DEBUG) - qDebug().noquote() << "line" << jsonData; -#endif - QJsonParseError err; - const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err); - if (err.error != QJsonParseError::NoError) { - m_chat->callResponse(-1, u"ERROR: ChatAPI responded with invalid json \"%1\""_s - .arg(err.errorString()).toStdString()); - continue; - } - - const QJsonObject root = document.object(); - const QJsonArray choices = root.value("choices").toArray(); - const QJsonObject choice = choices.first().toObject(); - const QJsonObject delta = choice.value("delta").toObject(); - const QString content = delta.value("content").toString(); - m_currentResponse += content; - if (!m_chat->callResponse(0, content.toStdString())) { - reply->abort(); - emit finished(); - return; - } - } -} - -void ChatAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code) -{ - QNetworkReply *reply = qobject_cast(sender()); - if (!reply || reply->error() == QNetworkReply::OperationCanceledError /*when we call abort on purpose*/) { - emit finished(); - return; - } - - qWarning().noquote() << "ERROR: ChatAPIWorker::handleErrorOccurred got HTTP Error" << code << "response:" - << reply->errorString(); - emit finished(); -} diff --git a/gpt4all-chat/src/chatapi.h b/gpt4all-chat/src/chatapi.h deleted file mode 100644 index f937a20d..00000000 --- a/gpt4all-chat/src/chatapi.h +++ /dev/null @@ -1,173 +0,0 @@ -#ifndef CHATAPI_H -#define CHATAPI_H - -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -// IWYU pragma: no_forward_declare QByteArray -class ChatAPI; -class QNetworkAccessManager; - - -class ChatAPIWorker : public QObject { - Q_OBJECT -public: - ChatAPIWorker(ChatAPI *chatAPI) - : QObject(nullptr) - , m_networkManager(nullptr) - , m_chat(chatAPI) {} - virtual ~ChatAPIWorker() {} - - QString currentResponse() const { return m_currentResponse; } - - void request(const QString &apiKey, const QByteArray &array); - -Q_SIGNALS: - void finished(); - -private Q_SLOTS: - void handleFinished(); - void handleReadyRead(); - void handleErrorOccurred(QNetworkReply::NetworkError code); - -private: - ChatAPI *m_chat; - QNetworkAccessManager *m_networkManager; - QString m_currentResponse; -}; - -class ChatAPI : public QObject, public LLModel { - Q_OBJECT -public: - ChatAPI(); - virtual ~ChatAPI(); - - bool supportsEmbedding() const override { return false; } - bool supportsCompletion() const override { return true; } - bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override; - bool isModelLoaded() const override; - size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override; - - // All three of the state virtual functions are handled custom inside of chatllm save/restore - size_t stateSize() const override - { throwNotImplemented(); } - size_t saveState(std::span stateOut, std::vector &inputTokensOut) const override - { Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); } - size_t restoreState(std::span state, std::span inputTokens) override - { Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); } - - void prompt(std::string_view prompt, - const PromptCallback &promptCallback, - const ResponseCallback &responseCallback, - const PromptContext &ctx) override; - - [[noreturn]] - int32_t countPromptTokens(std::string_view prompt) const override - { Q_UNUSED(prompt); throwNotImplemented(); } - - void setThreadCount(int32_t n_threads) override; - int32_t threadCount() const override; - - void setModelName(const QString &modelName) { m_modelName = modelName; } - void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } - void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; } - QString url() const { return m_requestURL; } - - bool callResponse(int32_t token, const std::string &string); - - [[noreturn]] - int32_t contextLength() const override - { throwNotImplemented(); } - - auto specialTokens() -> std::unordered_map const override - { return {}; } - -Q_SIGNALS: - void request(const QString &apiKey, const QByteArray &array); - -protected: - // We have to implement these as they are pure virtual in base class, but we don't actually use - // them as they are only called from the default implementation of 'prompt' which we override and - // completely replace - - [[noreturn]] - static void throwNotImplemented() { throw std::logic_error("not implemented"); } - - [[noreturn]] - std::vector tokenize(std::string_view str) const override - { Q_UNUSED(str); throwNotImplemented(); } - - [[noreturn]] - bool isSpecialToken(Token id) const override - { Q_UNUSED(id); throwNotImplemented(); } - - [[noreturn]] - std::string tokenToString(Token id) const override - { Q_UNUSED(id); throwNotImplemented(); } - - [[noreturn]] - void initSampler(const PromptContext &ctx) override - { Q_UNUSED(ctx); throwNotImplemented(); } - - [[noreturn]] - Token sampleToken() const override - { throwNotImplemented(); } - - [[noreturn]] - bool evalTokens(int32_t nPast, std::span tokens) const override - { Q_UNUSED(nPast); Q_UNUSED(tokens); throwNotImplemented(); } - - [[noreturn]] - void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override - { Q_UNUSED(promptCtx); Q_UNUSED(nPast); throwNotImplemented(); } - - [[noreturn]] - int32_t inputLength() const override - { throwNotImplemented(); } - - [[noreturn]] - int32_t computeModelInputPosition(std::span input) const override - { Q_UNUSED(input); throwNotImplemented(); } - - [[noreturn]] - void setModelInputPosition(int32_t pos) override - { Q_UNUSED(pos); throwNotImplemented(); } - - [[noreturn]] - void appendInputToken(Token tok) override - { Q_UNUSED(tok); throwNotImplemented(); } - - [[noreturn]] - const std::vector &endTokens() const override - { throwNotImplemented(); } - - [[noreturn]] - bool shouldAddBOS() const override - { throwNotImplemented(); } - - [[noreturn]] - std::span inputTokens() const override - { throwNotImplemented(); } - -private: - ResponseCallback m_responseCallback; - QString m_modelName; - QString m_apiKey; - QString m_requestURL; -}; - -#endif // CHATAPI_H diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index 1fbcce8c..5186bc4e 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -1,17 +1,19 @@ #include "chatllm.h" #include "chat.h" -#include "chatapi.h" #include "chatmodel.h" #include "jinja_helpers.h" +#include "llmodel/chat.h" +#include "llmodel/openai.h" #include "localdocs.h" #include "mysettings.h" #include "network.h" #include "tool.h" -#include "toolmodel.h" #include "toolcallparser.h" +#include "toolmodel.h" #include +#include #include #include @@ -64,6 +66,8 @@ using namespace Qt::Literals::StringLiterals; using namespace ToolEnums; +using namespace gpt4all; +using namespace gpt4all::ui; namespace ranges = std::ranges; using json = nlohmann::ordered_json; @@ -115,14 +119,12 @@ public: }; static auto promptModelWithTools( - LLModel *model, const LLModel::PromptCallback &promptCallback, BaseResponseHandler &respHandler, - const LLModel::PromptContext &ctx, const QByteArray &prompt, const QStringList &toolNames + ChatLLModel *model, BaseResponseHandler &respHandler, const backend::GenerationParams ¶ms, + const QByteArray &prompt, const QStringList &toolNames ) -> std::pair { ToolCallParser toolCallParser(toolNames); - auto handleResponse = [&toolCallParser, &respHandler](LLModel::Token token, std::string_view piece) -> bool { - Q_UNUSED(token) - + auto handleResponse = [&toolCallParser, &respHandler](std::string_view piece) -> bool { toolCallParser.update(piece.data()); // Split the response into two if needed @@ -157,7 +159,7 @@ static auto promptModelWithTools( return !shouldExecuteToolCall && !respHandler.getStopGenerating(); }; - model->prompt(std::string_view(prompt), promptCallback, handleResponse, ctx); + model->prompt(std::string_view(prompt), promptCallback, handleResponse, params); const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete && toolCallParser.startTag() != ToolCallConstants::ThinkStartTag; @@ -217,9 +219,8 @@ void LLModelStore::destroy() m_availableModel.reset(); } -void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) { +void LLModelInfo::resetModel(ChatLLM *cllm, ChatLLModel *model) { this->model.reset(model); - fallbackReason.reset(); emit cllm->loadedModelInfoChanged(); } @@ -232,8 +233,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_stopGenerating(false) , m_timer(nullptr) , m_isServer(isServer) - , m_forceMetal(MySettings::globalInstance()->forceMetal()) - , m_reloadingToChangeVariant(false) , m_chatModel(parent->chatModel()) { moveToThread(&m_llmThread); @@ -243,8 +242,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) Qt::QueuedConnection); // explicitly queued connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); - connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); - connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged); // The following are blocking operations and will block the llm thread connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, @@ -284,31 +281,6 @@ void ChatLLM::handleThreadStarted() emit threadStarted(); } -void ChatLLM::handleForceMetalChanged(bool forceMetal) -{ -#if defined(Q_OS_MAC) && defined(__aarch64__) - m_forceMetal = forceMetal; - if (isModelLoaded() && m_shouldBeLoaded) { - m_reloadingToChangeVariant = true; - unloadModel(); - reloadModel(); - m_reloadingToChangeVariant = false; - } -#else - Q_UNUSED(forceMetal); -#endif -} - -void ChatLLM::handleDeviceChanged() -{ - if (isModelLoaded() && m_shouldBeLoaded) { - m_reloadingToChangeVariant = true; - unloadModel(); - reloadModel(); - m_reloadingToChangeVariant = false; - } -} - bool ChatLLM::loadDefaultModel() { ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); @@ -325,10 +297,9 @@ void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) // and if so we just acquire it from the store and switch the context and return true. If the // store doesn't have it or we're already loaded or in any other case just return false. - // If we're already loaded or a server or we're reloading to change the variant/device or the - // modelInfo is empty, then this should fail + // If we're already loaded or a server or the modelInfo is empty, then this should fail if ( - isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded + isModelLoaded() || m_isServer || modelInfo.name().isEmpty() || !m_shouldBeLoaded ) { emit trySwitchContextOfLoadedModelCompleted(0); return; @@ -409,7 +380,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) } // Check if the store just gave us exactly the model we were looking for - if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) { + if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) { #if defined(DEBUG_MODEL_LOADING) qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif @@ -482,7 +453,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f); emit loadedModelInfoChanged(); - modelLoadProps.insert("requestedDevice", MySettings::globalInstance()->device()); modelLoadProps.insert("model", modelInfo.filename()); Network::globalInstance()->trackChatEvent("model_load", modelLoadProps); } else { @@ -504,43 +474,17 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro QElapsedTimer modelLoadTimer; modelLoadTimer.start(); - QString requestedDevice = MySettings::globalInstance()->device(); int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo); int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo); std::string backend = "auto"; -#ifdef Q_OS_MAC - if (requestedDevice == "CPU") { - backend = "cpu"; - } else if (m_forceMetal) { -#ifdef __aarch64__ - backend = "metal"; -#endif - } -#else // !defined(Q_OS_MAC) - if (requestedDevice.startsWith("CUDA: ")) - backend = "cuda"; -#endif - QString filePath = modelInfo.dirpath + modelInfo.filename(); - auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx](std::string const &backend) { + auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx]() { QString constructError; m_llModelInfo.resetModel(this); - try { - auto *model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx); - m_llModelInfo.resetModel(this, model); - } catch (const LLModel::MissingImplementationError &e) { - modelLoadProps.insert("error", "missing_model_impl"); - constructError = e.what(); - } catch (const LLModel::UnsupportedModelError &e) { - modelLoadProps.insert("error", "unsupported_model_file"); - constructError = e.what(); - } catch (const LLModel::BadArchError &e) { - constructError = e.what(); - modelLoadProps.insert("error", "unsupported_model_arch"); - modelLoadProps.insert("model_arch", QString::fromStdString(e.arch())); - } + auto *model = LLModel::Implementation::construct(filePath.toStdString(), "", n_ctx); + m_llModelInfo.resetModel(this, model); if (!m_llModelInfo.model) { if (!m_isServer) @@ -558,7 +502,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro return true; }; - if (!construct(backend)) + if (!construct()) return true; if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) { @@ -572,58 +516,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro } } - auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) { - float memGB = dev->heapSize / float(1024 * 1024 * 1024); - return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place - }; - - std::vector availableDevices; - const LLModel::GPUDevice *defaultDevice = nullptr; - { - const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx, ngl); - availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); - // Pick the best device - // NB: relies on the fact that Kompute devices are listed first - if (!availableDevices.empty() && availableDevices.front().type == 2 /*a discrete gpu*/) { - defaultDevice = &availableDevices.front(); - float memGB = defaultDevice->heapSize / float(1024 * 1024 * 1024); - memGB = std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place - modelLoadProps.insert("default_device", QString::fromStdString(defaultDevice->name)); - modelLoadProps.insert("default_device_mem", approxDeviceMemGB(defaultDevice)); - modelLoadProps.insert("default_device_backend", QString::fromStdString(defaultDevice->backendName())); - } - } - - bool actualDeviceIsCPU = true; - -#if defined(Q_OS_MAC) && defined(__aarch64__) - if (m_llModelInfo.model->implementation().buildVariant() == "metal") - actualDeviceIsCPU = false; -#else - if (requestedDevice != "CPU") { - const auto *device = defaultDevice; - if (requestedDevice != "Auto") { - // Use the selected device - for (const LLModel::GPUDevice &d : availableDevices) { - if (QString::fromStdString(d.selectionName()) == requestedDevice) { - device = &d; - break; - } - } - } - - std::string unavail_reason; - if (!device) { - // GPU not available - } else if (!m_llModelInfo.model->initializeGPUDevice(device->index, &unavail_reason)) { - m_llModelInfo.fallbackReason = QString::fromStdString(unavail_reason); - } else { - actualDeviceIsCPU = false; - modelLoadProps.insert("requested_device_mem", approxDeviceMemGB(device)); - } - } -#endif - bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl); if (!m_shouldBeLoaded) { @@ -635,35 +527,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro return false; } - if (actualDeviceIsCPU) { - // we asked llama.cpp to use the CPU - } else if (!success) { - // llama_init_from_file returned nullptr - m_llModelInfo.fallbackReason = "GPU loading failed (out of VRAM?)"; - modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed"); - - // For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls - if (backend == "cuda" && !construct("auto")) - return true; - - success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0); - - if (!m_shouldBeLoaded) { - m_llModelInfo.resetModel(this); - if (!m_isServer) - LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); - resetModel(); - emit modelLoadingPercentageChanged(0.0f); - return false; - } - } else if (!m_llModelInfo.model->usingGPUDevice()) { - // ggml_vk_init was not called in llama.cpp - // We might have had to fallback to CPU after load if the model is not possible to accelerate - // for instance if the quantization method is not supported on Vulkan yet - m_llModelInfo.fallbackReason = "model or quant has no GPU support"; - modelLoadProps.insert("cpu_fallback_reason", "gpu_unsupported_model"); - } - if (!success) { m_llModelInfo.resetModel(this); if (!m_isServer) @@ -756,7 +619,7 @@ void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo) } } -static LLModel::PromptContext promptContextFromSettings(const ModelInfo &modelInfo) +static backend::GenerationParams genParamsFromSettings(const ModelInfo &modelInfo) { auto *mySettings = MySettings::globalInstance(); return { @@ -779,7 +642,7 @@ void ChatLLM::prompt(const QStringList &enabledCollections) } try { - promptInternalChat(enabledCollections, promptContextFromSettings(m_modelInfo)); + promptInternalChat(enabledCollections, genParamsFromSettings(m_modelInfo)); } catch (const std::exception &e) { // FIXME(jared): this is neither translated nor serialized m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what()))); @@ -906,7 +769,7 @@ std::string ChatLLM::applyJinjaTemplate(std::span items) cons Q_UNREACHABLE(); } -auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx, +auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const backend::GenerationParams ¶ms, qsizetype startOffset) -> ChatPromptResult { Q_ASSERT(isModelLoaded()); @@ -944,7 +807,7 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL auto messageItems = getChat(); messageItems.pop_back(); // exclude new response - auto result = promptInternal(messageItems, ctx, !databaseResults.isEmpty()); + auto result = promptInternal(messageItems, params, !databaseResults.isEmpty()); return { /*PromptResult*/ { .response = std::move(result.response), @@ -1014,7 +877,7 @@ private: auto ChatLLM::promptInternal( const std::variant, std::string_view> &prompt, - const LLModel::PromptContext &ctx, + const backend::GenerationParams params, bool usedLocalDocs ) -> PromptResult { @@ -1052,13 +915,6 @@ auto ChatLLM::promptInternal( PromptResult result {}; - auto handlePrompt = [this, &result](std::span batch, bool cached) -> bool { - Q_UNUSED(cached) - result.promptTokens += batch.size(); - m_timer->start(); - return !m_stopGenerating; - }; - QElapsedTimer totalTime; totalTime.start(); ChatViewResponseHandler respHandler(this, &totalTime, &result); @@ -1070,8 +926,10 @@ auto ChatLLM::promptInternal( emit promptProcessing(); m_llModelInfo.model->setThreadCount(mySettings->threadCount()); m_stopGenerating = false; + // TODO: set result.promptTokens based on ollama prompt_eval_count + // TODO: support interruption via m_stopGenerating std::tie(finalBuffers, shouldExecuteTool) = promptModelWithTools( - m_llModelInfo.model.get(), handlePrompt, respHandler, ctx, + m_llModelInfo.model.get(), handlePrompt, respHandler, params, QByteArray::fromRawData(conversation.data(), conversation.size()), ToolCallConstants::AllTagNames ); @@ -1251,10 +1109,10 @@ void ChatLLM::generateName() NameResponseHandler respHandler(this); try { + // TODO: support interruption via m_stopGenerating promptModelWithTools( m_llModelInfo.model.get(), - /*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; }, - respHandler, promptContextFromSettings(m_modelInfo), + respHandler, genParamsFromSettings(m_modelInfo), applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(), { ToolCallConstants::ThinkTagName } ); @@ -1327,10 +1185,10 @@ void ChatLLM::generateQuestions(qint64 elapsed) QElapsedTimer totalTime; totalTime.start(); try { + // TODO: support interruption via m_stopGenerating promptModelWithTools( m_llModelInfo.model.get(), - /*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; }, - respHandler, promptContextFromSettings(m_modelInfo), + respHandler, genParamsFromSettings(m_modelInfo), applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(), { ToolCallConstants::ThinkTagName } ); diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index c9ec4c21..46adf231 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -3,10 +3,9 @@ #include "chatmodel.h" #include "database.h" +#include "llmodel/chat.h" #include "modellist.h" -#include - #include #include #include @@ -91,14 +90,9 @@ inline LLModelTypeV1 parseLLModelTypeV0(int v0) } struct LLModelInfo { - std::unique_ptr model; + std::unique_ptr model; QFileInfo fileInfo; - std::optional fallbackReason; - - // NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which - // must be able to serialize the information even if it is in the unloaded state - - void resetModel(ChatLLM *cllm, LLModel *model = nullptr); + void resetModel(ChatLLM *cllm, gpt4all::ui::ChatLLModel *model = nullptr); }; class TokenTimer : public QObject { @@ -145,9 +139,6 @@ class Chat; class ChatLLM : public QObject { Q_OBJECT - Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) - Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) - Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) public: ChatLLM(Chat *parent, bool isServer = false); virtual ~ChatLLM(); @@ -175,27 +166,6 @@ public: void acquireModel(); void resetModel(); - QString deviceBackend() const - { - if (!isModelLoaded()) return QString(); - std::string name = LLModel::GPUDevice::backendIdToName(m_llModelInfo.model->backendName()); - return QString::fromStdString(name); - } - - QString device() const - { - if (!isModelLoaded()) return QString(); - const char *name = m_llModelInfo.model->gpuDeviceName(); - return name ? QString(name) : u"CPU"_s; - } - - // not loaded -> QString(), no fallback -> QString("") - QString fallbackReason() const - { - if (!isModelLoaded()) return QString(); - return m_llModelInfo.fallbackReason.value_or(u""_s); - } - bool serialize(QDataStream &stream, int version); bool deserialize(QDataStream &stream, int version); @@ -211,8 +181,6 @@ public Q_SLOTS: void handleChatIdChanged(const QString &id); void handleShouldBeLoadedChanged(); void handleThreadStarted(); - void handleForceMetalChanged(bool forceMetal); - void handleDeviceChanged(); Q_SIGNALS: void loadedModelInfoChanged(); @@ -233,8 +201,6 @@ Q_SIGNALS: void trySwitchContextOfLoadedModelCompleted(int value); void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); - void reportDevice(const QString &device); - void reportFallbackReason(const QString &fallbackReason); void databaseResultsChanged(const QList&); void modelInfoChanged(const ModelInfo &modelInfo); @@ -249,12 +215,11 @@ protected: QList databaseResults; }; - ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx, - qsizetype startOffset = 0); + auto promptInternalChat(const QStringList &enabledCollections, const gpt4all::backend::GenerationParams ¶ms, + qsizetype startOffset = 0) -> ChatPromptResult; // passing a string_view directly skips templating and uses the raw string - PromptResult promptInternal(const std::variant, std::string_view> &prompt, - const LLModel::PromptContext &ctx, - bool usedLocalDocs); + auto promptInternal(const std::variant, std::string_view> &prompt, + const gpt4all::backend::GenerationParams ¶ms, bool usedLocalDocs) -> PromptResult; private: bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); @@ -282,8 +247,6 @@ private: std::atomic m_forceUnloadModel; std::atomic m_markedForDeletion; bool m_isServer; - bool m_forceMetal; - bool m_reloadingToChangeVariant; friend class ChatViewResponseHandler; friend class SimpleResponseHandler; }; diff --git a/gpt4all-chat/src/embllm.cpp b/gpt4all-chat/src/embllm.cpp index 964ab7b7..e4e53820 100644 --- a/gpt4all-chat/src/embllm.cpp +++ b/gpt4all-chat/src/embllm.cpp @@ -88,73 +88,14 @@ bool EmbeddingLLMWorker::loadModel() return false; } - QString requestedDevice = MySettings::globalInstance()->localDocsEmbedDevice(); - std::string backend = "auto"; -#ifdef Q_OS_MAC - if (requestedDevice == "Auto" || requestedDevice == "CPU") - backend = "cpu"; -#else - if (requestedDevice.startsWith("CUDA: ")) - backend = "cuda"; -#endif - try { - m_model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx); + m_model = LLModel::Implementation::construct(filePath.toStdString(), "", n_ctx); } catch (const std::exception &e) { qWarning() << "embllm WARNING: Could not load embedding model:" << e.what(); return false; } - bool actualDeviceIsCPU = true; - -#if defined(Q_OS_MAC) && defined(__aarch64__) - if (m_model->implementation().buildVariant() == "metal") - actualDeviceIsCPU = false; -#else - if (requestedDevice != "CPU") { - const LLModel::GPUDevice *device = nullptr; - std::vector availableDevices = m_model->availableGPUDevices(0); - if (requestedDevice != "Auto") { - // Use the selected device - for (const LLModel::GPUDevice &d : availableDevices) { - if (QString::fromStdString(d.selectionName()) == requestedDevice) { - device = &d; - break; - } - } - } - - std::string unavail_reason; - if (!device) { - // GPU not available - } else if (!m_model->initializeGPUDevice(device->index, &unavail_reason)) { - qWarning().noquote() << "embllm WARNING: Did not use GPU:" << QString::fromStdString(unavail_reason); - } else { - actualDeviceIsCPU = false; - } - } -#endif - bool success = m_model->loadModel(filePath.toStdString(), n_ctx, 100); - - // CPU fallback - if (!actualDeviceIsCPU && !success) { - // llama_init_from_file returned nullptr - qWarning() << "embllm WARNING: Did not use GPU: GPU loading failed (out of VRAM?)"; - - if (backend == "cuda") { - // For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls - try { - m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto", n_ctx); - } catch (const std::exception &e) { - qWarning() << "embllm WARNING: Could not load embedding model:" << e.what(); - return false; - } - } - - success = m_model->loadModel(filePath.toStdString(), n_ctx, 0); - } - if (!success) { qWarning() << "embllm WARNING: Could not load embedding model"; delete m_model; diff --git a/gpt4all-chat/src/llmodel/chat.h b/gpt4all-chat/src/llmodel/chat.h new file mode 100644 index 00000000..e13f4065 --- /dev/null +++ b/gpt4all-chat/src/llmodel/chat.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +class QString; +namespace QCoro { template class AsyncGenerator; } +namespace gpt4all::backend { struct GenerationParams; } + + +namespace gpt4all::ui { + + +struct ChatResponseMetadata { + int nPromptTokens; + int nResponseTokens; +}; + +// TODO: implement two of these; one based on Ollama (TBD) and the other based on OpenAI (chatapi.h) +class ChatLLModel { +public: + virtual ~ChatLLModel() = 0; + + [[nodiscard]] + virtual QString name() = 0; + + virtual void preload() = 0; + virtual auto chat(QStringView prompt, const backend::GenerationParams ¶ms, + /*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator = 0; +}; + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/llmodel/openai.cpp b/gpt4all-chat/src/llmodel/openai.cpp new file mode 100644 index 00000000..32a0b58f --- /dev/null +++ b/gpt4all-chat/src/llmodel/openai.cpp @@ -0,0 +1,217 @@ +#include "openai.h" + +#include "mysettings.h" +#include "utils.h" + +#include // IWYU pragma: keep +#include // IWYU pragma: keep +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: keep +#include +#include +#include + +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + +//#define DEBUG + + +static auto processRespLine(const QByteArray &line) -> std::optional +{ + auto jsonData = line.trimmed(); + if (jsonData.startsWith("data:"_ba)) + jsonData.remove(0, 5); + jsonData = jsonData.trimmed(); + if (jsonData.isEmpty()) + return std::nullopt; + if (jsonData == "[DONE]") + return std::nullopt; + + QJsonParseError err; + auto document = QJsonDocument::fromJson(jsonData, &err); + if (document.isNull()) + throw std::runtime_error(fmt::format("OpenAI chat response parsing failed: {}", err.errorString())); + + auto root = document.object(); + auto choices = root.value("choices").toArray(); + auto choice = choices.first().toObject(); + auto delta = choice.value("delta").toObject(); + return delta.value("content").toString(); +} + + +namespace gpt4all::ui { + + +void OpenaiModelDescription::setDisplayName(QString value) +{ + if (m_displayName != value) { + m_displayName = std::move(value); + emit displayNameChanged(m_displayName); + } +} + +void OpenaiModelDescription::setModelName(QString value) +{ + if (m_modelName != value) { + m_modelName = std::move(value); + emit modelNameChanged(m_modelName); + } +} + +OpenaiLLModel::OpenaiLLModel(OpenaiConnectionDetails connDetails, QNetworkAccessManager *nam) + : m_connDetails(std::move(connDetails)) + , m_nam(nam) + {} + +static auto parsePrompt(QXmlStreamReader &xml) -> std::expected +{ + QJsonArray messages; + + auto xmlError = [&xml] { + return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString())); + }; + + if (xml.hasError()) + return xmlError(); + if (xml.atEnd()) + return messages; + + // skip header + bool foundElement = false; + do { + switch (xml.readNext()) { + using enum QXmlStreamReader::TokenType; + case Invalid: + return xmlError(); + case EndDocument: + return messages; + default: + foundElement = true; + case StartDocument: + case Comment: + case DTD: + case ProcessingInstruction: + ; + } + } while (!foundElement); + + // document body loop + bool foundRoot = false; + for (;;) { + switch (xml.tokenType()) { + using enum QXmlStreamReader::TokenType; + case StartElement: + { + auto name = xml.name(); + if (!foundRoot) { + if (name != "chat"_L1) + return std::unexpected(u"unexpected tag: %1"_s.arg(name)); + foundRoot = true; + } else { + if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1) + return std::unexpected(u"unknown role: %1"_s.arg(name)); + auto content = xml.readElementText(); + if (xml.tokenType() != EndElement) + return xmlError(); + messages << makeJsonObject({ + { "role"_L1, name.toString().trimmed() }, + { "content"_L1, content }, + }); + } + break; + } + case Characters: + if (!xml.isWhitespace()) + return std::unexpected(u"unexpected text: %1"_s.arg(xml.text())); + case Comment: + case ProcessingInstruction: + case EndElement: + break; + case EndDocument: + return messages; + case Invalid: + return xmlError(); + default: + return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString())); + } + xml.readNext(); + } +} + +auto OpenaiLLModel::chat(QStringView prompt, const backend::GenerationParams ¶ms, + /*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator +{ + auto *mySettings = MySettings::globalInstance(); + + if (!params.n_predict) + co_return; // nothing requested + + auto reqBody = makeJsonObject({ + { "model"_L1, m_connDetails.modelName }, + { "max_completion_tokens"_L1, qint64(params.n_predict) }, + { "stream"_L1, true }, + { "temperature"_L1, params.temperature }, + { "top_p"_L1, params.top_p }, + }); + + // conversation history + { + QXmlStreamReader xml(prompt); + auto messages = parsePrompt(xml); + if (!messages) + throw std::invalid_argument(fmt::format("Failed to parse OpenAI prompt: {}", messages.error())); + reqBody.insert("messages"_L1, *messages); + } + + QNetworkRequest request(m_connDetails.baseUrl.resolved(QUrl("/v1/chat/completions"))); + request.setHeader(QNetworkRequest::UserAgentHeader, mySettings->userAgent()); + request.setRawHeader("authorization", u"Bearer %1"_s.arg(m_connDetails.apiKey).toUtf8()); + + QRestAccessManager restNam(m_nam); + std::unique_ptr reply(restNam.post(request, QJsonDocument(reqBody))); + + auto makeError = [](const QRestReply &reply) { + return std::runtime_error(fmt::format("OpenAI chat request failed: {}", backend::restErrorString(reply))); + }; + + QRestReply restReply(reply.get()); + if (reply->error()) + throw makeError(restReply); + + auto coroReply = qCoro(reply.get()); + for (;;) { + auto line = co_await coroReply.readLine(); + if (!restReply.isSuccess()) + throw makeError(restReply); + if (line.isEmpty()) { + Q_ASSERT(reply->atEnd()); + break; + } + if (auto chunk = processRespLine(line)) + co_yield *chunk; + } +} + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/llmodel/openai.h b/gpt4all-chat/src/llmodel/openai.h new file mode 100644 index 00000000..3046a5c0 --- /dev/null +++ b/gpt4all-chat/src/llmodel/openai.h @@ -0,0 +1,75 @@ +#pragma once + +#include "chat.h" +#include "provider.h" + +#include +#include +#include +#include + +class QNetworkAccessManager; + + +namespace gpt4all::ui { + + +class OpenaiModelDescription : public QObject { + Q_OBJECT + QML_ELEMENT + +public: + explicit OpenaiModelDescription(OpenaiProvider *provider, QString displayName, QString modelName) + : QObject(provider) + , m_provider(provider) + , m_displayName(std::move(displayName)) + , m_modelName(std::move(modelName)) + {} + + // getters + [[nodiscard]] OpenaiProvider *provider () const { return m_provider; } + [[nodiscard]] const QString &displayName() const { return m_displayName; } + [[nodiscard]] const QString &modelName () const { return m_modelName; } + + // setters + void setDisplayName(QString value); + void setModelName (QString value); + +Q_SIGNALS: + void displayNameChanged(const QString &value); + void modelNameChanged (const QString &value); + +private: + OpenaiProvider *m_provider; + QString m_displayName; + QString m_modelName; +}; + +struct OpenaiConnectionDetails { + QUrl baseUrl; + QString modelName; + QString apiKey; + + OpenaiConnectionDetails(const OpenaiModelDescription *desc) + : baseUrl(desc->provider()->baseUrl()) + , apiKey(desc->provider()->apiKey()) + , modelName(desc->modelName()) + {} +}; + +class OpenaiLLModel : public ChatLLModel { +public: + explicit OpenaiLLModel(OpenaiConnectionDetails connDetails, QNetworkAccessManager *nam); + + void preload() override { /* not supported -> no-op */ } + + auto chat(QStringView prompt, const backend::GenerationParams ¶ms, /*out*/ ChatResponseMetadata &metadata) + -> QCoro::AsyncGenerator override; + +private: + OpenaiConnectionDetails m_connDetails; + QNetworkAccessManager *m_nam; +}; + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/llmodel/provider.cpp b/gpt4all-chat/src/llmodel/provider.cpp new file mode 100644 index 00000000..779925a1 --- /dev/null +++ b/gpt4all-chat/src/llmodel/provider.cpp @@ -0,0 +1,26 @@ +#include "provider.h" + +#include + + +namespace gpt4all::ui { + + +void OpenaiProvider::setBaseUrl(QUrl value) +{ + if (m_baseUrl != value) { + m_baseUrl = std::move(value); + emit baseUrlChanged(m_baseUrl); + } +} + +void OpenaiProvider::setApiKey(QString value) +{ + if (m_apiKey != value) { + m_apiKey = std::move(value); + emit apiKeyChanged(m_apiKey); + } +} + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/llmodel/provider.h b/gpt4all-chat/src/llmodel/provider.h new file mode 100644 index 00000000..7189e726 --- /dev/null +++ b/gpt4all-chat/src/llmodel/provider.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include +#include +#include + + +namespace gpt4all::ui { + + +class ModelProvider : public QObject { + Q_OBJECT + + Q_PROPERTY(QString name READ name CONSTANT) + +public: + [[nodiscard]] virtual QString name() = 0; +}; + +class OpenaiProvider : public ModelProvider { + Q_OBJECT + QML_ELEMENT + + Q_PROPERTY(QUrl baseUrl READ baseUrl WRITE setBaseUrl NOTIFY baseUrlChanged) + Q_PROPERTY(QString apiKey READ apiKey WRITE setApiKey NOTIFY apiKeyChanged) + +public: + [[nodiscard]] QString name() override { return m_name; } + [[nodiscard]] const QUrl &baseUrl() { return m_baseUrl; } + [[nodiscard]] const QString &apiKey () { return m_apiKey; } + + void setBaseUrl(QUrl value); + void setApiKey (QString value); + +Q_SIGNALS: + void baseUrlChanged(const QUrl &value); + void apiKeyChanged (const QString &value); + +private: + QString m_name; + QUrl m_baseUrl; + QString m_apiKey; +}; + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/mysettings.cpp b/gpt4all-chat/src/mysettings.cpp index 4bc1595f..06a77f26 100644 --- a/gpt4all-chat/src/mysettings.cpp +++ b/gpt4all-chat/src/mysettings.cpp @@ -1,6 +1,7 @@ #include "mysettings.h" #include "chatllm.h" +#include "config.h" #include "modellist.h" #include @@ -48,7 +49,6 @@ namespace ModelSettingsKey { namespace { namespace defaults { static const int threadCount = std::min(4, (int32_t) std::thread::hardware_concurrency()); -static const bool forceMetal = false; static const bool networkIsActive = false; static const bool networkUsageStatsActive = false; static const QString device = "Auto"; @@ -71,7 +71,6 @@ static const QVariantMap basicDefaults { { "localdocs/fileExtensions", QStringList { "docx", "pdf", "txt", "md", "rst" } }, { "localdocs/useRemoteEmbed", false }, { "localdocs/nomicAPIKey", "" }, - { "localdocs/embedDevice", "Auto" }, { "network/attribution", "" }, }; @@ -174,11 +173,16 @@ MySettings *MySettings::globalInstance() MySettings::MySettings() : QObject(nullptr) , m_deviceList(getDevices()) - , m_embeddingsDeviceList(getDevices(/*skipKompute*/ true)) , m_uiLanguages(getUiLanguages(modelPath())) { } +const QString &MySettings::userAgent() +{ + static const QString s_userAgent = QStringLiteral("gpt4all/" APP_VERSION); + return s_userAgent; +} + QVariant MySettings::checkJinjaTemplateError(const QString &tmpl) { if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString())) @@ -256,7 +260,6 @@ void MySettings::restoreApplicationDefaults() setNetworkPort(basicDefaults.value("networkPort").toInt()); setModelPath(defaultLocalModelsPath()); setUserDefaultModel(basicDefaults.value("userDefaultModel").toString()); - setForceMetal(defaults::forceMetal); setSuggestionMode(basicDefaults.value("suggestionMode").value()); setLanguageAndLocale(defaults::languageAndLocale); } @@ -269,7 +272,6 @@ void MySettings::restoreLocalDocsDefaults() setLocalDocsFileExtensions(basicDefaults.value("localdocs/fileExtensions").toStringList()); setLocalDocsUseRemoteEmbed(basicDefaults.value("localdocs/useRemoteEmbed").toBool()); setLocalDocsNomicAPIKey(basicDefaults.value("localdocs/nomicAPIKey").toString()); - setLocalDocsEmbedDevice(basicDefaults.value("localdocs/embedDevice").toString()); } void MySettings::eraseModel(const ModelInfo &info) @@ -628,7 +630,6 @@ bool MySettings::localDocsShowReferences() const { return getBasicSetting QStringList MySettings::localDocsFileExtensions() const { return getBasicSetting("localdocs/fileExtensions").toStringList(); } bool MySettings::localDocsUseRemoteEmbed() const { return getBasicSetting("localdocs/useRemoteEmbed").toBool(); } QString MySettings::localDocsNomicAPIKey() const { return getBasicSetting("localdocs/nomicAPIKey" ).toString(); } -QString MySettings::localDocsEmbedDevice() const { return getBasicSetting("localdocs/embedDevice" ).toString(); } QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); } ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnumSetting("chatTheme", chatThemeNames)); } @@ -646,7 +647,6 @@ void MySettings::setLocalDocsShowReferences(bool value) { setBasic void MySettings::setLocalDocsFileExtensions(const QStringList &value) { setBasicSetting("localdocs/fileExtensions", value, "localDocsFileExtensions"); } void MySettings::setLocalDocsUseRemoteEmbed(bool value) { setBasicSetting("localdocs/useRemoteEmbed", value, "localDocsUseRemoteEmbed"); } void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasicSetting("localdocs/nomicAPIKey", value, "localDocsNomicAPIKey"); } -void MySettings::setLocalDocsEmbedDevice(const QString &value) { setBasicSetting("localdocs/embedDevice", value, "localDocsEmbedDevice"); } void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); } void MySettings::setChatTheme(ChatTheme value) { setBasicSetting("chatTheme", chatThemeNames .value(int(value))); } @@ -706,19 +706,6 @@ void MySettings::setDevice(const QString &value) } } -bool MySettings::forceMetal() const -{ - return m_forceMetal; -} - -void MySettings::setForceMetal(bool value) -{ - if (m_forceMetal != value) { - m_forceMetal = value; - emit forceMetalChanged(value); - } -} - bool MySettings::networkIsActive() const { return m_settings.value("network/isActive", defaults::networkIsActive).toBool(); diff --git a/gpt4all-chat/src/mysettings.h b/gpt4all-chat/src/mysettings.h index 59cf1fa0..ffb7113a 100644 --- a/gpt4all-chat/src/mysettings.h +++ b/gpt4all-chat/src/mysettings.h @@ -62,7 +62,6 @@ class MySettings : public QObject Q_PROPERTY(ChatTheme chatTheme READ chatTheme WRITE setChatTheme NOTIFY chatThemeChanged) Q_PROPERTY(FontSize fontSize READ fontSize WRITE setFontSize NOTIFY fontSizeChanged) Q_PROPERTY(QString languageAndLocale READ languageAndLocale WRITE setLanguageAndLocale NOTIFY languageAndLocaleChanged) - Q_PROPERTY(bool forceMetal READ forceMetal WRITE setForceMetal NOTIFY forceMetalChanged) Q_PROPERTY(QString lastVersionStarted READ lastVersionStarted WRITE setLastVersionStarted NOTIFY lastVersionStartedChanged) Q_PROPERTY(int localDocsChunkSize READ localDocsChunkSize WRITE setLocalDocsChunkSize NOTIFY localDocsChunkSizeChanged) Q_PROPERTY(int localDocsRetrievalSize READ localDocsRetrievalSize WRITE setLocalDocsRetrievalSize NOTIFY localDocsRetrievalSizeChanged) @@ -70,13 +69,11 @@ class MySettings : public QObject Q_PROPERTY(QStringList localDocsFileExtensions READ localDocsFileExtensions WRITE setLocalDocsFileExtensions NOTIFY localDocsFileExtensionsChanged) Q_PROPERTY(bool localDocsUseRemoteEmbed READ localDocsUseRemoteEmbed WRITE setLocalDocsUseRemoteEmbed NOTIFY localDocsUseRemoteEmbedChanged) Q_PROPERTY(QString localDocsNomicAPIKey READ localDocsNomicAPIKey WRITE setLocalDocsNomicAPIKey NOTIFY localDocsNomicAPIKeyChanged) - Q_PROPERTY(QString localDocsEmbedDevice READ localDocsEmbedDevice WRITE setLocalDocsEmbedDevice NOTIFY localDocsEmbedDeviceChanged) Q_PROPERTY(QString networkAttribution READ networkAttribution WRITE setNetworkAttribution NOTIFY networkAttributionChanged) Q_PROPERTY(bool networkIsActive READ networkIsActive WRITE setNetworkIsActive NOTIFY networkIsActiveChanged) Q_PROPERTY(bool networkUsageStatsActive READ networkUsageStatsActive WRITE setNetworkUsageStatsActive NOTIFY networkUsageStatsActiveChanged) Q_PROPERTY(QString device READ device WRITE setDevice NOTIFY deviceChanged) Q_PROPERTY(QStringList deviceList MEMBER m_deviceList CONSTANT) - Q_PROPERTY(QStringList embeddingsDeviceList MEMBER m_embeddingsDeviceList CONSTANT) Q_PROPERTY(int networkPort READ networkPort WRITE setNetworkPort NOTIFY networkPortChanged) Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged) Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT) @@ -91,6 +88,8 @@ public Q_SLOTS: public: static MySettings *globalInstance(); + static const QString &userAgent(); + Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl); // Restore methods @@ -172,8 +171,6 @@ public: void setChatTheme(ChatTheme value); FontSize fontSize() const; void setFontSize(FontSize value); - bool forceMetal() const; - void setForceMetal(bool value); QString device(); void setDevice(const QString &value); int32_t contextLength() const; @@ -203,8 +200,6 @@ public: void setLocalDocsUseRemoteEmbed(bool value); QString localDocsNomicAPIKey() const; void setLocalDocsNomicAPIKey(const QString &value); - QString localDocsEmbedDevice() const; - void setLocalDocsEmbedDevice(const QString &value); // Network settings QString networkAttribution() const; @@ -243,7 +238,6 @@ Q_SIGNALS: void userDefaultModelChanged(); void chatThemeChanged(); void fontSizeChanged(); - void forceMetalChanged(bool); void lastVersionStartedChanged(); void localDocsChunkSizeChanged(); void localDocsRetrievalSizeChanged(); @@ -251,7 +245,6 @@ Q_SIGNALS: void localDocsFileExtensionsChanged(); void localDocsUseRemoteEmbedChanged(); void localDocsNomicAPIKeyChanged(); - void localDocsEmbedDeviceChanged(); void networkAttributionChanged(); void networkIsActiveChanged(); void networkPortChanged(); @@ -287,9 +280,7 @@ private: private: QSettings m_settings; - bool m_forceMetal; const QStringList m_deviceList; - const QStringList m_embeddingsDeviceList; const QStringList m_uiLanguages; std::unique_ptr m_translator; diff --git a/gpt4all-chat/src/network.cpp b/gpt4all-chat/src/network.cpp index e0c04f1d..015115b9 100644 --- a/gpt4all-chat/src/network.cpp +++ b/gpt4all-chat/src/network.cpp @@ -372,8 +372,6 @@ void Network::trackChatEvent(const QString &ev, QVariantMap props) Q_ASSERT(curChat); if (!props.contains("model")) props.insert("model", curChat->modelInfo().filename()); - props.insert("device_backend", curChat->deviceBackend()); - props.insert("actualDevice", curChat->device()); props.insert("doc_collections_enabled", curChat->collectionList().count()); props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount()); props.insert("datalake_active", MySettings::globalInstance()->networkIsActive()); diff --git a/gpt4all-chat/src/server.cpp b/gpt4all-chat/src/server.cpp index 940b5895..45db49f0 100644 --- a/gpt4all-chat/src/server.cpp +++ b/gpt4all-chat/src/server.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -126,7 +127,7 @@ class BaseCompletionRequest { public: QString model; // required // NB: some parameters are not supported yet - int32_t max_tokens = 16; + uint max_tokens = 16; qint64 n = 1; float temperature = 1.f; float top_p = 1.f; @@ -161,7 +162,7 @@ protected: value = reqValue("max_tokens", Integer, false, /*min*/ 1); if (!value.isNull()) - this->max_tokens = int32_t(qMin(value.toInteger(), INT32_MAX)); + this->max_tokens = uint(qMin(value.toInteger(), UINT32_MAX)); value = reqValue("n", Integer, false, /*min*/ 1); if (!value.isNull()) @@ -666,7 +667,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) m_chatModel->appendResponse(); // FIXME(jared): taking parameters from the UI inhibits reproducibility of results - LLModel::PromptContext promptCtx { + backend::GenerationParams genParams { .n_predict = request.max_tokens, .top_k = mySettings->modelTopK(modelInfo), .top_p = request.top_p, @@ -685,7 +686,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) PromptResult result; try { result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()), - promptCtx, + genParams, /*usedLocalDocs*/ false); } catch (const std::exception &e) { m_chatModel->setResponseValue(e.what()); @@ -779,7 +780,7 @@ auto Server::handleChatRequest(const ChatRequest &request) auto startOffset = m_chatModel->appendResponseWithHistory(messages); // FIXME(jared): taking parameters from the UI inhibits reproducibility of results - LLModel::PromptContext promptCtx { + backend::GenerationParams genParams { .n_predict = request.max_tokens, .top_k = mySettings->modelTopK(modelInfo), .top_p = request.top_p, @@ -796,7 +797,7 @@ auto Server::handleChatRequest(const ChatRequest &request) for (int i = 0; i < request.n; ++i) { ChatPromptResult result; try { - result = promptInternalChat(m_collections, promptCtx, startOffset); + result = promptInternalChat(m_collections, genParams, startOffset); } catch (const std::exception &e) { m_chatModel->setResponseValue(e.what()); m_chatModel->setError();