From 9772027e5edc3187c9d1ffd817f3c218f3fe2535 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 19 Mar 2025 10:49:39 -0400 Subject: [PATCH] WIP: provider page in the "add models" view --- .../include/gpt4all-backend/ollama-client.h | 35 ++-- .../include/gpt4all-backend/ollama-types.h | 2 +- gpt4all-backend/src/ollama-client.cpp | 33 +-- gpt4all-backend/src/rest.cpp | 6 +- gpt4all-chat/CMakeLists.txt | 6 +- gpt4all-chat/deps/CMakeLists.txt | 1 - gpt4all-chat/qml/AddCustomProviderView.qml | 57 ++++++ gpt4all-chat/qml/AddModelView.qml | 18 +- gpt4all-chat/qml/AddRemoteModelView.qml | 24 ++- gpt4all-chat/qml/CustomProviderCard.qml | 193 ++++++++++++++++++ gpt4all-chat/qml/RemoteModelCard.qml | 101 ++++----- gpt4all-chat/src/creatable.h | 18 ++ gpt4all-chat/src/llmodel_description.h | 2 + gpt4all-chat/src/llmodel_ollama.cpp | 71 ++++++- gpt4all-chat/src/llmodel_ollama.h | 50 +++-- gpt4all-chat/src/llmodel_openai.cpp | 83 +++++--- gpt4all-chat/src/llmodel_openai.h | 55 +++-- gpt4all-chat/src/llmodel_provider.cpp | 179 +++++++++++----- gpt4all-chat/src/llmodel_provider.h | 139 +++++++++---- gpt4all-chat/src/llmodel_provider.inl | 42 +++- gpt4all-chat/src/main.cpp | 2 + gpt4all-chat/src/mysettings.cpp | 4 +- gpt4all-chat/src/mysettings.h | 2 +- gpt4all-chat/src/qmlfunctions.cpp | 41 ++++ gpt4all-chat/src/qmlfunctions.h | 31 +++ gpt4all-chat/src/qmlsharedptr.cpp | 14 ++ gpt4all-chat/src/qmlsharedptr.h | 25 +++ gpt4all-chat/src/store_base.cpp | 52 +++-- gpt4all-chat/src/store_base.h | 7 +- gpt4all-chat/src/store_base.inl | 19 +- gpt4all-chat/src/store_provider.cpp | 20 +- gpt4all-chat/src/store_provider.h | 5 +- 32 files changed, 998 insertions(+), 339 deletions(-) create mode 100644 gpt4all-chat/qml/AddCustomProviderView.qml create mode 100644 gpt4all-chat/qml/CustomProviderCard.qml create mode 100644 gpt4all-chat/src/creatable.h create mode 100644 gpt4all-chat/src/qmlfunctions.cpp create mode 100644 gpt4all-chat/src/qmlfunctions.h create mode 100644 gpt4all-chat/src/qmlsharedptr.cpp create mode 100644 gpt4all-chat/src/qmlsharedptr.h diff --git a/gpt4all-backend/include/gpt4all-backend/ollama-client.h b/gpt4all-backend/include/gpt4all-backend/ollama-client.h index c891b64f..bcd22d43 100644 --- a/gpt4all-backend/include/gpt4all-backend/ollama-client.h +++ b/gpt4all-backend/include/gpt4all-backend/ollama-client.h @@ -12,6 +12,8 @@ #include #include +#include +#include #include #include @@ -24,7 +26,7 @@ namespace gpt4all::backend { struct ResponseError { public: - struct BadStatus { int code; }; + struct BadStatus { int code; std::optional reason; }; using ErrorCode = std::variant< QNetworkReply::NetworkError, boost::system::error_code, @@ -34,8 +36,8 @@ public: ResponseError(const QRestReply *reply); ResponseError(const boost::system::system_error &e); - const ErrorCode &error () { return m_error; } - const QString &errorString() { return m_errorString; } + const ErrorCode &error () const { return m_error; } + const QString &errorString() const { return m_errorString; } private: ErrorCode m_error; @@ -47,9 +49,10 @@ using DataOrRespErr = std::expected; class OllamaClient { public: - OllamaClient(QUrl baseUrl, QString m_userAgent) - : m_baseUrl(baseUrl) + OllamaClient(QUrl baseUrl, QString m_userAgent, QNetworkAccessManager *nam) + : m_baseUrl(std::move(baseUrl)) , m_userAgent(std::move(m_userAgent)) + , m_nam(nam) {} const QUrl &baseUrl() const { return m_baseUrl; } @@ -70,28 +73,28 @@ public: private: QNetworkRequest makeRequest(const QString &path) const; - auto processResponse(QNetworkReply &reply) -> QCoro::Task>; + auto processResponse(std::unique_ptr reply) -> QCoro::Task>; template - auto get(const QString &path) -> QCoro::Task>; + auto get(QString path) -> QCoro::Task>; template - auto post(const QString &path, Req const &body) -> QCoro::Task>; + auto post(QString path, Req const &body) -> QCoro::Task>; - auto getJson(const QString &path) -> QCoro::Task>; + auto getJson(QString path) -> QCoro::Task>; auto postJson(const QString &path, const boost::json::value &body) -> QCoro::Task>; private: - QUrl m_baseUrl; - QString m_userAgent; - QNetworkAccessManager m_nam; - boost::json::stream_parser m_parser; + QUrl m_baseUrl; + QString m_userAgent; + QNetworkAccessManager *m_nam; + boost::json::stream_parser m_parser; }; -extern template auto OllamaClient::get(const QString &) -> QCoro::Task>; -extern template auto OllamaClient::get(const QString &) -> QCoro::Task>; +extern template auto OllamaClient::get(QString) -> QCoro::Task>; +extern template auto OllamaClient::get(QString) -> QCoro::Task>; -extern template auto OllamaClient::post(const QString &, const ollama::ShowRequest &) +extern template auto OllamaClient::post(QString, const ollama::ShowRequest &) -> QCoro::Task>; diff --git a/gpt4all-backend/include/gpt4all-backend/ollama-types.h b/gpt4all-backend/include/gpt4all-backend/ollama-types.h index 4cc3c2e9..e63ca736 100644 --- a/gpt4all-backend/include/gpt4all-backend/ollama-types.h +++ b/gpt4all-backend/include/gpt4all-backend/ollama-types.h @@ -51,7 +51,7 @@ struct ListModelResponse { QString digest; std::optional details; }; -BOOST_DESCRIBE_STRUCT(ListModelResponse, (), (model, modified_at, size, digest, details)) +BOOST_DESCRIBE_STRUCT(ListModelResponse, (), (name, model, modified_at, size, digest, details)) using ToolCallFunctionArguments = boost::json::object; diff --git a/gpt4all-backend/src/ollama-client.cpp b/gpt4all-backend/src/ollama-client.cpp index d32d2ae6..b183bb69 100644 --- a/gpt4all-backend/src/ollama-client.cpp +++ b/gpt4all-backend/src/ollama-client.cpp @@ -29,7 +29,8 @@ ResponseError::ResponseError(const QRestReply *reply) if (reply->hasError()) { m_error = reply->networkReply()->error(); } else if (!reply->isHttpStatusSuccess()) { - m_error = BadStatus(reply->httpStatus()); + auto reason = reply->networkReply()->attribute(QNetworkRequest::HttpReasonPhraseAttribute).toString(); + m_error = BadStatus(reply->httpStatus(), reason.isEmpty() ? std::nullopt : std::optional(reason)); } else Q_UNREACHABLE(); @@ -50,19 +51,19 @@ QNetworkRequest OllamaClient::makeRequest(const QString &path) const return req; } -auto OllamaClient::processResponse(QNetworkReply &reply) -> QCoro::Task> +auto OllamaClient::processResponse(std::unique_ptr reply) -> QCoro::Task> { - QRestReply restReply(&reply); - if (reply.error()) + QRestReply restReply(reply.get()); + if (reply->error()) co_return std::unexpected(&restReply); - auto coroReply = qCoro(reply); + auto coroReply = qCoro(*reply); for (;;) { auto chunk = co_await coroReply.readAll(); if (!restReply.isSuccess()) co_return std::unexpected(&restReply); if (chunk.isEmpty()) { - Q_ASSERT(reply.atEnd()); + Q_ASSERT(reply->atEnd()); break; } m_parser.write(chunk.data(), chunk.size()); @@ -73,7 +74,7 @@ auto OllamaClient::processResponse(QNetworkReply &reply) -> QCoro::Task -auto OllamaClient::get(const QString &path) -> QCoro::Task> +auto OllamaClient::get(QString path) -> QCoro::Task> { // get() should not throw exceptions try { @@ -86,11 +87,11 @@ auto OllamaClient::get(const QString &path) -> QCoro::Task> } } -template auto OllamaClient::get(const QString &) -> QCoro::Task>; -template auto OllamaClient::get(const QString &) -> QCoro::Task>; +template auto OllamaClient::get(QString) -> QCoro::Task>; +template auto OllamaClient::get(QString) -> QCoro::Task>; template -auto OllamaClient::post(const QString &path, const Req &body) -> QCoro::Task> +auto OllamaClient::post(QString path, const Req &body) -> QCoro::Task> { // post() should not throw exceptions try { @@ -104,12 +105,12 @@ auto OllamaClient::post(const QString &path, const Req &body) -> QCoro::Task QCoro::Task>; +template auto OllamaClient::post(QString, const ShowRequest &) -> QCoro::Task>; -auto OllamaClient::getJson(const QString &path) -> QCoro::Task> +auto OllamaClient::getJson(QString path) -> QCoro::Task> { - std::unique_ptr reply(m_nam.get(makeRequest(path))); - co_return co_await processResponse(*reply); + std::unique_ptr reply(m_nam->get(makeRequest(path))); + return processResponse(std::move(reply)); } auto OllamaClient::postJson(const QString &path, const json::value &body) -> QCoro::Task> @@ -117,8 +118,8 @@ auto OllamaClient::postJson(const QString &path, const json::value &body) -> QCo JsonStreamDevice stream(&body); auto req = makeRequest(path); req.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"_ba); - std::unique_ptr reply(m_nam.post(req, &stream)); - co_return co_await processResponse(*reply); + std::unique_ptr reply(m_nam->post(req, &stream)); + co_return co_await processResponse(std::move(reply)); } diff --git a/gpt4all-backend/src/rest.cpp b/gpt4all-backend/src/rest.cpp index 974a2bb0..d577777f 100644 --- a/gpt4all-backend/src/rest.cpp +++ b/gpt4all-backend/src/rest.cpp @@ -18,11 +18,11 @@ QString restErrorString(const QRestReply &reply) if (!reply.isHttpStatusSuccess()) { auto code = reply.httpStatus(); - auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute); + auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute).toString(); return u"HTTP %1%2%3 for URL \"%4\""_s.arg( QString::number(code), - reason.isValid() ? u" "_s : QString(), - reason.toString(), + reason.isEmpty() ? QString() : u" "_s, + reason, nr->request().url().toString() ); } diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 36c0a033..15a8160b 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -253,6 +253,8 @@ qt_add_executable(chat src/modellist.cpp src/modellist.h src/mysettings.cpp src/mysettings.h src/network.cpp src/network.h + src/qmlfunctions.cpp src/qmlfunctions.h + src/qmlsharedptr.cpp src/qmlsharedptr.h src/server.cpp src/server.h src/store_base.cpp src/store_base.h src/store_provider.cpp src/store_provider.h @@ -272,18 +274,20 @@ qt_add_qml_module(chat QML_FILES main.qml qml/AddCollectionView.qml + qml/AddCustomProviderView.qml qml/AddModelView.qml qml/AddGPT4AllModelView.qml qml/AddHFModelView.qml qml/AddRemoteModelView.qml qml/ApplicationSettings.qml - qml/ChatDrawer.qml qml/ChatCollapsibleItem.qml + qml/ChatDrawer.qml qml/ChatItemView.qml qml/ChatMessageButton.qml qml/ChatTextItem.qml qml/ChatView.qml qml/CollectionsDrawer.qml + qml/CustomProviderCard.qml qml/HomeView.qml qml/LocalDocsSettings.qml qml/LocalDocsView.qml diff --git a/gpt4all-chat/deps/CMakeLists.txt b/gpt4all-chat/deps/CMakeLists.txt index 6cec1227..305e68a1 100644 --- a/gpt4all-chat/deps/CMakeLists.txt +++ b/gpt4all-chat/deps/CMakeLists.txt @@ -15,7 +15,6 @@ add_subdirectory(QXlsx/QXlsx) add_subdirectory(json) # required by minja # TartanLlama -set(FUNCTION_REF_ENABLE_TESTS OFF) add_subdirectory(generator) if (NOT GPT4ALL_USING_QTPDF) diff --git a/gpt4all-chat/qml/AddCustomProviderView.qml b/gpt4all-chat/qml/AddCustomProviderView.qml new file mode 100644 index 00000000..cb1f5819 --- /dev/null +++ b/gpt4all-chat/qml/AddCustomProviderView.qml @@ -0,0 +1,57 @@ +import QtQuick +import QtQuick.Controls +import QtQuick.Layouts + +ColumnLayout { + Layout.fillWidth: true + Layout.alignment: Qt.AlignTop + spacing: 5 + + Label { + Layout.topMargin: 0 + Layout.bottomMargin: 25 + Layout.rightMargin: 150 * theme.fontScale + Layout.alignment: Qt.AlignTop + Layout.fillWidth: true + verticalAlignment: Text.AlignTop + text: qsTr("Add custom model providers here.") + font.pixelSize: theme.fontSizeLarger + color: theme.textColor + wrapMode: Text.WordWrap + } + + ScrollView { + id: scrollView + ScrollBar.vertical.policy: ScrollBar.AsNeeded + Layout.fillWidth: true + Layout.fillHeight: true + contentWidth: availableWidth + clip: true + Flow { + anchors.left: parent.left + anchors.right: parent.right + spacing: 20 + bottomPadding: 20 + property int childWidth: 330 * theme.fontScale + property int childHeight: 400 + 166 * theme.fontScale + CustomProviderCard { + width: parent.childWidth + height: parent.childHeight + withApiKey: true + createProvider: QmlFunctions.newCustomOpenaiProvider + providerName: qsTr("OpenAI") + providerImage: "qrc:/gpt4all/icons/antenna_3.svg" + providerDesc: qsTr("Configure a custom OpenAI provider.") + } + CustomProviderCard { + width: parent.childWidth + height: parent.childHeight + withApiKey: false + createProvider: QmlFunctions.newCustomOllamaProvider + providerName: qsTr("Ollama") + providerImage: "qrc:/gpt4all/icons/antenna_3.svg" + providerDesc: qsTr("Configure a custom Ollama provider.") + } + } + } +} diff --git a/gpt4all-chat/qml/AddModelView.qml b/gpt4all-chat/qml/AddModelView.qml index 0bb06f24..07f45f0e 100644 --- a/gpt4all-chat/qml/AddModelView.qml +++ b/gpt4all-chat/qml/AddModelView.qml @@ -96,6 +96,11 @@ Rectangle { remoteModelView.show(); } } + MyTabButton { + text: qsTr("Custom Providers") + isSelected: customProviderModelView.isShown() + onPressed: customProviderModelView.show() + } MyTabButton { text: qsTr("HuggingFace") isSelected: huggingfaceModelView.isShown() @@ -136,6 +141,15 @@ Rectangle { } } + AddCustomProviderView { + id: customProviderModelView + Layout.fillWidth: true + Layout.fillHeight: true + + function show() { stackLayout.currentIndex = 2; } + function isShown() { return stackLayout.currentIndex === 2; } + } + AddHFModelView { id: huggingfaceModelView Layout.fillWidth: true @@ -146,10 +160,10 @@ Rectangle { anchors.fill: parent function show() { - stackLayout.currentIndex = 2; + stackLayout.currentIndex = 3; } function isShown() { - return stackLayout.currentIndex === 2; + return stackLayout.currentIndex === 3; } } } diff --git a/gpt4all-chat/qml/AddRemoteModelView.qml b/gpt4all-chat/qml/AddRemoteModelView.qml index 924ddc6e..57acff58 100644 --- a/gpt4all-chat/qml/AddRemoteModelView.qml +++ b/gpt4all-chat/qml/AddRemoteModelView.qml @@ -49,15 +49,13 @@ ColumnLayout { property int childWidth: 330 * theme.fontScale property int childHeight: 400 + 166 * theme.fontScale Repeater { - model: BuiltinProviderList - delegate: RemoteModelCard { - required property var data + model: ProviderListSort + RemoteModelCard { width: parent.childWidth height: parent.childHeight - provider: data - providerBaseUrl: data.baseUrl - providerName: data.name - providerImage: data.icon + provider: modelData + providerName: provider.name + providerImage: provider.icon providerDesc: ({ '{20f963dc-1f99-441e-ad80-f30a0a06bcac}': qsTr( 'Groq offers a high-performance AI inference engine designed for low-latency and ' + @@ -78,10 +76,18 @@ ColumnLayout { 'performance, making them a solid option for applications requiring scalable AI ' + 'solutions.

Get your API key: https://mistral.ai/' ), - })[data.id.toString()] - modelWhitelist: data.modelWhitelist + })[provider.id.toString()] } } + RemoteModelCard { + width: parent.childWidth + height: parent.childHeight + providerUsesApiKey: false + providerName: qsTr("Ollama (Custom)") + providerImage: "qrc:/gpt4all/icons/antenna_3.svg" + providerDesc: qsTr("Configure a custom Ollama provider.") + } + // TODO(jared): add custom openai back to the list /* RemoteModelCard { width: parent.childWidth diff --git a/gpt4all-chat/qml/CustomProviderCard.qml b/gpt4all-chat/qml/CustomProviderCard.qml new file mode 100644 index 00000000..cdd93c7c --- /dev/null +++ b/gpt4all-chat/qml/CustomProviderCard.qml @@ -0,0 +1,193 @@ +import QtQuick +import QtQuick.Controls +import QtQuick.Layouts + +import gpt4all.ProviderRegistry + +Rectangle { + id: root + required property bool withApiKey + required property var createProvider + property alias providerName: providerNameLabel.text + property alias providerImage: myimage.source + property alias providerDesc: providerDescLabel.text + + color: theme.conversationBackground + radius: 10 + border.width: 1 + border.color: theme.controlBorder + implicitHeight: topColumn.height + bottomColumn.height + 33 * theme.fontScale + + ColumnLayout { + id: topColumn + anchors.left: parent.left + anchors.right: parent.right + anchors.top: parent.top + anchors.margins: 20 + spacing: 15 * theme.fontScale + RowLayout { + Layout.alignment: Qt.AlignTop + spacing: 10 + Item { + Layout.preferredWidth: 27 * theme.fontScale + Layout.preferredHeight: 27 * theme.fontScale + Layout.alignment: Qt.AlignLeft + + Image { + id: myimage + anchors.centerIn: parent + sourceSize.width: parent.width + sourceSize.height: parent.height + mipmap: true + fillMode: Image.PreserveAspectFit + } + } + + Label { + id: providerNameLabel + color: theme.textColor + font.pixelSize: theme.fontSizeBanner + } + } + + Label { + id: providerDescLabel + Layout.fillWidth: true + wrapMode: Text.Wrap + color: theme.settingsTitleTextColor + font.pixelSize: theme.fontSizeLarge + onLinkActivated: function(link) { Qt.openUrlExternally(link); } + + MouseArea { + anchors.fill: parent + acceptedButtons: Qt.NoButton // pass clicks to parent + cursorShape: parent.hoveredLink ? Qt.PointingHandCursor : Qt.ArrowCursor + } + } + } + + ColumnLayout { + id: bottomColumn + anchors.left: parent.left + anchors.right: parent.right + anchors.bottom: parent.bottom + anchors.margins: 20 + spacing: 30 + + ColumnLayout { + MySettingsLabel { + text: qsTr("Name") + font.bold: true + font.pixelSize: theme.fontSizeLarge + color: theme.settingsTitleTextColor + } + MyTextField { + id: nameField + Layout.fillWidth: true + font.pixelSize: theme.fontSizeLarge + wrapMode: Text.WrapAnywhere + placeholderText: qsTr("Provider Name") + Accessible.role: Accessible.EditableText + Accessible.name: placeholderText + } + } + + ColumnLayout { + MySettingsLabel { + text: qsTr("Base URL") + font.bold: true + font.pixelSize: theme.fontSizeLarge + color: theme.settingsTitleTextColor + } + MyTextField { + id: baseUrlField + property bool ok: text.trim() !== "" + Layout.fillWidth: true + font.pixelSize: theme.fontSizeLarge + wrapMode: Text.WrapAnywhere + placeholderText: qsTr("Provider Base URL") + Accessible.role: Accessible.EditableText + Accessible.name: placeholderText + } + } + + ColumnLayout { + visible: withApiKey + + MySettingsLabel { + text: qsTr("API Key") + font.bold: true + font.pixelSize: theme.fontSizeLarge + color: theme.settingsTitleTextColor + } + + MyTextField { + id: apiKeyField + Layout.fillWidth: true + font.pixelSize: theme.fontSizeLarge + wrapMode: Text.WrapAnywhere + echoMode: TextField.Password + placeholderText: qsTr("Provider API Key") + Accessible.role: Accessible.EditableText + Accessible.name: placeholderText + } + } + + ColumnLayout { + MySettingsLabel { + text: qsTr("Status") + font.bold: true + font.pixelSize: theme.fontSizeLarge + color: theme.settingsTitleTextColor + } + + RowLayout { + spacing: 10 + + MyTextField { + id: statusText + property var provider: null // owns the new provider + enabled: false + Layout.fillWidth: true + font.pixelSize: theme.fontSizeLarge + property var inputs: ({ + name : nameField .text.trim(), + baseUrl : baseUrlField.text.trim(), + apiKey : apiKeyField .text.trim(), + }) + function update() { + provider = null; + text = qsTr("..."); + if (inputs.name === "" || inputs.baseUrl === "") + return; + const args = [inputs.name, inputs.baseUrl]; + if (withApiKey) + args.push(inputs.apiKey); + let p = createProvider(...args); + if (p !== null) + p.get().statusQml().then(status => { + if (status !== null) { + if (status.ok) { provider = p; } + text = status.detail; + } + }); + } + Component.onCompleted: update() + onInputsChanged: update() + } + } + } + + MySettingsButton { + id: installButton + Layout.alignment: Qt.AlignRight + text: qsTr("Install") + font.pixelSize: theme.fontSizeLarge + enabled: statusText.provider !== null + onClicked: ProviderRegistry.addQml(statusText.provider) + Accessible.role: Accessible.Button + Accessible.name: qsTr("Install") + Accessible.description: qsTr("Install custom provider") + } + } +} diff --git a/gpt4all-chat/qml/RemoteModelCard.qml b/gpt4all-chat/qml/RemoteModelCard.qml index f8585264..f2cbe0e6 100644 --- a/gpt4all-chat/qml/RemoteModelCard.qml +++ b/gpt4all-chat/qml/RemoteModelCard.qml @@ -1,30 +1,19 @@ -import QtCore import QtQuick import QtQuick.Controls -import QtQuick.Controls.Basic import QtQuick.Layouts -import QtQuick.Dialogs -import Qt.labs.folderlistmodel -import Qt5Compat.GraphicalEffects - -import llm -import chatlistmodel -import download -import modellist -import network -import gpt4all -import mysettings -import localdocs - Rectangle { - required property var provider + id: remoteModelCard + property var provider: null property alias providerName: providerNameLabel.text property alias providerImage: myimage.source property alias providerDesc: providerDescLabel.text - property string providerBaseUrl: "" - property bool providerIsCustom: false - property var modelWhitelist: null + property bool providerUsesApiKey: true + + // for internal use + property bool apiKeyRequired: provider === null ? providerUsesApiKey : "apiKey" in provider + property bool apiKeyGood: !apiKeyRequired // (overwritten later if required) + property bool baseUrlGood: provider !== null // (overwritten later if custom) color: theme.conversationBackground radius: 10 @@ -89,6 +78,8 @@ Rectangle { spacing: 30 ColumnLayout { + visible: apiKeyRequired + MySettingsLabel { text: qsTr("API Key") font.bold: true @@ -106,29 +97,19 @@ Rectangle { messageToast.show(qsTr("ERROR: $API_KEY is empty.")); apiKeyField.placeholderTextColor = theme.textErrorColor; } - Component.onCompleted: { text = provider.apiKey; } + Component.onCompleted: { if (parent.visible && provider !== null) { text = provider.apiKey; } } onTextChanged: { apiKeyField.placeholderTextColor = theme.mutedTextColor; - if (!providerIsCustom && provider.setApiKeyQml(text)) { - provider.listModelsQml().then(modelList => { - if (modelList !== null) { - if (modelWhitelist !== null) - models = models.filter(m => modelWhitelist.includes(m)); - myModelList.model = models; - myModelList.currentIndex = -1; - } - }); - } + if (provider !== null) { apiKeyGood = provider.setApiKeyQml(text) && text !== ""; } } placeholderText: qsTr("enter $API_KEY") Accessible.role: Accessible.EditableText Accessible.name: placeholderText - Accessible.description: qsTr("Whether the file hash is being calculated") } } ColumnLayout { - visible: providerIsCustom + visible: provider === null MySettingsLabel { text: qsTr("Base Url") font.bold: true @@ -146,40 +127,16 @@ Rectangle { } onTextChanged: { baseUrlField.placeholderTextColor = theme.mutedTextColor; + baseUrlGood = text.trim() !== ""; } placeholderText: qsTr("enter $BASE_URL") Accessible.role: Accessible.EditableText Accessible.name: placeholderText } } - ColumnLayout { - visible: providerIsCustom - MySettingsLabel { - text: qsTr("Model Name") - font.bold: true - font.pixelSize: theme.fontSizeLarge - color: theme.settingsTitleTextColor - } - MyTextField { - id: modelNameField - Layout.fillWidth: true - font.pixelSize: theme.fontSizeLarge - wrapMode: Text.WrapAnywhere - function showError() { - messageToast.show(qsTr("ERROR: $MODEL_NAME is empty.")) - modelNameField.placeholderTextColor = theme.textErrorColor; - } - onTextChanged: { - modelNameField.placeholderTextColor = theme.mutedTextColor; - } - placeholderText: qsTr("enter $MODEL_NAME") - Accessible.role: Accessible.EditableText - Accessible.name: placeholderText - } - } ColumnLayout { - visible: myModelList.count > 0 && !providerIsCustom + visible: myModelList.count > 0 MySettingsLabel { text: qsTr("Models") @@ -194,7 +151,27 @@ Rectangle { MyComboBox { Layout.fillWidth: true id: myModelList - currentIndex: -1; + currentIndex: -1 + property bool ready: baseUrlGood && apiKeyGood + onReadyChanged: { + if (!ready) { return; } + let providerRef = null; // owns the new provider + let provider = remoteModelCard.provider; + if (provider === null) { + // TODO: custom OpenAI + providerRef = QmlFunctions.newCustomOllamaProvider("foo", baseUrlField.text.trim()); + if (providerRef !== null) + provider = providerRef.get(); + } + if (provider !== null) { + provider.listModelsQml().then(modelList => { + if (modelList !== null) { + model = modelList; + currentIndex = -1; + } + }); + } + } } } } @@ -206,10 +183,10 @@ Rectangle { font.pixelSize: theme.fontSizeLarge property string apiKeyText: apiKeyField.text.trim() - property string baseUrlText: providerIsCustom ? baseUrlField.text.trim() : providerBaseUrl.trim() - property string modelNameText: providerIsCustom ? modelNameField.text.trim() : myModelList.currentText.trim() + property string baseUrlText: provider === null ? baseUrlField.text.trim() : provider.baseUrl + property string modelNameText: myModelList.currentText.trim() - enabled: apiKeyText !== "" && baseUrlText !== "" && modelNameText !== "" + enabled: baseUrlGood && apiKeyGood && modelNameText !== "" onClicked: { Download.installCompatibleModel( diff --git a/gpt4all-chat/src/creatable.h b/gpt4all-chat/src/creatable.h new file mode 100644 index 00000000..5933bcc0 --- /dev/null +++ b/gpt4all-chat/src/creatable.h @@ -0,0 +1,18 @@ +#pragma once + +#include + + +namespace gpt4all::ui { + + +/// Helper mixin for classes derived from std::enable_shared_from_this. +template +struct Creatable { + template + static auto create(Ts &&...args) -> std::shared_ptr + { return std::make_shared(typename T::protected_t(), std::forward(args)...); } +}; + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/llmodel_description.h b/gpt4all-chat/src/llmodel_description.h index fc2811fb..ef9ddefd 100644 --- a/gpt4all-chat/src/llmodel_description.h +++ b/gpt4all-chat/src/llmodel_description.h @@ -37,6 +37,8 @@ public: protected: [[nodiscard]] virtual auto newInstanceImpl(QNetworkAccessManager *nam) const -> ChatLLMInstance * = 0; + + template friend struct Creatable; }; diff --git a/gpt4all-chat/src/llmodel_ollama.cpp b/gpt4all-chat/src/llmodel_ollama.cpp index 39c9358d..42e89c6a 100644 --- a/gpt4all-chat/src/llmodel_ollama.cpp +++ b/gpt4all-chat/src/llmodel_ollama.cpp @@ -1,7 +1,14 @@ #include "llmodel_ollama.h" +#include "main.h" +#include "mysettings.h" + #include #include +#include +#include // IWYU pragma: keep + +#include using namespace Qt::Literals::StringLiterals; @@ -21,6 +28,9 @@ auto OllamaGenerationParams::toMap() const -> QMap }; } +OllamaProvider::OllamaProvider() +{ QJSEngine::setObjectOwnership(this, QJSEngine::CppOwnership); } + OllamaProvider::~OllamaProvider() noexcept = default; auto OllamaProvider::supportedGenerationParams() const -> QSet @@ -33,9 +43,54 @@ auto OllamaProvider::makeGenerationParams(const QMap -> OllamaGenerationParams * { return new OllamaGenerationParams(values); } +auto OllamaProvider::status() -> QCoro::Task +{ + auto client = makeClient(); + auto resp = co_await client.version(); + if (resp) + co_return ProviderStatus(tr("Version: %1").arg(resp->version)); + co_return ProviderStatus(resp.error()); +} + +auto OllamaProvider::listModels() -> QCoro::Task> +{ + auto client = makeClient(); + auto resp = co_await client.list(); + if (!resp) + co_return std::unexpected(resp.error()); + QStringList res; + for (auto &model : resp->models) + res << model.name; + co_return res; +} + +QCoro::QmlTask OllamaProvider::statusQml() +{ return wrapQmlTask(this, &OllamaProvider::status, u"OllamaProvider::status"_s); } + +QCoro::QmlTask OllamaProvider::listModelsQml() +{ return wrapQmlTask(this, &OllamaProvider::listModels, u"OllamaProvider::listModels"_s); } + +auto OllamaProvider::newModel(const QByteArray &modelHash) const -> std::shared_ptr +{ return std::static_pointer_cast(newModelImpl(modelHash)); } + +auto OllamaProvider::newModelImpl(const QVariant &key) const -> std::shared_ptr +{ + if (!key.canConvert()) + throw std::invalid_argument(fmt::format("expected modelHash type QByteArray, got {}", key.typeName())); + return OllamaModelDescription::create( + std::shared_ptr(shared_from_this(), this), key.toByteArray() + ); +} + +auto OllamaProvider::makeClient() -> backend::OllamaClient +{ + auto *mySettings = MySettings::globalInstance(); + return backend::OllamaClient(m_baseUrl, mySettings->userAgent(), networkAccessManager()); +} + /// load -OllamaProviderCustom::OllamaProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl) - : ModelProvider (std::move(id), std::move(name), std::move(baseUrl)) +OllamaProviderCustom::OllamaProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl) + : ModelProvider (p, std::move(id), std::move(name), std::move(baseUrl)) , ModelProviderCustom(store) { if (auto res = m_store->acquire(m_id); !res) @@ -43,14 +98,12 @@ OllamaProviderCustom::OllamaProviderCustom(ProviderStore *store, QUuid id, QStri } /// create -OllamaProviderCustom::OllamaProviderCustom(ProviderStore *store, QString name, QUrl baseUrl) - : ModelProvider (std::move(name), std::move(baseUrl)) +OllamaProviderCustom::OllamaProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl) + : ModelProvider (p, QUuid::createUuid(), std::move(name), std::move(baseUrl)) , ModelProviderCustom(store) { - auto data = m_store->create(m_name, m_baseUrl); - if (!data) - data.error().raise(); - m_id = (*data)->id; + if (auto res = m_store->acquire(m_id); !res) + res.error().raise(); } auto OllamaProviderCustom::asData() -> ModelProviderData @@ -58,7 +111,7 @@ auto OllamaProviderCustom::asData() -> ModelProviderData return { .id = m_id, .custom_details = CustomProviderDetails { m_name, m_baseUrl }, - .provider_details = {}, + .provider_details = std::monostate(), }; } diff --git a/gpt4all-chat/src/llmodel_ollama.h b/gpt4all-chat/src/llmodel_ollama.h index 6768badb..0cd3270e 100644 --- a/gpt4all-chat/src/llmodel_ollama.h +++ b/gpt4all-chat/src/llmodel_ollama.h @@ -1,9 +1,13 @@ #pragma once +#include "creatable.h" #include "llmodel_chat.h" #include "llmodel_description.h" #include "llmodel_provider.h" +#include // IWYU pragma: keep +#include + #include #include // IWYU pragma: keep #include @@ -12,6 +16,8 @@ #include #include // IWYU pragma: keep +#include + class QNetworkAccessManager; template class QMap; template class QSet; @@ -21,6 +27,7 @@ namespace gpt4all::ui { class OllamaChatModel; +class OllamaModelDescription; struct OllamaGenerationParamsData { uint n_predict; @@ -38,9 +45,12 @@ protected: }; class OllamaProvider : public QObject, public virtual ModelProvider { - Q_GADGET + Q_OBJECT Q_PROPERTY(QUuid id READ id CONSTANT) +protected: + explicit OllamaProvider(); + public: ~OllamaProvider() noexcept override = 0; @@ -49,30 +59,48 @@ public: auto supportedGenerationParams() const -> QSet override; auto makeGenerationParams(const QMap &values) const -> OllamaGenerationParams * override; + + // endpoints + auto status () -> QCoro::Task override; + auto listModels() -> QCoro::Task> override; + + // QML wrapped endpoints + Q_INVOKABLE QCoro::QmlTask statusQml (); + Q_INVOKABLE QCoro::QmlTask listModelsQml(); + + [[nodiscard]] auto newModel(const QByteArray &modelHash) const -> std::shared_ptr; + +protected: + [[nodiscard]] auto newModelImpl(const QVariant &key) const -> std::shared_ptr final; + +private: + backend::OllamaClient makeClient(); }; -class OllamaProviderBuiltin : public OllamaProvider { - Q_GADGET +class OllamaProviderBuiltin : public OllamaProvider, public Creatable { + Q_OBJECT Q_PROPERTY(QString name READ name CONSTANT) Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT) public: /// Create a new built-in Ollama provider (transient). - explicit OllamaProviderBuiltin(QUuid id, QString name, QUrl baseUrl) - : ModelProvider(std::move(id), std::move(name), std::move(baseUrl)) {} + explicit OllamaProviderBuiltin(protected_t p, QUuid id, QString name, QUrl baseUrl) + : ModelProvider(p, std::move(id), std::move(name), std::move(baseUrl)) {} }; -class OllamaProviderCustom final : public OllamaProvider, public ModelProviderCustom { +class OllamaProviderCustom final + : public OllamaProvider, public ModelProviderCustom, public Creatable +{ Q_OBJECT Q_PROPERTY(QString name READ name NOTIFY nameChanged ) Q_PROPERTY(QUrl baseUrl READ baseUrl NOTIFY baseUrlChanged) public: /// Load an existing OllamaProvider from disk. - explicit OllamaProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl); + explicit OllamaProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl); /// Create a new OllamaProvider on disk. - explicit OllamaProviderCustom(ProviderStore *store, QString name, QUrl baseUrl); + explicit OllamaProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl); Q_SIGNALS: void nameChanged (const QString &value); @@ -82,17 +110,13 @@ protected: auto asData() -> ModelProviderData override; }; -class OllamaModelDescription : public ModelDescription { +class OllamaModelDescription : public ModelDescription, public Creatable { Q_GADGET Q_PROPERTY(QByteArray modelHash READ modelHash CONSTANT) public: explicit OllamaModelDescription(protected_t, std::shared_ptr provider, QByteArray modelHash); - static auto create(std::shared_ptr provider, QByteArray modelHash) - -> std::shared_ptr - { return std::make_shared(protected_t(), std::move(provider), std::move(modelHash)); } - // getters [[nodiscard]] auto provider () const -> const OllamaProvider * override { return m_provider.get(); } [[nodiscard]] QVariant key () const override { return m_modelHash; } diff --git a/gpt4all-chat/src/llmodel_openai.cpp b/gpt4all-chat/src/llmodel_openai.cpp index 7252654f..bfd2642b 100644 --- a/gpt4all-chat/src/llmodel_openai.cpp +++ b/gpt4all-chat/src/llmodel_openai.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -88,6 +89,13 @@ auto OpenaiGenerationParams::toMap() const -> QMap }; } +OpenaiProvider::OpenaiProvider() +{ QJSEngine::setObjectOwnership(this, QJSEngine::CppOwnership); } + +OpenaiProvider::OpenaiProvider(QString apiKey) + : m_apiKey(std::move(apiKey)) +{ QJSEngine::setObjectOwnership(this, QJSEngine::CppOwnership); } + OpenaiProvider::~OpenaiProvider() noexcept = default; Q_INVOKABLE bool OpenaiProvider::setApiKeyQml(QString value) @@ -108,13 +116,22 @@ auto OpenaiProvider::makeGenerationParams(const QMap -> OpenaiGenerationParams * { return new OpenaiGenerationParams(values); } +auto OpenaiProvider::status() -> QCoro::Task +{ + auto resp = co_await listModels(); + if (resp) + co_return ProviderStatus(tr("OK")); + co_return ProviderStatus(resp.error()); +} + auto OpenaiProvider::listModels() -> QCoro::Task> { + auto *mySettings = MySettings::globalInstance(); auto *nam = networkAccessManager(); QNetworkRequest request(m_baseUrl.resolved(u"models"_s)); - request.setHeader (QNetworkRequest::ContentTypeHeader, "application/json"_ba); - request.setRawHeader("Authorization"_ba, fmt::format("Bearer {}", m_apiKey).c_str()); + request.setHeader (QNetworkRequest::UserAgentHeader, mySettings->userAgent()); + request.setRawHeader("authorization"_ba, fmt::format("Bearer {}", m_apiKey).c_str()); std::unique_ptr reply(nam->get(request)); QRestReply restReply(reply.get()); @@ -146,20 +163,27 @@ auto OpenaiProvider::listModels() -> QCoro::Task std::shared_ptr +{ return std::static_pointer_cast(newModelImpl(modelName)); } + +auto OpenaiProvider::newModelImpl(const QVariant &key) const -> std::shared_ptr { - return [this]() -> QCoro::Task { - auto result = co_await listModels(); - if (result) - co_return *result; - qWarning().noquote() << "OpenaiProvider::listModels failed:" << result.error().errorString(); - co_return QVariant::fromValue(nullptr); - }(); + if (!key.canConvert()) + throw std::invalid_argument(fmt::format("expected modelName type QString, got {}", key.typeName())); + return OpenaiModelDescription::create( + std::shared_ptr(shared_from_this(), this), key.toString() + ); } -OpenaiProviderBuiltin::OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl, - QStringList modelWhitelist) - : ModelProvider(std::move(id), std::move(name), std::move(baseUrl)) +OpenaiProviderBuiltin::OpenaiProviderBuiltin(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl icon, + QUrl baseUrl, std::unordered_set modelWhitelist) + : ModelProvider(p, std::move(id), std::move(name), std::move(baseUrl)) , ModelProviderBuiltin(std::move(icon)) , ModelProviderMutable(store) , m_modelWhitelist(std::move(modelWhitelist)) @@ -173,6 +197,15 @@ OpenaiProviderBuiltin::OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QSt } } +auto OpenaiProviderBuiltin::listModels() -> QCoro::Task> +{ + auto models = co_await OpenaiProvider::listModels(); + if (!models) + co_return std::unexpected(models.error()); + models->removeIf([&](auto &m) { return !m_modelWhitelist.contains(m); }); + co_return *models; +} + auto OpenaiProviderBuiltin::asData() -> ModelProviderData { return { @@ -183,22 +216,25 @@ auto OpenaiProviderBuiltin::asData() -> ModelProviderData } /// load -OpenaiProviderCustom::OpenaiProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey) - : ModelProvider(std::move(id), std::move(name), std::move(baseUrl)) +OpenaiProviderCustom::OpenaiProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl, + QString apiKey) + : ModelProvider(p, std::move(id), std::move(name), std::move(baseUrl)) , OpenaiProvider(std::move(apiKey)) , ModelProviderCustom(store) - {} +{ + if (auto res = m_store->acquire(m_id); !res) + res.error().raise(); +} /// create -OpenaiProviderCustom::OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey) - : ModelProvider(std::move(name), std::move(baseUrl)) +OpenaiProviderCustom::OpenaiProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl, + QString apiKey) + : ModelProvider(p, QUuid::createUuid(), std::move(name), std::move(baseUrl)) , ModelProviderCustom(std::move(store)) , OpenaiProvider(std::move(apiKey)) { - auto data = m_store->create(m_name, m_baseUrl, m_apiKey); - if (!data) - data.error().raise(); - m_id = (*data)->id; + if (auto res = m_store->acquire(m_id); !res) + res.error().raise(); } auto OpenaiProviderCustom::asData() -> ModelProviderData @@ -330,8 +366,9 @@ auto OpenaiChatModel::generate(QStringView prompt, const GenerationParams *param auto &provider = *m_description->provider(); QNetworkRequest request(provider.baseUrl().resolved(QUrl("/v1/chat/completions"))); - request.setHeader(QNetworkRequest::UserAgentHeader, mySettings->userAgent()); - request.setRawHeader("authorization", u"Bearer %1"_s.arg(provider.apiKey()).toUtf8()); + request.setHeader (QNetworkRequest::UserAgentHeader, mySettings->userAgent()); + request.setHeader (QNetworkRequest::ContentTypeHeader, "application/json"_ba ); + request.setRawHeader("authorization"_ba, fmt::format("Bearer {}", provider.apiKey()).c_str()); QRestAccessManager restNam(m_nam); std::unique_ptr reply(restNam.post(request, QJsonDocument(reqBody))); diff --git a/gpt4all-chat/src/llmodel_openai.h b/gpt4all-chat/src/llmodel_openai.h index 5d5e790e..4d92c627 100644 --- a/gpt4all-chat/src/llmodel_openai.h +++ b/gpt4all-chat/src/llmodel_openai.h @@ -1,5 +1,6 @@ #pragma once +#include "creatable.h" #include "llmodel_chat.h" #include "llmodel_description.h" #include "llmodel_provider.h" @@ -16,6 +17,7 @@ #include // IWYU pragma: keep #include +#include #include class QNetworkAccessManager; @@ -28,6 +30,7 @@ namespace gpt4all::ui { class OpenaiChatModel; +class OpenaiModelDescription; struct OpenaiGenerationParamsData { uint n_predict; @@ -51,9 +54,8 @@ class OpenaiProvider : public QObject, public virtual ModelProvider { Q_PROPERTY(QString apiKey READ apiKey NOTIFY apiKeyChanged) protected: - explicit OpenaiProvider() = default; - explicit OpenaiProvider(QString apiKey) - : m_apiKey(std::move(apiKey)) {} + explicit OpenaiProvider(); + explicit OpenaiProvider(QString apiKey); public: ~OpenaiProvider() noexcept override = 0; @@ -69,50 +71,69 @@ public: auto supportedGenerationParams() const -> QSet override; auto makeGenerationParams(const QMap &values) const -> OpenaiGenerationParams * override; - auto listModels() -> QCoro::Task>; + // endpoints + auto status () -> QCoro::Task override; + auto listModels() -> QCoro::Task> override; + + // QML wrapped endpoints + Q_INVOKABLE QCoro::QmlTask statusQml (); Q_INVOKABLE QCoro::QmlTask listModelsQml(); + [[nodiscard]] auto newModel(const QString &modelName) const -> std::shared_ptr; + Q_SIGNALS: void apiKeyChanged(const QString &value); protected: + [[nodiscard]] auto newModelImpl(const QVariant &key) const -> std::shared_ptr final; + QString m_apiKey; }; -class OpenaiProviderBuiltin : public OpenaiProvider, public ModelProviderBuiltin, public ModelProviderMutable { +class OpenaiProviderBuiltin + : public OpenaiProvider + , public ModelProviderBuiltin + , public ModelProviderMutable + , public Creatable +{ Q_OBJECT Q_PROPERTY(QString name READ name CONSTANT) Q_PROPERTY(QUrl icon READ icon CONSTANT) Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT) - Q_PROPERTY(QStringList modelWhitelist READ modelWhitelist CONSTANT) public: /// Create a new built-in OpenAI provider, loading its API key from disk if known. - explicit OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl, - QStringList modelWhitelist); - - [[nodiscard]] const QStringList &modelWhitelist() { return m_modelWhitelist; } + explicit OpenaiProviderBuiltin(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl, + std::unordered_set modelWhitelist); [[nodiscard]] DataStoreResult<> setApiKey(QString value) override { return setMemberProp(&OpenaiProviderBuiltin::m_apiKey, "apiKey", std::move(value), /*createName*/ m_name); } + // override for model whitelist + auto listModels() -> QCoro::Task> override; + +Q_SIGNALS: + void apiKeyChanged(const QString &value); + protected: auto asData() -> ModelProviderData override; - QStringList m_modelWhitelist; + std::unordered_set m_modelWhitelist; }; -class OpenaiProviderCustom final : public OpenaiProvider, public ModelProviderCustom { +class OpenaiProviderCustom final + : public OpenaiProvider, public ModelProviderCustom, public Creatable +{ Q_OBJECT Q_PROPERTY(QString name READ name NOTIFY nameChanged ) Q_PROPERTY(QUrl baseUrl READ baseUrl NOTIFY baseUrlChanged) public: /// Load an existing OpenaiProvider from disk. - explicit OpenaiProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey); + explicit OpenaiProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey); /// Create a new OpenaiProvider on disk. - explicit OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey); + explicit OpenaiProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl, QString apiKey); [[nodiscard]] DataStoreResult<> setApiKey(QString value) override { return setMemberProp(&OpenaiProviderCustom::m_apiKey, "apiKey", std::move(value)); } @@ -126,17 +147,13 @@ protected: auto asData() -> ModelProviderData override; }; -class OpenaiModelDescription : public ModelDescription { +class OpenaiModelDescription : public ModelDescription, public Creatable { Q_GADGET Q_PROPERTY(QString modelName READ modelName CONSTANT) public: explicit OpenaiModelDescription(protected_t, std::shared_ptr provider, QString modelName); - static auto create(std::shared_ptr provider, QByteArray modelHash) - -> std::shared_ptr - { return std::make_shared(protected_t(), std::move(provider), std::move(modelHash)); } - // getters [[nodiscard]] auto provider () const -> const OpenaiProvider * override { return m_provider.get(); } [[nodiscard]] QVariant key () const override { return m_modelName; } diff --git a/gpt4all-chat/src/llmodel_provider.cpp b/gpt4all-chat/src/llmodel_provider.cpp index 95ab16e2..45cac7c0 100644 --- a/gpt4all-chat/src/llmodel_provider.cpp +++ b/gpt4all-chat/src/llmodel_provider.cpp @@ -8,10 +8,13 @@ #include #include // IWYU pragma: keep +#include #include // IWYU pragma: keep #include -namespace fs = std::filesystem; +#include + +namespace ranges = std::ranges; namespace gpt4all::ui { @@ -43,12 +46,48 @@ QVariant GenerationParams::tryParseValue(QMap &values return value; } +ProviderStatus::ProviderStatus(const backend::ResponseError &error) + : m_ok(false) +{ + auto &code = error.error(); + if (auto *badStatus = std::get_if(&code)) { + m_detail = QObject::tr("HTTP %1%2%3").arg( + QString::number(badStatus->code), + badStatus->reason ? u" "_s : QString(), + badStatus->reason.value_or(QString()) + ); + return; + } + if (auto *netErr = std::get_if(&code)) { + auto meta = QMetaEnum::fromType(); + m_detail = QString::fromUtf8(meta.valueToKey(*netErr)); + return; + } + m_detail = QObject::tr("(unknown error)"); +} + +ModelProvider::ModelProvider(protected_t, QUuid id, QString name, QUrl baseUrl) // create built-in or load + : m_id(std::move(id)), m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) +{ Q_ASSERT(!m_id.isNull()); } + ModelProvider::~ModelProvider() noexcept = default; +auto ModelProvider::newModel(const QVariant &key) const -> std::shared_ptr +{ return newModelImpl(key); } + ModelProviderMutable::~ModelProviderMutable() noexcept { - if (auto res = m_store->release(m_id); !res) - res.error().raise(); // should not happen - will terminate program + if (!m_id.isNull()) // (will be null if constructor throws) + if (auto res = m_store->release(m_id); !res) + res.error().raise(); // should not happen - will terminate program +} + +auto ModelProviderCustom::persist() -> DataStoreResult<> +{ + if (auto res = m_store->create(asData()); !res) + return res; + m_persisted = true; + return {}; } ProviderRegistry::ProviderRegistry(PathSet paths) @@ -70,18 +109,27 @@ ProviderRegistry *ProviderRegistry::globalInstance() void ProviderRegistry::load() { - size_t i = 0; + auto registerListener = [this](ModelProvider *provider) { + // listen for any change in the provider so we can tell the model about it + if (auto *mut = dynamic_cast(provider)) + connect(mut->asQObject(), "apiKeyChanged", this, "onProviderChanged"); + if (auto *cust = dynamic_cast(provider)) { + connect(cust->asQObject(), "nameChanged", this, "onProviderChanged"); + connect(cust->asQObject(), "baseUrlChanged", this, "onProviderChanged"); + } + }; for (auto &p : s_builtinProviders) { // (not all builtin providers are stored) - auto provider = std::make_shared( + auto provider = OpenaiProviderBuiltin::create( &m_builtinStore, p.id, p.name, p.icon, p.base_url, - QStringList(p.model_whitelist.begin(), p.model_whitelist.end()) + std::unordered_set(p.model_whitelist.begin(), p.model_whitelist.end()) ); - auto [_, unique] = m_providers.emplace(p.id, std::move(provider)); + auto [it, unique] = m_providers.emplace(p.id, std::move(provider)); if (!unique) throw std::logic_error(fmt::format("duplicate builtin provider id: {}", p.id.toString())); - m_builtinProviders[i++] = p.id; + m_providerList.push_back(&p.id); + registerListener(it->second.get()); } - for (auto &p : m_customStore.list()) { // disk is source of truth for custom providers + for (auto p : m_customStore.list()) { // disk is source of truth for custom providers if (!p.custom_details) { qWarning() << "ignoring builtin provider in custom store:" << p.id; continue; @@ -91,54 +139,62 @@ void ProviderRegistry::load() switch (p.type()) { using enum ProviderType; case ollama: - provider = std::make_shared( + provider = OllamaProviderCustom::create( &m_customStore, p.id, cust.name, cust.base_url ); break; case openai: - provider = std::make_shared( + provider = OpenaiProviderCustom::create( &m_customStore, p.id, cust.name, cust.base_url, std::get(p.provider_details).api_key ); } - auto [_, unique] = m_providers.emplace(p.id, std::move(provider)); + auto [it, unique] = m_providers.emplace(p.id, std::move(provider)); if (!unique) qWarning() << "ignoring duplicate custom provider with id:" << p.id; - m_customProviders.push_back(std::make_unique(p.id)); + m_providerList.push_back(&it->second->id()); + registerListener(it->second.get()); } } -[[nodiscard]] -bool ProviderRegistry::add(std::shared_ptr provider) +auto ProviderRegistry::add(std::shared_ptr provider) -> DataStoreResult<> { + if (auto res = provider->persist(); !res) + return res; auto [it, unique] = m_providers.emplace(provider->id(), std::move(provider)); - if (unique) { - m_customProviders.push_back(std::make_unique(it->first)); - emit customProviderAdded(m_customProviders.size() - 1); + if (!unique) + return std::unexpected(u"custom provider already registered: %1"_s.arg(provider->id().toString())); + m_providerList.push_back(&it->second->id()); + emit customProviderAdded(m_providerList.size() - 1); + return {}; +} + +bool ProviderRegistry::addQml(QmlSharedPtr *provider) +{ + auto obj = std::dynamic_pointer_cast(provider->ptr()); + if (!obj) { + qWarning() << "ProviderRegistry::add failed: Expected ModelProviderCustom, got" + << provider->metaObject()->className(); + return false; } - return unique; + auto res = add(obj); + if (!res) + qWarning() << "ProviderRegistry::add failed:" << res.error().errorString(); + return bool(res); } -auto ProviderRegistry::customProviderAt(size_t i) const -> ModelProviderCustom * +auto ProviderRegistry::providerAt(size_t i) const -> const ModelProvider * { - auto it = m_providers.find(*m_customProviders.at(i)); + auto it = m_providers.find(*m_providerList.at(i)); Q_ASSERT(it != m_providers.end()); - return &dynamic_cast(*it->second); -} - -auto ProviderRegistry::builtinProviderAt(size_t i) const -> ModelProviderBuiltin * -{ - auto it = m_providers.find(m_builtinProviders.at(i)); - Q_ASSERT(it != m_providers.end()); - return &dynamic_cast(*it->second); - + return it->second.get(); } auto ProviderRegistry::getSubdirs() -> PathSet { auto *mysettings = MySettings::globalInstance(); auto parent = toFSPath(mysettings->modelPath()) / "providers"; - return { .builtin = parent, .custom = parent / "custom" }; + return { .builtin = parent, .custom = parent / "custom" }; } void ProviderRegistry::onModelPathChanged() @@ -147,8 +203,8 @@ void ProviderRegistry::onModelPathChanged() if (paths.builtin != m_builtinStore.path()) { emit aboutToBeCleared(); // delete providers to release store locks - m_customProviders.clear(); m_providers.clear(); + m_providerList.clear(); if (auto res = m_builtinStore.setPath(paths.builtin); !res) res.error().raise(); // should not happen if (auto res = m_customStore.setPath(paths.custom); !res) @@ -157,42 +213,58 @@ void ProviderRegistry::onModelPathChanged() } } -auto BuiltinProviderList::roleNames() const -> QHash -{ return { { Qt::DisplayRole, "data"_ba } }; } - -QVariant BuiltinProviderList::data(const QModelIndex &index, int role) const +void ProviderRegistry::onProviderChanged() { - auto *registry = ProviderRegistry::globalInstance(); - if (index.isValid() && index.row() < rowCount() && role == Qt::DisplayRole) - return QVariant::fromValue(registry->builtinProviderAt(index.row())->asQObject()); - return {}; + // notify that this provider has changed + auto *obj = &dynamic_cast(*QObject::sender()); + auto it = ranges::find_if(m_providerList, [&](auto *id) { return *id == obj->id(); }); + if (it < m_providerList.end()) + emit customProviderChanged(it - m_providerList.begin()); } -CustomProviderList::CustomProviderList() - : m_size(ProviderRegistry::globalInstance()->customProviderCount()) +ProviderList::ProviderList() + : m_size(ProviderRegistry::globalInstance()->providerCount()) { auto *registry = ProviderRegistry::globalInstance(); - connect(registry, &ProviderRegistry::customProviderAdded, this, &CustomProviderList::onCustomProviderAdded); - connect(registry, &ProviderRegistry::aboutToBeCleared, this, &CustomProviderList::onAboutToBeCleared, + connect(registry, &ProviderRegistry::customProviderAdded, this, &ProviderList::onCustomProviderAdded); + connect(registry, &ProviderRegistry::customProviderRemoved, this, &ProviderList::onCustomProviderRemoved); + connect(registry, &ProviderRegistry::customProviderChanged, this, &ProviderList::onCustomProviderChanged); + connect(registry, &ProviderRegistry::aboutToBeCleared, this, &ProviderList::onAboutToBeCleared, Qt::DirectConnection); } -QVariant CustomProviderList::data(const QModelIndex &index, int role) const +auto ProviderList::roleNames() const -> QHash +{ return { { Qt::DisplayRole, "provider"_ba } }; } + +QVariant ProviderList::data(const QModelIndex &index, int role) const { auto *registry = ProviderRegistry::globalInstance(); if (index.isValid() && index.row() < rowCount() && role == Qt::DisplayRole) - return QVariant::fromValue(registry->customProviderAt(index.row())->asQObject()); + return QVariant::fromValue(registry->providerAt(index.row())->asQObject()); return {}; } -void CustomProviderList::onCustomProviderAdded(size_t index) +void ProviderList::onCustomProviderAdded(size_t index) { - beginInsertRows({}, m_size, m_size); + beginInsertRows({}, index, index); m_size++; endInsertRows(); } -void CustomProviderList::onAboutToBeCleared() +void ProviderList::onCustomProviderRemoved(size_t index) +{ + beginRemoveRows({}, index, index); + m_size--; + endRemoveRows(); +} + +void ProviderList::onCustomProviderChanged(size_t index) +{ + auto i = this->index(index); + emit dataChanged(i, i); +} + +void ProviderList::onAboutToBeCleared() { beginResetModel(); m_size = 0; @@ -203,8 +275,13 @@ bool ProviderListSort::lessThan(const QModelIndex &left, const QModelIndex &righ { auto *leftData = sourceModel()->data(left ).value(); auto *rightData = sourceModel()->data(right).value(); - if (leftData && rightData) - return QString::localeAwareCompare(leftData->name(), rightData->name()) < 0; + if (leftData && rightData) { + if (leftData->isBuiltin() != rightData->isBuiltin()) + return leftData->isBuiltin() > rightData->isBuiltin(); // builtins first + if (leftData->isBuiltin()) + return left.row() < right.row(); // preserve order of builtins + return QString::localeAwareCompare(leftData->name(), rightData->name()) < 0; // sort by name + } return true; } diff --git a/gpt4all-chat/src/llmodel_provider.h b/gpt4all-chat/src/llmodel_provider.h index 1242baf8..f811d254 100644 --- a/gpt4all-chat/src/llmodel_provider.h +++ b/gpt4all-chat/src/llmodel_provider.h @@ -2,23 +2,29 @@ #include "store_provider.h" +#include "qmlsharedptr.h" // IWYU pragma: keep #include "utils.h" // IWYU pragma: keep +#include + #include #include #include // IWYU pragma: keep #include #include +#include // IWYU pragma: keep #include #include #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -26,6 +32,10 @@ class QByteArray; class QJSEngine; template class QHash; +namespace QCoro { + template class Task; + struct QmlTask; +} namespace gpt4all::ui { @@ -33,6 +43,30 @@ namespace gpt4all::ui { Q_NAMESPACE +class ModelDescription; + +namespace detail { + +template +struct is_expected_impl : std::false_type {}; + +template +struct is_expected_impl> : std::true_type {}; + +template +concept is_expected = is_expected_impl>::value; + +} // namespace detail + +/// Drop the type and error information from a QCoro::Task> so it can be used by QML. +template + requires (!detail::is_expected::value_type>) +QCoro::QmlTask wrapQmlTask(std::shared_ptr c, F f, QString prefix, Args &&...args); + +template + requires detail::is_expected::value_type> +QCoro::QmlTask wrapQmlTask(std::shared_ptr c, F f, QString prefix, Args &&...args); + enum class GenerationParam { NPredict, Temperature, @@ -61,12 +95,28 @@ protected: void tryParseValue(this S &self, QMap &values, GenerationParam key, T C::* dest); }; -class ModelProvider { +class ProviderStatus { + Q_GADGET + Q_PROPERTY(bool ok READ ok CONSTANT) + Q_PROPERTY(QString detail READ detail CONSTANT) + +public: + explicit ProviderStatus(QString okMsg): m_ok(true), m_detail(std::move(okMsg)) {} + explicit ProviderStatus(const backend::ResponseError &error); + + bool ok () const { return m_ok; } + const QString &detail() const { return m_detail; } + +private: + bool m_ok; + QString m_detail; +}; + +class ModelProvider : public std::enable_shared_from_this { protected: - explicit ModelProvider(QUuid id, QString name, QUrl baseUrl) // create built-in or load - : m_id(std::move(id)), m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {} - explicit ModelProvider(QString name, QUrl baseUrl) // create custom - : m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {} + struct protected_t { explicit protected_t() = default; }; + + explicit ModelProvider(protected_t, QUuid id, QString name, QUrl baseUrl); public: virtual ~ModelProvider() noexcept = 0; @@ -74,6 +124,8 @@ public: virtual QObject *asQObject() = 0; virtual const QObject *asQObject() const = 0; + virtual bool isBuiltin() const = 0; + // getters [[nodiscard]] const QUuid &id () const { return m_id; } [[nodiscard]] const QString &name () const { return m_name; } @@ -82,13 +134,24 @@ public: virtual auto supportedGenerationParams() const -> QSet = 0; virtual auto makeGenerationParams(const QMap &values) const -> GenerationParams * = 0; + // endpoints + virtual auto status () -> QCoro::Task = 0; + virtual auto listModels() -> QCoro::Task> = 0; + + /// create a model using this provider + [[nodiscard]] auto newModel(const QVariant &key) const -> std::shared_ptr; + friend bool operator==(const ModelProvider &a, const ModelProvider &b) { return a.m_id == b.m_id; } protected: + [[nodiscard]] virtual auto newModelImpl(const QVariant &key) const -> std::shared_ptr = 0; + QUuid m_id; QString m_name; QUrl m_baseUrl; + + template friend struct Creatable; }; class ModelProviderBuiltin : public virtual ModelProvider { @@ -97,6 +160,8 @@ protected: : m_icon(std::move(icon)) {} public: + bool isBuiltin() const final { return true; } + [[nodiscard]] const QUrl &icon() const { return m_icon; } protected: @@ -119,6 +184,8 @@ protected: [[nodiscard]] DataStoreResult<> setMemberProp(this S &self, T C::* member, std::string_view name, T value, std::optional createName = {}); + [[nodiscard]] virtual bool persisted() const { return true; } + ProviderStore *m_store; }; @@ -128,11 +195,20 @@ protected: : ModelProviderMutable(store) {} public: + bool isBuiltin() const final { return false; } + // setters [[nodiscard]] DataStoreResult<> setName (QString value) { return setMemberProp(&ModelProviderCustom::m_name, "name", std::move(value)); } [[nodiscard]] DataStoreResult<> setBaseUrl(QUrl value) { return setMemberProp(&ModelProviderCustom::m_baseUrl, "baseUrl", std::move(value)); } + + [[nodiscard]] auto persist() -> DataStoreResult<>; + +protected: + [[nodiscard]] bool persisted() const override { return m_persisted; } + + bool m_persisted = false; }; class ProviderRegistry : public QObject { @@ -156,17 +232,21 @@ protected: public: static ProviderRegistry *globalInstance(); - [[nodiscard]] bool add(std::shared_ptr provider); + [[nodiscard]] auto add(std::shared_ptr provider) -> DataStoreResult<>; + Q_INVOKABLE bool addQml(QmlSharedPtr *provider); + + // TODO(jared): implement a way to remove custom providers via the model auto operator[](const QUuid &id) -> const ModelProvider * { return m_providers.at(id).get(); } - // TODO(jared): implement a way to remove custom providers via the model - [[nodiscard]] size_t customProviderCount () const { return m_customProviders.size(); } - [[nodiscard]] auto customProviderAt (size_t i) const -> ModelProviderCustom *; - [[nodiscard]] size_t builtinProviderCount() const { return m_builtinProviders.size(); } - [[nodiscard]] auto builtinProviderAt (size_t i) const -> ModelProviderBuiltin *; + [[nodiscard]] size_t providerCount() const { return m_providers.size(); } + [[nodiscard]] auto providerAt(size_t i) const -> const ModelProvider *; + + ProviderStore *customStore() { return &m_customStore; } Q_SIGNALS: void customProviderAdded(size_t index); + void customProviderRemoved(size_t index); // TODO: use + void customProviderChanged(size_t index); void aboutToBeCleared(); private: @@ -175,6 +255,7 @@ private: private Q_SLOTS: void onModelPathChanged(); + void onProviderChanged(); private: static constexpr size_t N_BUILTIN = 3; @@ -183,51 +264,32 @@ private: ProviderStore m_customStore; ProviderStore m_builtinStore; std::unordered_map> m_providers; - std::vector> m_customProviders; - std::array m_builtinProviders; + std::vector m_providerList; // TODO: implement }; -// TODO: api keys are allowed to change for here and also below. That should emit dataChanged. -class BuiltinProviderList : public QAbstractListModel { +class ProviderList : public QAbstractListModel { Q_OBJECT - QML_SINGLETON - QML_ELEMENT public: - explicit BuiltinProviderList() - : m_size(ProviderRegistry::globalInstance()->builtinProviderCount()) {} + explicit ProviderList(); - static BuiltinProviderList *create(QQmlEngine *, QJSEngine *) { return new BuiltinProviderList(); } + static ProviderList *create(QQmlEngine *, QJSEngine *) { return new ProviderList(); } auto roleNames() const -> QHash override; int rowCount(const QModelIndex &parent = {}) const override { Q_UNUSED(parent) return int(m_size); } QVariant data(const QModelIndex &index, int role) const override; -private: - size_t m_size; -}; - -class CustomProviderList : public QAbstractListModel { - Q_OBJECT - -public: - explicit CustomProviderList(); - - int rowCount(const QModelIndex &parent = {}) const override - { Q_UNUSED(parent) return int(m_size); } - QVariant data(const QModelIndex &index, int role) const override; - private Q_SLOTS: - void onCustomProviderAdded(size_t index); + void onCustomProviderAdded (size_t index); + void onCustomProviderRemoved(size_t index); + void onCustomProviderChanged(size_t index); void onAboutToBeCleared(); private: size_t m_size; }; -// todo: don't have singletons use singletons directly -// TODO: actually use the provider sort, here, rather than unsorted, for builtins class ProviderListSort : public QSortFilterProxyModel { Q_OBJECT QML_SINGLETON @@ -243,8 +305,7 @@ protected: bool lessThan(const QModelIndex &left, const QModelIndex &right) const override; private: - // TODO: support custom providers as well - BuiltinProviderList m_model; + ProviderList m_model; }; diff --git a/gpt4all-chat/src/llmodel_provider.inl b/gpt4all-chat/src/llmodel_provider.inl index fda1c02e..837763a3 100644 --- a/gpt4all-chat/src/llmodel_provider.inl +++ b/gpt4all-chat/src/llmodel_provider.inl @@ -1,9 +1,43 @@ #include +#include +#include + +#include +#include +#include + +#include +#include + namespace gpt4all::ui { +template + requires (!detail::is_expected::value_type>) +QCoro::QmlTask wrapQmlTask(C *obj, F f, QString prefix, Args &&...args) +{ + std::shared_ptr ptr(obj->shared_from_this(), obj); + return [](std::shared_ptr ptr, F f, QString prefix, Args &&...args) -> QCoro::Task { + co_return QVariant::fromValue(co_await std::invoke(f, ptr.get(), std::forward(args)...)); + }(std::move(ptr), std::move(f), std::move(prefix), std::forward(args)...); +} + +template + requires detail::is_expected::value_type> +QCoro::QmlTask wrapQmlTask(C *obj, F f, QString prefix, Args &&...args) +{ + std::shared_ptr ptr(obj->shared_from_this(), obj); + return [](std::shared_ptr ptr, F f, QString prefix, Args &&...args) -> QCoro::Task { + auto result = co_await std::invoke(f, ptr.get(), std::forward(args)...); + if (result) + co_return QVariant::fromValue(*result); + qWarning().noquote() << prefix << "failed:" << result.error().errorString(); + co_return QVariant::fromValue(nullptr); + }(std::move(ptr), std::move(f), std::move(prefix), std::forward(args)...); +} + template void GenerationParams::tryParseValue(this S &self, QMap &values, GenerationParam key, T C::* dest) @@ -20,9 +54,11 @@ auto ModelProviderMutable::setMemberProp(this S &self, T C::* member, std::strin auto &cur = self.*member; if (cur != value) { cur = std::move(value); - auto data = mpc.asData(); - if (auto res = mpc.m_store->setData(std::move(data), createName); !res) - return res; + if (mpc.persisted()) { + auto data = mpc.asData(); + if (auto res = mpc.m_store->setData(std::move(data), createName); !res) + return res; + } QMetaObject::invokeMethod(self.asQObject(), fmt::format("{}Changed", name).c_str(), cur); } return {}; diff --git a/gpt4all-chat/src/main.cpp b/gpt4all-chat/src/main.cpp index 9d8810be..1be7b2bc 100644 --- a/gpt4all-chat/src/main.cpp +++ b/gpt4all-chat/src/main.cpp @@ -2,6 +2,7 @@ #include "config.h" #include "download.h" #include "llm.h" +#include "llmodel_provider.h" #include "localdocs.h" #include "logger.h" #include "modellist.h" @@ -153,6 +154,7 @@ int main(int argc, char *argv[]) qmlRegisterSingletonInstance("network", 1, 0, "Network", Network::globalInstance()); qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance()); qmlRegisterSingletonInstance("toollist", 1, 0, "ToolList", ToolModel::globalInstance()); + qmlRegisterSingletonInstance("gpt4all.ProviderRegistry", 1, 0, "ProviderRegistry", ProviderRegistry::globalInstance()); qmlRegisterUncreatableMetaObject(ToolEnums::staticMetaObject, "toolenums", 1, 0, "ToolEnums", "Error: only enums"); qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums"); diff --git a/gpt4all-chat/src/mysettings.cpp b/gpt4all-chat/src/mysettings.cpp index a12daaba..5527ba20 100644 --- a/gpt4all-chat/src/mysettings.cpp +++ b/gpt4all-chat/src/mysettings.cpp @@ -178,9 +178,9 @@ MySettings::MySettings() { } -const QString &MySettings::userAgent() +const QByteArray &MySettings::userAgent() { - static const QString s_userAgent = QStringLiteral("gpt4all/" APP_VERSION); + static const QByteArray s_userAgent = QByteArrayLiteral("gpt4all/" APP_VERSION); return s_userAgent; } diff --git a/gpt4all-chat/src/mysettings.h b/gpt4all-chat/src/mysettings.h index 251ec5ed..0dfd3eab 100644 --- a/gpt4all-chat/src/mysettings.h +++ b/gpt4all-chat/src/mysettings.h @@ -88,7 +88,7 @@ public Q_SLOTS: public: static MySettings *globalInstance(); - static const QString &userAgent(); + static const QByteArray &userAgent(); Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl); diff --git a/gpt4all-chat/src/qmlfunctions.cpp b/gpt4all-chat/src/qmlfunctions.cpp new file mode 100644 index 00000000..cef126d3 --- /dev/null +++ b/gpt4all-chat/src/qmlfunctions.cpp @@ -0,0 +1,41 @@ +#include "qmlfunctions.h" + +#include "llmodel_ollama.h" +#include "llmodel_openai.h" +#include "llmodel_provider.h" + +#include +#include + + +namespace gpt4all::ui { + + +QmlSharedPtr *QmlFunctions::newCustomOpenaiProvider(QString name, QUrl baseUrl, QString apiKey) const +{ + auto *store = ProviderRegistry::globalInstance()->customStore(); + std::shared_ptr ptr; + try { + ptr = OpenaiProviderCustom::create(store, std::move(name), std::move(baseUrl), std::move(apiKey)); + } catch (const std::exception &e) { + qWarning() << "newCustomOpenaiProvider failed:" << e.what(); + return nullptr; + } + return new QmlSharedPtr(std::move(ptr)); +} + +QmlSharedPtr *QmlFunctions::newCustomOllamaProvider(QString name, QUrl baseUrl) const +{ + auto *store = ProviderRegistry::globalInstance()->customStore(); + std::shared_ptr ptr; + try { + ptr = OllamaProviderCustom::create(store, std::move(name), std::move(baseUrl)); + } catch (const std::exception &e) { + qWarning() << "newCustomOllamaProvider failed:" << e.what(); + return nullptr; + } + return new QmlSharedPtr(std::move(ptr)); +} + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/qmlfunctions.h b/gpt4all-chat/src/qmlfunctions.h new file mode 100644 index 00000000..52862df0 --- /dev/null +++ b/gpt4all-chat/src/qmlfunctions.h @@ -0,0 +1,31 @@ +#pragma once + +#include "qmlsharedptr.h" // IWYU pragma: keep + +#include +#include +#include // IWYU pragma: keep +#include // IWYU pragma: keep + + +namespace gpt4all::ui { + + +// The singleton through which all static methods and free functions are called in QML. +class QmlFunctions : public QObject { + Q_OBJECT + QML_ELEMENT + QML_SINGLETON + + explicit QmlFunctions() = default; + +public: + + static QmlFunctions *create(QQmlEngine *, QJSEngine *) { return new QmlFunctions; } + + Q_INVOKABLE QmlSharedPtr *newCustomOpenaiProvider(QString name, QUrl baseUrl, QString apiKey) const; + Q_INVOKABLE QmlSharedPtr *newCustomOllamaProvider(QString name, QUrl baseUrl) const; +}; + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/qmlsharedptr.cpp b/gpt4all-chat/src/qmlsharedptr.cpp new file mode 100644 index 00000000..b3e23230 --- /dev/null +++ b/gpt4all-chat/src/qmlsharedptr.cpp @@ -0,0 +1,14 @@ +#include "qmlsharedptr.h" + +#include + + +namespace gpt4all::ui { + + +QmlSharedPtr::QmlSharedPtr(std::shared_ptr ptr) + : m_ptr(std::move(ptr)) +{ if (m_ptr) { QJSEngine::setObjectOwnership(m_ptr.get(), QJSEngine::CppOwnership); } } + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/qmlsharedptr.h b/gpt4all-chat/src/qmlsharedptr.h new file mode 100644 index 00000000..e46fe155 --- /dev/null +++ b/gpt4all-chat/src/qmlsharedptr.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include + + +namespace gpt4all::ui { + + +class QmlSharedPtr : public QObject { + Q_OBJECT + +public: + explicit QmlSharedPtr(std::shared_ptr ptr); + + const std::shared_ptr &ptr() { return m_ptr; } + Q_INVOKABLE QObject *get() { return m_ptr.get(); } + +private: + std::shared_ptr m_ptr; +}; + + +} // namespace gpt4all::ui diff --git a/gpt4all-chat/src/store_base.cpp b/gpt4all-chat/src/store_base.cpp index 294b9bc5..d8c1b9d1 100644 --- a/gpt4all-chat/src/store_base.cpp +++ b/gpt4all-chat/src/store_base.cpp @@ -2,6 +2,7 @@ #include #include // IWYU pragma: keep +#include #include #include @@ -157,39 +158,32 @@ auto DataStoreBase::read(QFileDevice &file, json::stream_parser &parser) -> Data } }; - auto inner = [&] -> DataStoreResult<> { - bool partialRead = false; - auto chunkIt = iterChunks(); - // read JSON data + bool partialRead = false; + auto chunkIt = iterChunks(); + + // read JSON data + parser.reset(); + for (auto &chunk : chunkIt) { + if (!chunk) + return std::unexpected(chunk.error()); + size_t nRead = parser.write_some(chunk->data(), chunk->size()); + // consume trailing whitespace in chunk + if (nRead < chunk->size()) { + auto rest = QByteArrayView(*chunk).slice(nRead); + if (!rest.trimmed().isEmpty()) + return std::unexpected(u"unexpected data after json: \"%1\""_s.arg(QByteArray(rest))); + partialRead = true; + break; + } + } + // consume trailing whitespace in file + if (partialRead) { for (auto &chunk : chunkIt) { if (!chunk) return std::unexpected(chunk.error()); - size_t nRead = parser.write_some(chunk->data(), chunk->size()); - // consume trailing whitespace in chunk - if (nRead < chunk->size()) { - auto rest = QByteArrayView(*chunk).slice(nRead); - if (!rest.trimmed().isEmpty()) - return std::unexpected(u"unexpected data after json: \"%1\""_s.arg(QByteArray(rest))); - partialRead = true; - break; - } + if (!chunk->trimmed().isEmpty()) + return std::unexpected(u"unexpected data after json: \"%1\""_s.arg(*chunk)); } - // consume trailing whitespace in file - if (partialRead) { - for (auto &chunk : chunkIt) { - if (!chunk) - return std::unexpected(chunk.error()); - if (!chunk->trimmed().isEmpty()) - return std::unexpected(u"unexpected data after json: \"%1\""_s.arg(*chunk)); - } - } - return {}; - }; - - auto res = inner(); - if (!res) { - parser.reset(); - return std::unexpected(res.error()); } return parser.release(); } diff --git a/gpt4all-chat/src/store_base.h b/gpt4all-chat/src/store_base.h index 49bb9524..ae64c86c 100644 --- a/gpt4all-chat/src/store_base.h +++ b/gpt4all-chat/src/store_base.h @@ -4,7 +4,6 @@ #include // IWYU pragma: keep #include // IWYU pragma: keep -#include #include #include @@ -22,6 +21,8 @@ #include #include +#include + class QByteArray; class QSaveFile; @@ -96,7 +97,7 @@ class DataStore : public DataStoreBase { public: explicit DataStore(std::filesystem::path path); - auto list() -> tl::generator; + auto list() { return m_entries | std::views::transform([](auto &e) { return e.second; }); } auto setData(T data, std::optional createName = {}) -> DataStoreResult<>; auto remove(const QUuid &id) -> DataStoreResult<>; @@ -109,7 +110,7 @@ public: { auto it = m_entries.find(id); return it == m_entries.end() ? std::nullopt : std::optional(&it->second); } protected: - auto createImpl(T data, const QString &name) -> DataStoreResult; + auto createImpl(T data, const QString &name) -> DataStoreResult<>; auto clear() -> DataStoreResult<> final; CacheInsertResult cacheInsert(const boost::json::value &jv) override; diff --git a/gpt4all-chat/src/store_base.inl b/gpt4all-chat/src/store_base.inl index c81e2870..acb846c1 100644 --- a/gpt4all-chat/src/store_base.inl +++ b/gpt4all-chat/src/store_base.inl @@ -6,8 +6,11 @@ #include #include +#include #include +namespace views = std::views; + namespace gpt4all::ui { @@ -21,14 +24,7 @@ DataStore::DataStore(std::filesystem::path path) } template -auto DataStore::list() -> tl::generator -{ - for (auto &[_, value] : m_entries) - co_yield value; -} - -template -auto DataStore::createImpl(T data, const QString &name) -> DataStoreResult +auto DataStore::createImpl(T data, const QString &name) -> DataStoreResult<> { // acquire path auto file = openNew(name); @@ -42,12 +38,7 @@ auto DataStore::createImpl(T data, const QString &name) -> DataStoreResultsecond; + return {}; } template diff --git a/gpt4all-chat/src/store_provider.cpp b/gpt4all-chat/src/store_provider.cpp index 8597f96d..6edd32db 100644 --- a/gpt4all-chat/src/store_provider.cpp +++ b/gpt4all-chat/src/store_provider.cpp @@ -57,25 +57,9 @@ auto tag_invoke(const boost::json::value_to_tag &, const boos }; } -auto ProviderStore::create(QString name, QUrl base_url, QString api_key) - -> DataStoreResult +auto ProviderStore::create(ModelProviderData data) -> DataStoreResult<> { - ModelProviderData data { - .id = QUuid::createUuid(), - .custom_details = CustomProviderDetails { name, std::move(base_url) }, - .provider_details = OpenaiProviderDetails { std::move(api_key) }, - }; - return createImpl(std::move(data), name); -} - -auto ProviderStore::create(QString name, QUrl base_url) - -> DataStoreResult -{ - ModelProviderData data { - .id = QUuid::createUuid(), - .custom_details = CustomProviderDetails { name, std::move(base_url) }, - .provider_details = {}, - }; + auto name = data.custom_details.value().name; return createImpl(std::move(data), name); } diff --git a/gpt4all-chat/src/store_provider.h b/gpt4all-chat/src/store_provider.h index 078a3ec4..b64dc343 100644 --- a/gpt4all-chat/src/store_provider.h +++ b/gpt4all-chat/src/store_provider.h @@ -49,10 +49,7 @@ private: public: using Super::Super; - /// OpenAI - auto create(QString name, QUrl base_url, QString api_key) -> DataStoreResult; - /// Ollama - auto create(QString name, QUrl base_url) -> DataStoreResult; + auto create(ModelProviderData data) -> DataStoreResult<>; };