WIP: provider page in the "add models" view

This commit is contained in:
Jared Van Bortel
2025-03-19 10:49:39 -04:00
parent 8294a5cd58
commit 9772027e5e
32 changed files with 998 additions and 339 deletions

View File

@@ -12,6 +12,8 @@
#include <QUrl>
#include <expected>
#include <memory>
#include <optional>
#include <utility>
#include <variant>
@@ -24,7 +26,7 @@ namespace gpt4all::backend {
struct ResponseError {
public:
struct BadStatus { int code; };
struct BadStatus { int code; std::optional<QString> reason; };
using ErrorCode = std::variant<
QNetworkReply::NetworkError,
boost::system::error_code,
@@ -34,8 +36,8 @@ public:
ResponseError(const QRestReply *reply);
ResponseError(const boost::system::system_error &e);
const ErrorCode &error () { return m_error; }
const QString &errorString() { return m_errorString; }
const ErrorCode &error () const { return m_error; }
const QString &errorString() const { return m_errorString; }
private:
ErrorCode m_error;
@@ -47,9 +49,10 @@ using DataOrRespErr = std::expected<T, ResponseError>;
class OllamaClient {
public:
OllamaClient(QUrl baseUrl, QString m_userAgent)
: m_baseUrl(baseUrl)
OllamaClient(QUrl baseUrl, QString m_userAgent, QNetworkAccessManager *nam)
: m_baseUrl(std::move(baseUrl))
, m_userAgent(std::move(m_userAgent))
, m_nam(nam)
{}
const QUrl &baseUrl() const { return m_baseUrl; }
@@ -70,28 +73,28 @@ public:
private:
QNetworkRequest makeRequest(const QString &path) const;
auto processResponse(QNetworkReply &reply) -> QCoro::Task<DataOrRespErr<boost::json::value>>;
auto processResponse(std::unique_ptr<QNetworkReply> reply) -> QCoro::Task<DataOrRespErr<boost::json::value>>;
template <typename Resp>
auto get(const QString &path) -> QCoro::Task<DataOrRespErr<Resp>>;
auto get(QString path) -> QCoro::Task<DataOrRespErr<Resp>>;
template <typename Resp, typename Req>
auto post(const QString &path, Req const &body) -> QCoro::Task<DataOrRespErr<Resp>>;
auto post(QString path, Req const &body) -> QCoro::Task<DataOrRespErr<Resp>>;
auto getJson(const QString &path) -> QCoro::Task<DataOrRespErr<boost::json::value>>;
auto getJson(QString path) -> QCoro::Task<DataOrRespErr<boost::json::value>>;
auto postJson(const QString &path, const boost::json::value &body)
-> QCoro::Task<DataOrRespErr<boost::json::value>>;
private:
QUrl m_baseUrl;
QString m_userAgent;
QNetworkAccessManager m_nam;
QNetworkAccessManager *m_nam;
boost::json::stream_parser m_parser;
};
extern template auto OllamaClient::get(const QString &) -> QCoro::Task<DataOrRespErr<ollama::VersionResponse>>;
extern template auto OllamaClient::get(const QString &) -> QCoro::Task<DataOrRespErr<ollama::ListResponse>>;
extern template auto OllamaClient::get(QString) -> QCoro::Task<DataOrRespErr<ollama::VersionResponse>>;
extern template auto OllamaClient::get(QString) -> QCoro::Task<DataOrRespErr<ollama::ListResponse>>;
extern template auto OllamaClient::post(const QString &, const ollama::ShowRequest &)
extern template auto OllamaClient::post(QString, const ollama::ShowRequest &)
-> QCoro::Task<DataOrRespErr<ollama::ShowResponse>>;

View File

@@ -51,7 +51,7 @@ struct ListModelResponse {
QString digest;
std::optional<ModelDetails> details;
};
BOOST_DESCRIBE_STRUCT(ListModelResponse, (), (model, modified_at, size, digest, details))
BOOST_DESCRIBE_STRUCT(ListModelResponse, (), (name, model, modified_at, size, digest, details))
using ToolCallFunctionArguments = boost::json::object;

View File

@@ -29,7 +29,8 @@ ResponseError::ResponseError(const QRestReply *reply)
if (reply->hasError()) {
m_error = reply->networkReply()->error();
} else if (!reply->isHttpStatusSuccess()) {
m_error = BadStatus(reply->httpStatus());
auto reason = reply->networkReply()->attribute(QNetworkRequest::HttpReasonPhraseAttribute).toString();
m_error = BadStatus(reply->httpStatus(), reason.isEmpty() ? std::nullopt : std::optional(reason));
} else
Q_UNREACHABLE();
@@ -50,19 +51,19 @@ QNetworkRequest OllamaClient::makeRequest(const QString &path) const
return req;
}
auto OllamaClient::processResponse(QNetworkReply &reply) -> QCoro::Task<DataOrRespErr<json::value>>
auto OllamaClient::processResponse(std::unique_ptr<QNetworkReply> reply) -> QCoro::Task<DataOrRespErr<json::value>>
{
QRestReply restReply(&reply);
if (reply.error())
QRestReply restReply(reply.get());
if (reply->error())
co_return std::unexpected(&restReply);
auto coroReply = qCoro(reply);
auto coroReply = qCoro(*reply);
for (;;) {
auto chunk = co_await coroReply.readAll();
if (!restReply.isSuccess())
co_return std::unexpected(&restReply);
if (chunk.isEmpty()) {
Q_ASSERT(reply.atEnd());
Q_ASSERT(reply->atEnd());
break;
}
m_parser.write(chunk.data(), chunk.size());
@@ -73,7 +74,7 @@ auto OllamaClient::processResponse(QNetworkReply &reply) -> QCoro::Task<DataOrRe
}
template <typename Resp>
auto OllamaClient::get(const QString &path) -> QCoro::Task<DataOrRespErr<Resp>>
auto OllamaClient::get(QString path) -> QCoro::Task<DataOrRespErr<Resp>>
{
// get() should not throw exceptions
try {
@@ -86,11 +87,11 @@ auto OllamaClient::get(const QString &path) -> QCoro::Task<DataOrRespErr<Resp>>
}
}
template auto OllamaClient::get(const QString &) -> QCoro::Task<DataOrRespErr<VersionResponse>>;
template auto OllamaClient::get(const QString &) -> QCoro::Task<DataOrRespErr<ListResponse>>;
template auto OllamaClient::get(QString) -> QCoro::Task<DataOrRespErr<VersionResponse>>;
template auto OllamaClient::get(QString) -> QCoro::Task<DataOrRespErr<ListResponse>>;
template <typename Resp, typename Req>
auto OllamaClient::post(const QString &path, const Req &body) -> QCoro::Task<DataOrRespErr<Resp>>
auto OllamaClient::post(QString path, const Req &body) -> QCoro::Task<DataOrRespErr<Resp>>
{
// post() should not throw exceptions
try {
@@ -104,12 +105,12 @@ auto OllamaClient::post(const QString &path, const Req &body) -> QCoro::Task<Dat
}
}
template auto OllamaClient::post(const QString &, const ShowRequest &) -> QCoro::Task<DataOrRespErr<ShowResponse>>;
template auto OllamaClient::post(QString, const ShowRequest &) -> QCoro::Task<DataOrRespErr<ShowResponse>>;
auto OllamaClient::getJson(const QString &path) -> QCoro::Task<DataOrRespErr<json::value>>
auto OllamaClient::getJson(QString path) -> QCoro::Task<DataOrRespErr<json::value>>
{
std::unique_ptr<QNetworkReply> reply(m_nam.get(makeRequest(path)));
co_return co_await processResponse(*reply);
std::unique_ptr<QNetworkReply> reply(m_nam->get(makeRequest(path)));
return processResponse(std::move(reply));
}
auto OllamaClient::postJson(const QString &path, const json::value &body) -> QCoro::Task<DataOrRespErr<json::value>>
@@ -117,8 +118,8 @@ auto OllamaClient::postJson(const QString &path, const json::value &body) -> QCo
JsonStreamDevice stream(&body);
auto req = makeRequest(path);
req.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"_ba);
std::unique_ptr<QNetworkReply> reply(m_nam.post(req, &stream));
co_return co_await processResponse(*reply);
std::unique_ptr<QNetworkReply> reply(m_nam->post(req, &stream));
co_return co_await processResponse(std::move(reply));
}

View File

@@ -18,11 +18,11 @@ QString restErrorString(const QRestReply &reply)
if (!reply.isHttpStatusSuccess()) {
auto code = reply.httpStatus();
auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute);
auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute).toString();
return u"HTTP %1%2%3 for URL \"%4\""_s.arg(
QString::number(code),
reason.isValid() ? u" "_s : QString(),
reason.toString(),
reason.isEmpty() ? QString() : u" "_s,
reason,
nr->request().url().toString()
);
}

View File

@@ -253,6 +253,8 @@ qt_add_executable(chat
src/modellist.cpp src/modellist.h
src/mysettings.cpp src/mysettings.h
src/network.cpp src/network.h
src/qmlfunctions.cpp src/qmlfunctions.h
src/qmlsharedptr.cpp src/qmlsharedptr.h
src/server.cpp src/server.h
src/store_base.cpp src/store_base.h
src/store_provider.cpp src/store_provider.h
@@ -272,18 +274,20 @@ qt_add_qml_module(chat
QML_FILES
main.qml
qml/AddCollectionView.qml
qml/AddCustomProviderView.qml
qml/AddModelView.qml
qml/AddGPT4AllModelView.qml
qml/AddHFModelView.qml
qml/AddRemoteModelView.qml
qml/ApplicationSettings.qml
qml/ChatDrawer.qml
qml/ChatCollapsibleItem.qml
qml/ChatDrawer.qml
qml/ChatItemView.qml
qml/ChatMessageButton.qml
qml/ChatTextItem.qml
qml/ChatView.qml
qml/CollectionsDrawer.qml
qml/CustomProviderCard.qml
qml/HomeView.qml
qml/LocalDocsSettings.qml
qml/LocalDocsView.qml

View File

@@ -15,7 +15,6 @@ add_subdirectory(QXlsx/QXlsx)
add_subdirectory(json) # required by minja
# TartanLlama
set(FUNCTION_REF_ENABLE_TESTS OFF)
add_subdirectory(generator)
if (NOT GPT4ALL_USING_QTPDF)

View File

@@ -0,0 +1,57 @@
import QtQuick
import QtQuick.Controls
import QtQuick.Layouts
ColumnLayout {
Layout.fillWidth: true
Layout.alignment: Qt.AlignTop
spacing: 5
Label {
Layout.topMargin: 0
Layout.bottomMargin: 25
Layout.rightMargin: 150 * theme.fontScale
Layout.alignment: Qt.AlignTop
Layout.fillWidth: true
verticalAlignment: Text.AlignTop
text: qsTr("Add custom model providers here.")
font.pixelSize: theme.fontSizeLarger
color: theme.textColor
wrapMode: Text.WordWrap
}
ScrollView {
id: scrollView
ScrollBar.vertical.policy: ScrollBar.AsNeeded
Layout.fillWidth: true
Layout.fillHeight: true
contentWidth: availableWidth
clip: true
Flow {
anchors.left: parent.left
anchors.right: parent.right
spacing: 20
bottomPadding: 20
property int childWidth: 330 * theme.fontScale
property int childHeight: 400 + 166 * theme.fontScale
CustomProviderCard {
width: parent.childWidth
height: parent.childHeight
withApiKey: true
createProvider: QmlFunctions.newCustomOpenaiProvider
providerName: qsTr("OpenAI")
providerImage: "qrc:/gpt4all/icons/antenna_3.svg"
providerDesc: qsTr("Configure a custom OpenAI provider.")
}
CustomProviderCard {
width: parent.childWidth
height: parent.childHeight
withApiKey: false
createProvider: QmlFunctions.newCustomOllamaProvider
providerName: qsTr("Ollama")
providerImage: "qrc:/gpt4all/icons/antenna_3.svg"
providerDesc: qsTr("Configure a custom Ollama provider.")
}
}
}
}

View File

@@ -96,6 +96,11 @@ Rectangle {
remoteModelView.show();
}
}
MyTabButton {
text: qsTr("Custom Providers")
isSelected: customProviderModelView.isShown()
onPressed: customProviderModelView.show()
}
MyTabButton {
text: qsTr("HuggingFace")
isSelected: huggingfaceModelView.isShown()
@@ -136,6 +141,15 @@ Rectangle {
}
}
AddCustomProviderView {
id: customProviderModelView
Layout.fillWidth: true
Layout.fillHeight: true
function show() { stackLayout.currentIndex = 2; }
function isShown() { return stackLayout.currentIndex === 2; }
}
AddHFModelView {
id: huggingfaceModelView
Layout.fillWidth: true
@@ -146,10 +160,10 @@ Rectangle {
anchors.fill: parent
function show() {
stackLayout.currentIndex = 2;
stackLayout.currentIndex = 3;
}
function isShown() {
return stackLayout.currentIndex === 2;
return stackLayout.currentIndex === 3;
}
}
}

View File

@@ -49,15 +49,13 @@ ColumnLayout {
property int childWidth: 330 * theme.fontScale
property int childHeight: 400 + 166 * theme.fontScale
Repeater {
model: BuiltinProviderList
delegate: RemoteModelCard {
required property var data
model: ProviderListSort
RemoteModelCard {
width: parent.childWidth
height: parent.childHeight
provider: data
providerBaseUrl: data.baseUrl
providerName: data.name
providerImage: data.icon
provider: modelData
providerName: provider.name
providerImage: provider.icon
providerDesc: ({
'{20f963dc-1f99-441e-ad80-f30a0a06bcac}': qsTr(
'Groq offers a high-performance AI inference engine designed for low-latency and ' +
@@ -78,10 +76,18 @@ ColumnLayout {
'performance, making them a solid option for applications requiring scalable AI ' +
'solutions.<br><br>Get your API key: <a href="https://mistral.ai/">https://mistral.ai/</a>'
),
})[data.id.toString()]
modelWhitelist: data.modelWhitelist
})[provider.id.toString()]
}
}
RemoteModelCard {
width: parent.childWidth
height: parent.childHeight
providerUsesApiKey: false
providerName: qsTr("Ollama (Custom)")
providerImage: "qrc:/gpt4all/icons/antenna_3.svg"
providerDesc: qsTr("Configure a custom Ollama provider.")
}
// TODO(jared): add custom openai back to the list
/*
RemoteModelCard {
width: parent.childWidth

View File

@@ -0,0 +1,193 @@
import QtQuick
import QtQuick.Controls
import QtQuick.Layouts
import gpt4all.ProviderRegistry
Rectangle {
id: root
required property bool withApiKey
required property var createProvider
property alias providerName: providerNameLabel.text
property alias providerImage: myimage.source
property alias providerDesc: providerDescLabel.text
color: theme.conversationBackground
radius: 10
border.width: 1
border.color: theme.controlBorder
implicitHeight: topColumn.height + bottomColumn.height + 33 * theme.fontScale
ColumnLayout {
id: topColumn
anchors.left: parent.left
anchors.right: parent.right
anchors.top: parent.top
anchors.margins: 20
spacing: 15 * theme.fontScale
RowLayout {
Layout.alignment: Qt.AlignTop
spacing: 10
Item {
Layout.preferredWidth: 27 * theme.fontScale
Layout.preferredHeight: 27 * theme.fontScale
Layout.alignment: Qt.AlignLeft
Image {
id: myimage
anchors.centerIn: parent
sourceSize.width: parent.width
sourceSize.height: parent.height
mipmap: true
fillMode: Image.PreserveAspectFit
}
}
Label {
id: providerNameLabel
color: theme.textColor
font.pixelSize: theme.fontSizeBanner
}
}
Label {
id: providerDescLabel
Layout.fillWidth: true
wrapMode: Text.Wrap
color: theme.settingsTitleTextColor
font.pixelSize: theme.fontSizeLarge
onLinkActivated: function(link) { Qt.openUrlExternally(link); }
MouseArea {
anchors.fill: parent
acceptedButtons: Qt.NoButton // pass clicks to parent
cursorShape: parent.hoveredLink ? Qt.PointingHandCursor : Qt.ArrowCursor
}
}
}
ColumnLayout {
id: bottomColumn
anchors.left: parent.left
anchors.right: parent.right
anchors.bottom: parent.bottom
anchors.margins: 20
spacing: 30
ColumnLayout {
MySettingsLabel {
text: qsTr("Name")
font.bold: true
font.pixelSize: theme.fontSizeLarge
color: theme.settingsTitleTextColor
}
MyTextField {
id: nameField
Layout.fillWidth: true
font.pixelSize: theme.fontSizeLarge
wrapMode: Text.WrapAnywhere
placeholderText: qsTr("Provider Name")
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
}
}
ColumnLayout {
MySettingsLabel {
text: qsTr("Base URL")
font.bold: true
font.pixelSize: theme.fontSizeLarge
color: theme.settingsTitleTextColor
}
MyTextField {
id: baseUrlField
property bool ok: text.trim() !== ""
Layout.fillWidth: true
font.pixelSize: theme.fontSizeLarge
wrapMode: Text.WrapAnywhere
placeholderText: qsTr("Provider Base URL")
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
}
}
ColumnLayout {
visible: withApiKey
MySettingsLabel {
text: qsTr("API Key")
font.bold: true
font.pixelSize: theme.fontSizeLarge
color: theme.settingsTitleTextColor
}
MyTextField {
id: apiKeyField
Layout.fillWidth: true
font.pixelSize: theme.fontSizeLarge
wrapMode: Text.WrapAnywhere
echoMode: TextField.Password
placeholderText: qsTr("Provider API Key")
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
}
}
ColumnLayout {
MySettingsLabel {
text: qsTr("Status")
font.bold: true
font.pixelSize: theme.fontSizeLarge
color: theme.settingsTitleTextColor
}
RowLayout {
spacing: 10
MyTextField {
id: statusText
property var provider: null // owns the new provider
enabled: false
Layout.fillWidth: true
font.pixelSize: theme.fontSizeLarge
property var inputs: ({
name : nameField .text.trim(),
baseUrl : baseUrlField.text.trim(),
apiKey : apiKeyField .text.trim(),
})
function update() {
provider = null;
text = qsTr("...");
if (inputs.name === "" || inputs.baseUrl === "")
return;
const args = [inputs.name, inputs.baseUrl];
if (withApiKey)
args.push(inputs.apiKey);
let p = createProvider(...args);
if (p !== null)
p.get().statusQml().then(status => {
if (status !== null) {
if (status.ok) { provider = p; }
text = status.detail;
}
});
}
Component.onCompleted: update()
onInputsChanged: update()
}
}
}
MySettingsButton {
id: installButton
Layout.alignment: Qt.AlignRight
text: qsTr("Install")
font.pixelSize: theme.fontSizeLarge
enabled: statusText.provider !== null
onClicked: ProviderRegistry.addQml(statusText.provider)
Accessible.role: Accessible.Button
Accessible.name: qsTr("Install")
Accessible.description: qsTr("Install custom provider")
}
}
}

View File

@@ -1,30 +1,19 @@
import QtCore
import QtQuick
import QtQuick.Controls
import QtQuick.Controls.Basic
import QtQuick.Layouts
import QtQuick.Dialogs
import Qt.labs.folderlistmodel
import Qt5Compat.GraphicalEffects
import llm
import chatlistmodel
import download
import modellist
import network
import gpt4all
import mysettings
import localdocs
Rectangle {
required property var provider
id: remoteModelCard
property var provider: null
property alias providerName: providerNameLabel.text
property alias providerImage: myimage.source
property alias providerDesc: providerDescLabel.text
property string providerBaseUrl: ""
property bool providerIsCustom: false
property var modelWhitelist: null
property bool providerUsesApiKey: true
// for internal use
property bool apiKeyRequired: provider === null ? providerUsesApiKey : "apiKey" in provider
property bool apiKeyGood: !apiKeyRequired // (overwritten later if required)
property bool baseUrlGood: provider !== null // (overwritten later if custom)
color: theme.conversationBackground
radius: 10
@@ -89,6 +78,8 @@ Rectangle {
spacing: 30
ColumnLayout {
visible: apiKeyRequired
MySettingsLabel {
text: qsTr("API Key")
font.bold: true
@@ -106,29 +97,19 @@ Rectangle {
messageToast.show(qsTr("ERROR: $API_KEY is empty."));
apiKeyField.placeholderTextColor = theme.textErrorColor;
}
Component.onCompleted: { text = provider.apiKey; }
Component.onCompleted: { if (parent.visible && provider !== null) { text = provider.apiKey; } }
onTextChanged: {
apiKeyField.placeholderTextColor = theme.mutedTextColor;
if (!providerIsCustom && provider.setApiKeyQml(text)) {
provider.listModelsQml().then(modelList => {
if (modelList !== null) {
if (modelWhitelist !== null)
models = models.filter(m => modelWhitelist.includes(m));
myModelList.model = models;
myModelList.currentIndex = -1;
}
});
}
if (provider !== null) { apiKeyGood = provider.setApiKeyQml(text) && text !== ""; }
}
placeholderText: qsTr("enter $API_KEY")
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
Accessible.description: qsTr("Whether the file hash is being calculated")
}
}
ColumnLayout {
visible: providerIsCustom
visible: provider === null
MySettingsLabel {
text: qsTr("Base Url")
font.bold: true
@@ -146,40 +127,16 @@ Rectangle {
}
onTextChanged: {
baseUrlField.placeholderTextColor = theme.mutedTextColor;
baseUrlGood = text.trim() !== "";
}
placeholderText: qsTr("enter $BASE_URL")
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
}
}
ColumnLayout {
visible: providerIsCustom
MySettingsLabel {
text: qsTr("Model Name")
font.bold: true
font.pixelSize: theme.fontSizeLarge
color: theme.settingsTitleTextColor
}
MyTextField {
id: modelNameField
Layout.fillWidth: true
font.pixelSize: theme.fontSizeLarge
wrapMode: Text.WrapAnywhere
function showError() {
messageToast.show(qsTr("ERROR: $MODEL_NAME is empty."))
modelNameField.placeholderTextColor = theme.textErrorColor;
}
onTextChanged: {
modelNameField.placeholderTextColor = theme.mutedTextColor;
}
placeholderText: qsTr("enter $MODEL_NAME")
Accessible.role: Accessible.EditableText
Accessible.name: placeholderText
}
}
ColumnLayout {
visible: myModelList.count > 0 && !providerIsCustom
visible: myModelList.count > 0
MySettingsLabel {
text: qsTr("Models")
@@ -194,7 +151,27 @@ Rectangle {
MyComboBox {
Layout.fillWidth: true
id: myModelList
currentIndex: -1;
currentIndex: -1
property bool ready: baseUrlGood && apiKeyGood
onReadyChanged: {
if (!ready) { return; }
let providerRef = null; // owns the new provider
let provider = remoteModelCard.provider;
if (provider === null) {
// TODO: custom OpenAI
providerRef = QmlFunctions.newCustomOllamaProvider("foo", baseUrlField.text.trim());
if (providerRef !== null)
provider = providerRef.get();
}
if (provider !== null) {
provider.listModelsQml().then(modelList => {
if (modelList !== null) {
model = modelList;
currentIndex = -1;
}
});
}
}
}
}
}
@@ -206,10 +183,10 @@ Rectangle {
font.pixelSize: theme.fontSizeLarge
property string apiKeyText: apiKeyField.text.trim()
property string baseUrlText: providerIsCustom ? baseUrlField.text.trim() : providerBaseUrl.trim()
property string modelNameText: providerIsCustom ? modelNameField.text.trim() : myModelList.currentText.trim()
property string baseUrlText: provider === null ? baseUrlField.text.trim() : provider.baseUrl
property string modelNameText: myModelList.currentText.trim()
enabled: apiKeyText !== "" && baseUrlText !== "" && modelNameText !== ""
enabled: baseUrlGood && apiKeyGood && modelNameText !== ""
onClicked: {
Download.installCompatibleModel(

View File

@@ -0,0 +1,18 @@
#pragma once
#include <memory>
namespace gpt4all::ui {
/// Helper mixin for classes derived from std::enable_shared_from_this.
template <typename T>
struct Creatable {
template <typename... Ts>
static auto create(Ts &&...args) -> std::shared_ptr<T>
{ return std::make_shared<T>(typename T::protected_t(), std::forward<Ts>(args)...); }
};
} // namespace gpt4all::ui

View File

@@ -37,6 +37,8 @@ public:
protected:
[[nodiscard]] virtual auto newInstanceImpl(QNetworkAccessManager *nam) const -> ChatLLMInstance * = 0;
template <typename T> friend struct Creatable;
};

View File

@@ -1,7 +1,14 @@
#include "llmodel_ollama.h"
#include "main.h"
#include "mysettings.h"
#include <QCoro/QCoroAsyncGenerator>
#include <QCoro/QCoroTask>
#include <fmt/format.h>
#include <gpt4all-backend/formatters.h> // IWYU pragma: keep
#include <QJSEngine>
using namespace Qt::Literals::StringLiterals;
@@ -21,6 +28,9 @@ auto OllamaGenerationParams::toMap() const -> QMap<QLatin1StringView, QVariant>
};
}
OllamaProvider::OllamaProvider()
{ QJSEngine::setObjectOwnership(this, QJSEngine::CppOwnership); }
OllamaProvider::~OllamaProvider() noexcept = default;
auto OllamaProvider::supportedGenerationParams() const -> QSet<GenerationParam>
@@ -33,9 +43,54 @@ auto OllamaProvider::makeGenerationParams(const QMap<GenerationParam, QVariant>
-> OllamaGenerationParams *
{ return new OllamaGenerationParams(values); }
auto OllamaProvider::status() -> QCoro::Task<ProviderStatus>
{
auto client = makeClient();
auto resp = co_await client.version();
if (resp)
co_return ProviderStatus(tr("Version: %1").arg(resp->version));
co_return ProviderStatus(resp.error());
}
auto OllamaProvider::listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>>
{
auto client = makeClient();
auto resp = co_await client.list();
if (!resp)
co_return std::unexpected(resp.error());
QStringList res;
for (auto &model : resp->models)
res << model.name;
co_return res;
}
QCoro::QmlTask OllamaProvider::statusQml()
{ return wrapQmlTask(this, &OllamaProvider::status, u"OllamaProvider::status"_s); }
QCoro::QmlTask OllamaProvider::listModelsQml()
{ return wrapQmlTask(this, &OllamaProvider::listModels, u"OllamaProvider::listModels"_s); }
auto OllamaProvider::newModel(const QByteArray &modelHash) const -> std::shared_ptr<OllamaModelDescription>
{ return std::static_pointer_cast<OllamaModelDescription>(newModelImpl(modelHash)); }
auto OllamaProvider::newModelImpl(const QVariant &key) const -> std::shared_ptr<ModelDescription>
{
if (!key.canConvert<QByteArray>())
throw std::invalid_argument(fmt::format("expected modelHash type QByteArray, got {}", key.typeName()));
return OllamaModelDescription::create(
std::shared_ptr<const OllamaProvider>(shared_from_this(), this), key.toByteArray()
);
}
auto OllamaProvider::makeClient() -> backend::OllamaClient
{
auto *mySettings = MySettings::globalInstance();
return backend::OllamaClient(m_baseUrl, mySettings->userAgent(), networkAccessManager());
}
/// load
OllamaProviderCustom::OllamaProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl)
: ModelProvider (std::move(id), std::move(name), std::move(baseUrl))
OllamaProviderCustom::OllamaProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl)
: ModelProvider (p, std::move(id), std::move(name), std::move(baseUrl))
, ModelProviderCustom(store)
{
if (auto res = m_store->acquire(m_id); !res)
@@ -43,14 +98,12 @@ OllamaProviderCustom::OllamaProviderCustom(ProviderStore *store, QUuid id, QStri
}
/// create
OllamaProviderCustom::OllamaProviderCustom(ProviderStore *store, QString name, QUrl baseUrl)
: ModelProvider (std::move(name), std::move(baseUrl))
OllamaProviderCustom::OllamaProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl)
: ModelProvider (p, QUuid::createUuid(), std::move(name), std::move(baseUrl))
, ModelProviderCustom(store)
{
auto data = m_store->create(m_name, m_baseUrl);
if (!data)
data.error().raise();
m_id = (*data)->id;
if (auto res = m_store->acquire(m_id); !res)
res.error().raise();
}
auto OllamaProviderCustom::asData() -> ModelProviderData
@@ -58,7 +111,7 @@ auto OllamaProviderCustom::asData() -> ModelProviderData
return {
.id = m_id,
.custom_details = CustomProviderDetails { m_name, m_baseUrl },
.provider_details = {},
.provider_details = std::monostate(),
};
}

View File

@@ -1,9 +1,13 @@
#pragma once
#include "creatable.h"
#include "llmodel_chat.h"
#include "llmodel_description.h"
#include "llmodel_provider.h"
#include <QCoro/QCoroQmlTask> // IWYU pragma: keep
#include <gpt4all-backend/ollama-client.h>
#include <QByteArray>
#include <QLatin1StringView> // IWYU pragma: keep
#include <QObject>
@@ -12,6 +16,8 @@
#include <QVariant>
#include <QtTypes> // IWYU pragma: keep
#include <utility>
class QNetworkAccessManager;
template <typename Key, typename T> class QMap;
template <typename T> class QSet;
@@ -21,6 +27,7 @@ namespace gpt4all::ui {
class OllamaChatModel;
class OllamaModelDescription;
struct OllamaGenerationParamsData {
uint n_predict;
@@ -38,9 +45,12 @@ protected:
};
class OllamaProvider : public QObject, public virtual ModelProvider {
Q_GADGET
Q_OBJECT
Q_PROPERTY(QUuid id READ id CONSTANT)
protected:
explicit OllamaProvider();
public:
~OllamaProvider() noexcept override = 0;
@@ -49,30 +59,48 @@ public:
auto supportedGenerationParams() const -> QSet<GenerationParam> override;
auto makeGenerationParams(const QMap<GenerationParam, QVariant> &values) const -> OllamaGenerationParams * override;
// endpoints
auto status () -> QCoro::Task<ProviderStatus > override;
auto listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>> override;
// QML wrapped endpoints
Q_INVOKABLE QCoro::QmlTask statusQml ();
Q_INVOKABLE QCoro::QmlTask listModelsQml();
[[nodiscard]] auto newModel(const QByteArray &modelHash) const -> std::shared_ptr<OllamaModelDescription>;
protected:
[[nodiscard]] auto newModelImpl(const QVariant &key) const -> std::shared_ptr<ModelDescription> final;
private:
backend::OllamaClient makeClient();
};
class OllamaProviderBuiltin : public OllamaProvider {
Q_GADGET
class OllamaProviderBuiltin : public OllamaProvider, public Creatable<OllamaProviderBuiltin> {
Q_OBJECT
Q_PROPERTY(QString name READ name CONSTANT)
Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT)
public:
/// Create a new built-in Ollama provider (transient).
explicit OllamaProviderBuiltin(QUuid id, QString name, QUrl baseUrl)
: ModelProvider(std::move(id), std::move(name), std::move(baseUrl)) {}
explicit OllamaProviderBuiltin(protected_t p, QUuid id, QString name, QUrl baseUrl)
: ModelProvider(p, std::move(id), std::move(name), std::move(baseUrl)) {}
};
class OllamaProviderCustom final : public OllamaProvider, public ModelProviderCustom {
class OllamaProviderCustom final
: public OllamaProvider, public ModelProviderCustom, public Creatable<OllamaProviderCustom>
{
Q_OBJECT
Q_PROPERTY(QString name READ name NOTIFY nameChanged )
Q_PROPERTY(QUrl baseUrl READ baseUrl NOTIFY baseUrlChanged)
public:
/// Load an existing OllamaProvider from disk.
explicit OllamaProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl);
explicit OllamaProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl);
/// Create a new OllamaProvider on disk.
explicit OllamaProviderCustom(ProviderStore *store, QString name, QUrl baseUrl);
explicit OllamaProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl);
Q_SIGNALS:
void nameChanged (const QString &value);
@@ -82,17 +110,13 @@ protected:
auto asData() -> ModelProviderData override;
};
class OllamaModelDescription : public ModelDescription {
class OllamaModelDescription : public ModelDescription, public Creatable<OllamaModelDescription> {
Q_GADGET
Q_PROPERTY(QByteArray modelHash READ modelHash CONSTANT)
public:
explicit OllamaModelDescription(protected_t, std::shared_ptr<const OllamaProvider> provider, QByteArray modelHash);
static auto create(std::shared_ptr<const OllamaProvider> provider, QByteArray modelHash)
-> std::shared_ptr<OllamaModelDescription>
{ return std::make_shared<OllamaModelDescription>(protected_t(), std::move(provider), std::move(modelHash)); }
// getters
[[nodiscard]] auto provider () const -> const OllamaProvider * override { return m_provider.get(); }
[[nodiscard]] QVariant key () const override { return m_modelHash; }

View File

@@ -14,6 +14,7 @@
#include <QAnyStringView>
#include <QByteArray>
#include <QJSEngine>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
@@ -88,6 +89,13 @@ auto OpenaiGenerationParams::toMap() const -> QMap<QLatin1StringView, QVariant>
};
}
OpenaiProvider::OpenaiProvider()
{ QJSEngine::setObjectOwnership(this, QJSEngine::CppOwnership); }
OpenaiProvider::OpenaiProvider(QString apiKey)
: m_apiKey(std::move(apiKey))
{ QJSEngine::setObjectOwnership(this, QJSEngine::CppOwnership); }
OpenaiProvider::~OpenaiProvider() noexcept = default;
Q_INVOKABLE bool OpenaiProvider::setApiKeyQml(QString value)
@@ -108,13 +116,22 @@ auto OpenaiProvider::makeGenerationParams(const QMap<GenerationParam, QVariant>
-> OpenaiGenerationParams *
{ return new OpenaiGenerationParams(values); }
auto OpenaiProvider::status() -> QCoro::Task<ProviderStatus>
{
auto resp = co_await listModels();
if (resp)
co_return ProviderStatus(tr("OK"));
co_return ProviderStatus(resp.error());
}
auto OpenaiProvider::listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>>
{
auto *mySettings = MySettings::globalInstance();
auto *nam = networkAccessManager();
QNetworkRequest request(m_baseUrl.resolved(u"models"_s));
request.setHeader (QNetworkRequest::ContentTypeHeader, "application/json"_ba);
request.setRawHeader("Authorization"_ba, fmt::format("Bearer {}", m_apiKey).c_str());
request.setHeader (QNetworkRequest::UserAgentHeader, mySettings->userAgent());
request.setRawHeader("authorization"_ba, fmt::format("Bearer {}", m_apiKey).c_str());
std::unique_ptr<QNetworkReply> reply(nam->get(request));
QRestReply restReply(reply.get());
@@ -146,20 +163,27 @@ auto OpenaiProvider::listModels() -> QCoro::Task<backend::DataOrRespErr<QStringL
co_return models;
}
QCoro::QmlTask OpenaiProvider::statusQml()
{ return wrapQmlTask(this, &OpenaiProvider::status, u"OpenaiProvider::status"_s); }
QCoro::QmlTask OpenaiProvider::listModelsQml()
{ return wrapQmlTask(this, &OpenaiProvider::listModels, u"OpenaiProvider::listModels"_s); }
auto OpenaiProvider::newModel(const QString &modelName) const -> std::shared_ptr<OpenaiModelDescription>
{ return std::static_pointer_cast<OpenaiModelDescription>(newModelImpl(modelName)); }
auto OpenaiProvider::newModelImpl(const QVariant &key) const -> std::shared_ptr<ModelDescription>
{
return [this]() -> QCoro::Task<QVariant> {
auto result = co_await listModels();
if (result)
co_return *result;
qWarning().noquote() << "OpenaiProvider::listModels failed:" << result.error().errorString();
co_return QVariant::fromValue(nullptr);
}();
if (!key.canConvert<QString>())
throw std::invalid_argument(fmt::format("expected modelName type QString, got {}", key.typeName()));
return OpenaiModelDescription::create(
std::shared_ptr<const OpenaiProvider>(shared_from_this(), this), key.toString()
);
}
OpenaiProviderBuiltin::OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl,
QStringList modelWhitelist)
: ModelProvider(std::move(id), std::move(name), std::move(baseUrl))
OpenaiProviderBuiltin::OpenaiProviderBuiltin(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl icon,
QUrl baseUrl, std::unordered_set<QString> modelWhitelist)
: ModelProvider(p, std::move(id), std::move(name), std::move(baseUrl))
, ModelProviderBuiltin(std::move(icon))
, ModelProviderMutable(store)
, m_modelWhitelist(std::move(modelWhitelist))
@@ -173,6 +197,15 @@ OpenaiProviderBuiltin::OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QSt
}
}
auto OpenaiProviderBuiltin::listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>>
{
auto models = co_await OpenaiProvider::listModels();
if (!models)
co_return std::unexpected(models.error());
models->removeIf([&](auto &m) { return !m_modelWhitelist.contains(m); });
co_return *models;
}
auto OpenaiProviderBuiltin::asData() -> ModelProviderData
{
return {
@@ -183,22 +216,25 @@ auto OpenaiProviderBuiltin::asData() -> ModelProviderData
}
/// load
OpenaiProviderCustom::OpenaiProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey)
: ModelProvider(std::move(id), std::move(name), std::move(baseUrl))
OpenaiProviderCustom::OpenaiProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl,
QString apiKey)
: ModelProvider(p, std::move(id), std::move(name), std::move(baseUrl))
, OpenaiProvider(std::move(apiKey))
, ModelProviderCustom(store)
{}
{
if (auto res = m_store->acquire(m_id); !res)
res.error().raise();
}
/// create
OpenaiProviderCustom::OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey)
: ModelProvider(std::move(name), std::move(baseUrl))
OpenaiProviderCustom::OpenaiProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl,
QString apiKey)
: ModelProvider(p, QUuid::createUuid(), std::move(name), std::move(baseUrl))
, ModelProviderCustom(std::move(store))
, OpenaiProvider(std::move(apiKey))
{
auto data = m_store->create(m_name, m_baseUrl, m_apiKey);
if (!data)
data.error().raise();
m_id = (*data)->id;
if (auto res = m_store->acquire(m_id); !res)
res.error().raise();
}
auto OpenaiProviderCustom::asData() -> ModelProviderData
@@ -330,8 +366,9 @@ auto OpenaiChatModel::generate(QStringView prompt, const GenerationParams *param
auto &provider = *m_description->provider();
QNetworkRequest request(provider.baseUrl().resolved(QUrl("/v1/chat/completions")));
request.setHeader(QNetworkRequest::UserAgentHeader, mySettings->userAgent());
request.setRawHeader("authorization", u"Bearer %1"_s.arg(provider.apiKey()).toUtf8());
request.setHeader (QNetworkRequest::UserAgentHeader, mySettings->userAgent());
request.setHeader (QNetworkRequest::ContentTypeHeader, "application/json"_ba );
request.setRawHeader("authorization"_ba, fmt::format("Bearer {}", provider.apiKey()).c_str());
QRestAccessManager restNam(m_nam);
std::unique_ptr<QNetworkReply> reply(restNam.post(request, QJsonDocument(reqBody)));

View File

@@ -1,5 +1,6 @@
#pragma once
#include "creatable.h"
#include "llmodel_chat.h"
#include "llmodel_description.h"
#include "llmodel_provider.h"
@@ -16,6 +17,7 @@
#include <QtTypes> // IWYU pragma: keep
#include <memory>
#include <unordered_set>
#include <utility>
class QNetworkAccessManager;
@@ -28,6 +30,7 @@ namespace gpt4all::ui {
class OpenaiChatModel;
class OpenaiModelDescription;
struct OpenaiGenerationParamsData {
uint n_predict;
@@ -51,9 +54,8 @@ class OpenaiProvider : public QObject, public virtual ModelProvider {
Q_PROPERTY(QString apiKey READ apiKey NOTIFY apiKeyChanged)
protected:
explicit OpenaiProvider() = default;
explicit OpenaiProvider(QString apiKey)
: m_apiKey(std::move(apiKey)) {}
explicit OpenaiProvider();
explicit OpenaiProvider(QString apiKey);
public:
~OpenaiProvider() noexcept override = 0;
@@ -69,50 +71,69 @@ public:
auto supportedGenerationParams() const -> QSet<GenerationParam> override;
auto makeGenerationParams(const QMap<GenerationParam, QVariant> &values) const -> OpenaiGenerationParams * override;
auto listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>>;
// endpoints
auto status () -> QCoro::Task<ProviderStatus > override;
auto listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>> override;
// QML wrapped endpoints
Q_INVOKABLE QCoro::QmlTask statusQml ();
Q_INVOKABLE QCoro::QmlTask listModelsQml();
[[nodiscard]] auto newModel(const QString &modelName) const -> std::shared_ptr<OpenaiModelDescription>;
Q_SIGNALS:
void apiKeyChanged(const QString &value);
protected:
[[nodiscard]] auto newModelImpl(const QVariant &key) const -> std::shared_ptr<ModelDescription> final;
QString m_apiKey;
};
class OpenaiProviderBuiltin : public OpenaiProvider, public ModelProviderBuiltin, public ModelProviderMutable {
class OpenaiProviderBuiltin
: public OpenaiProvider
, public ModelProviderBuiltin
, public ModelProviderMutable
, public Creatable<OpenaiProviderBuiltin>
{
Q_OBJECT
Q_PROPERTY(QString name READ name CONSTANT)
Q_PROPERTY(QUrl icon READ icon CONSTANT)
Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT)
Q_PROPERTY(QStringList modelWhitelist READ modelWhitelist CONSTANT)
public:
/// Create a new built-in OpenAI provider, loading its API key from disk if known.
explicit OpenaiProviderBuiltin(ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl,
QStringList modelWhitelist);
[[nodiscard]] const QStringList &modelWhitelist() { return m_modelWhitelist; }
explicit OpenaiProviderBuiltin(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl icon, QUrl baseUrl,
std::unordered_set<QString> modelWhitelist);
[[nodiscard]] DataStoreResult<> setApiKey(QString value) override
{ return setMemberProp<QString>(&OpenaiProviderBuiltin::m_apiKey, "apiKey", std::move(value), /*createName*/ m_name); }
// override for model whitelist
auto listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>> override;
Q_SIGNALS:
void apiKeyChanged(const QString &value);
protected:
auto asData() -> ModelProviderData override;
QStringList m_modelWhitelist;
std::unordered_set<QString> m_modelWhitelist;
};
class OpenaiProviderCustom final : public OpenaiProvider, public ModelProviderCustom {
class OpenaiProviderCustom final
: public OpenaiProvider, public ModelProviderCustom, public Creatable<OpenaiProviderCustom>
{
Q_OBJECT
Q_PROPERTY(QString name READ name NOTIFY nameChanged )
Q_PROPERTY(QUrl baseUrl READ baseUrl NOTIFY baseUrlChanged)
public:
/// Load an existing OpenaiProvider from disk.
explicit OpenaiProviderCustom(ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey);
explicit OpenaiProviderCustom(protected_t p, ProviderStore *store, QUuid id, QString name, QUrl baseUrl, QString apiKey);
/// Create a new OpenaiProvider on disk.
explicit OpenaiProviderCustom(ProviderStore *store, QString name, QUrl baseUrl, QString apiKey);
explicit OpenaiProviderCustom(protected_t p, ProviderStore *store, QString name, QUrl baseUrl, QString apiKey);
[[nodiscard]] DataStoreResult<> setApiKey(QString value) override
{ return setMemberProp<QString>(&OpenaiProviderCustom::m_apiKey, "apiKey", std::move(value)); }
@@ -126,17 +147,13 @@ protected:
auto asData() -> ModelProviderData override;
};
class OpenaiModelDescription : public ModelDescription {
class OpenaiModelDescription : public ModelDescription, public Creatable<OpenaiModelDescription> {
Q_GADGET
Q_PROPERTY(QString modelName READ modelName CONSTANT)
public:
explicit OpenaiModelDescription(protected_t, std::shared_ptr<const OpenaiProvider> provider, QString modelName);
static auto create(std::shared_ptr<const OpenaiProvider> provider, QByteArray modelHash)
-> std::shared_ptr<OpenaiModelDescription>
{ return std::make_shared<OpenaiModelDescription>(protected_t(), std::move(provider), std::move(modelHash)); }
// getters
[[nodiscard]] auto provider () const -> const OpenaiProvider * override { return m_provider.get(); }
[[nodiscard]] QVariant key () const override { return m_modelName; }

View File

@@ -8,10 +8,13 @@
#include <fmt/format.h>
#include <gpt4all-backend/formatters.h> // IWYU pragma: keep
#include <QtAssert>
#include <QModelIndex> // IWYU pragma: keep
#include <QVariant>
namespace fs = std::filesystem;
#include <algorithm>
namespace ranges = std::ranges;
namespace gpt4all::ui {
@@ -43,14 +46,50 @@ QVariant GenerationParams::tryParseValue(QMap<GenerationParam, QVariant> &values
return value;
}
ProviderStatus::ProviderStatus(const backend::ResponseError &error)
: m_ok(false)
{
auto &code = error.error();
if (auto *badStatus = std::get_if<backend::ResponseError::BadStatus>(&code)) {
m_detail = QObject::tr("HTTP %1%2%3").arg(
QString::number(badStatus->code),
badStatus->reason ? u" "_s : QString(),
badStatus->reason.value_or(QString())
);
return;
}
if (auto *netErr = std::get_if<QNetworkReply::NetworkError>(&code)) {
auto meta = QMetaEnum::fromType<QNetworkReply::NetworkError>();
m_detail = QString::fromUtf8(meta.valueToKey(*netErr));
return;
}
m_detail = QObject::tr("(unknown error)");
}
ModelProvider::ModelProvider(protected_t, QUuid id, QString name, QUrl baseUrl) // create built-in or load
: m_id(std::move(id)), m_name(std::move(name)), m_baseUrl(std::move(baseUrl))
{ Q_ASSERT(!m_id.isNull()); }
ModelProvider::~ModelProvider() noexcept = default;
auto ModelProvider::newModel(const QVariant &key) const -> std::shared_ptr<ModelDescription>
{ return newModelImpl(key); }
ModelProviderMutable::~ModelProviderMutable() noexcept
{
if (!m_id.isNull()) // (will be null if constructor throws)
if (auto res = m_store->release(m_id); !res)
res.error().raise(); // should not happen - will terminate program
}
auto ModelProviderCustom::persist() -> DataStoreResult<>
{
if (auto res = m_store->create(asData()); !res)
return res;
m_persisted = true;
return {};
}
ProviderRegistry::ProviderRegistry(PathSet paths)
: m_customStore (std::move(paths.custom ))
, m_builtinStore(std::move(paths.builtin))
@@ -70,18 +109,27 @@ ProviderRegistry *ProviderRegistry::globalInstance()
void ProviderRegistry::load()
{
size_t i = 0;
auto registerListener = [this](ModelProvider *provider) {
// listen for any change in the provider so we can tell the model about it
if (auto *mut = dynamic_cast<ModelProviderMutable *>(provider))
connect(mut->asQObject(), "apiKeyChanged", this, "onProviderChanged");
if (auto *cust = dynamic_cast<ModelProviderCustom *>(provider)) {
connect(cust->asQObject(), "nameChanged", this, "onProviderChanged");
connect(cust->asQObject(), "baseUrlChanged", this, "onProviderChanged");
}
};
for (auto &p : s_builtinProviders) { // (not all builtin providers are stored)
auto provider = std::make_shared<OpenaiProviderBuiltin>(
auto provider = OpenaiProviderBuiltin::create(
&m_builtinStore, p.id, p.name, p.icon, p.base_url,
QStringList(p.model_whitelist.begin(), p.model_whitelist.end())
std::unordered_set<QString>(p.model_whitelist.begin(), p.model_whitelist.end())
);
auto [_, unique] = m_providers.emplace(p.id, std::move(provider));
auto [it, unique] = m_providers.emplace(p.id, std::move(provider));
if (!unique)
throw std::logic_error(fmt::format("duplicate builtin provider id: {}", p.id.toString()));
m_builtinProviders[i++] = p.id;
m_providerList.push_back(&p.id);
registerListener(it->second.get());
}
for (auto &p : m_customStore.list()) { // disk is source of truth for custom providers
for (auto p : m_customStore.list()) { // disk is source of truth for custom providers
if (!p.custom_details) {
qWarning() << "ignoring builtin provider in custom store:" << p.id;
continue;
@@ -91,47 +139,55 @@ void ProviderRegistry::load()
switch (p.type()) {
using enum ProviderType;
case ollama:
provider = std::make_shared<OllamaProviderCustom>(
provider = OllamaProviderCustom::create(
&m_customStore, p.id, cust.name, cust.base_url
);
break;
case openai:
provider = std::make_shared<OpenaiProviderCustom>(
provider = OpenaiProviderCustom::create(
&m_customStore, p.id, cust.name, cust.base_url,
std::get<size_t(ProviderType::openai)>(p.provider_details).api_key
);
}
auto [_, unique] = m_providers.emplace(p.id, std::move(provider));
auto [it, unique] = m_providers.emplace(p.id, std::move(provider));
if (!unique)
qWarning() << "ignoring duplicate custom provider with id:" << p.id;
m_customProviders.push_back(std::make_unique<QUuid>(p.id));
m_providerList.push_back(&it->second->id());
registerListener(it->second.get());
}
}
[[nodiscard]]
bool ProviderRegistry::add(std::shared_ptr<ModelProviderCustom> provider)
auto ProviderRegistry::add(std::shared_ptr<ModelProviderCustom> provider) -> DataStoreResult<>
{
if (auto res = provider->persist(); !res)
return res;
auto [it, unique] = m_providers.emplace(provider->id(), std::move(provider));
if (unique) {
m_customProviders.push_back(std::make_unique<QUuid>(it->first));
emit customProviderAdded(m_customProviders.size() - 1);
if (!unique)
return std::unexpected(u"custom provider already registered: %1"_s.arg(provider->id().toString()));
m_providerList.push_back(&it->second->id());
emit customProviderAdded(m_providerList.size() - 1);
return {};
}
bool ProviderRegistry::addQml(QmlSharedPtr *provider)
{
auto obj = std::dynamic_pointer_cast<ModelProviderCustom>(provider->ptr());
if (!obj) {
qWarning() << "ProviderRegistry::add failed: Expected ModelProviderCustom, got"
<< provider->metaObject()->className();
return false;
}
return unique;
auto res = add(obj);
if (!res)
qWarning() << "ProviderRegistry::add failed:" << res.error().errorString();
return bool(res);
}
auto ProviderRegistry::customProviderAt(size_t i) const -> ModelProviderCustom *
auto ProviderRegistry::providerAt(size_t i) const -> const ModelProvider *
{
auto it = m_providers.find(*m_customProviders.at(i));
auto it = m_providers.find(*m_providerList.at(i));
Q_ASSERT(it != m_providers.end());
return &dynamic_cast<ModelProviderCustom &>(*it->second);
}
auto ProviderRegistry::builtinProviderAt(size_t i) const -> ModelProviderBuiltin *
{
auto it = m_providers.find(m_builtinProviders.at(i));
Q_ASSERT(it != m_providers.end());
return &dynamic_cast<ModelProviderBuiltin &>(*it->second);
return it->second.get();
}
auto ProviderRegistry::getSubdirs() -> PathSet
@@ -147,8 +203,8 @@ void ProviderRegistry::onModelPathChanged()
if (paths.builtin != m_builtinStore.path()) {
emit aboutToBeCleared();
// delete providers to release store locks
m_customProviders.clear();
m_providers.clear();
m_providerList.clear();
if (auto res = m_builtinStore.setPath(paths.builtin); !res)
res.error().raise(); // should not happen
if (auto res = m_customStore.setPath(paths.custom); !res)
@@ -157,42 +213,58 @@ void ProviderRegistry::onModelPathChanged()
}
}
auto BuiltinProviderList::roleNames() const -> QHash<int, QByteArray>
{ return { { Qt::DisplayRole, "data"_ba } }; }
QVariant BuiltinProviderList::data(const QModelIndex &index, int role) const
void ProviderRegistry::onProviderChanged()
{
auto *registry = ProviderRegistry::globalInstance();
if (index.isValid() && index.row() < rowCount() && role == Qt::DisplayRole)
return QVariant::fromValue(registry->builtinProviderAt(index.row())->asQObject());
return {};
// notify that this provider has changed
auto *obj = &dynamic_cast<ModelProvider &>(*QObject::sender());
auto it = ranges::find_if(m_providerList, [&](auto *id) { return *id == obj->id(); });
if (it < m_providerList.end())
emit customProviderChanged(it - m_providerList.begin());
}
CustomProviderList::CustomProviderList()
: m_size(ProviderRegistry::globalInstance()->customProviderCount())
ProviderList::ProviderList()
: m_size(ProviderRegistry::globalInstance()->providerCount())
{
auto *registry = ProviderRegistry::globalInstance();
connect(registry, &ProviderRegistry::customProviderAdded, this, &CustomProviderList::onCustomProviderAdded);
connect(registry, &ProviderRegistry::aboutToBeCleared, this, &CustomProviderList::onAboutToBeCleared,
connect(registry, &ProviderRegistry::customProviderAdded, this, &ProviderList::onCustomProviderAdded);
connect(registry, &ProviderRegistry::customProviderRemoved, this, &ProviderList::onCustomProviderRemoved);
connect(registry, &ProviderRegistry::customProviderChanged, this, &ProviderList::onCustomProviderChanged);
connect(registry, &ProviderRegistry::aboutToBeCleared, this, &ProviderList::onAboutToBeCleared,
Qt::DirectConnection);
}
QVariant CustomProviderList::data(const QModelIndex &index, int role) const
auto ProviderList::roleNames() const -> QHash<int, QByteArray>
{ return { { Qt::DisplayRole, "provider"_ba } }; }
QVariant ProviderList::data(const QModelIndex &index, int role) const
{
auto *registry = ProviderRegistry::globalInstance();
if (index.isValid() && index.row() < rowCount() && role == Qt::DisplayRole)
return QVariant::fromValue(registry->customProviderAt(index.row())->asQObject());
return QVariant::fromValue(registry->providerAt(index.row())->asQObject());
return {};
}
void CustomProviderList::onCustomProviderAdded(size_t index)
void ProviderList::onCustomProviderAdded(size_t index)
{
beginInsertRows({}, m_size, m_size);
beginInsertRows({}, index, index);
m_size++;
endInsertRows();
}
void CustomProviderList::onAboutToBeCleared()
void ProviderList::onCustomProviderRemoved(size_t index)
{
beginRemoveRows({}, index, index);
m_size--;
endRemoveRows();
}
void ProviderList::onCustomProviderChanged(size_t index)
{
auto i = this->index(index);
emit dataChanged(i, i);
}
void ProviderList::onAboutToBeCleared()
{
beginResetModel();
m_size = 0;
@@ -203,8 +275,13 @@ bool ProviderListSort::lessThan(const QModelIndex &left, const QModelIndex &righ
{
auto *leftData = sourceModel()->data(left ).value<ModelProvider *>();
auto *rightData = sourceModel()->data(right).value<ModelProvider *>();
if (leftData && rightData)
return QString::localeAwareCompare(leftData->name(), rightData->name()) < 0;
if (leftData && rightData) {
if (leftData->isBuiltin() != rightData->isBuiltin())
return leftData->isBuiltin() > rightData->isBuiltin(); // builtins first
if (leftData->isBuiltin())
return left.row() < right.row(); // preserve order of builtins
return QString::localeAwareCompare(leftData->name(), rightData->name()) < 0; // sort by name
}
return true;
}

View File

@@ -2,23 +2,29 @@
#include "store_provider.h"
#include "qmlsharedptr.h" // IWYU pragma: keep
#include "utils.h" // IWYU pragma: keep
#include <gpt4all-backend/ollama-client.h>
#include <QAbstractListModel>
#include <QObject>
#include <QQmlEngine> // IWYU pragma: keep
#include <QSortFilterProxyModel>
#include <QString>
#include <QStringList> // IWYU pragma: keep
#include <QUrl>
#include <QUuid>
#include <QtPreprocessorSupport>
#include <array>
#include <cstddef>
#include <expected>
#include <filesystem>
#include <memory>
#include <optional>
#include <string_view>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -26,6 +32,10 @@
class QByteArray;
class QJSEngine;
template <typename Key, typename T> class QHash;
namespace QCoro {
template <typename T> class Task;
struct QmlTask;
}
namespace gpt4all::ui {
@@ -33,6 +43,30 @@ namespace gpt4all::ui {
Q_NAMESPACE
class ModelDescription;
namespace detail {
template <typename T>
struct is_expected_impl : std::false_type {};
template <typename T, typename E>
struct is_expected_impl<std::expected<T, E>> : std::true_type {};
template <typename T>
concept is_expected = is_expected_impl<std::remove_cvref_t<T>>::value;
} // namespace detail
/// Drop the type and error information from a QCoro::Task<DataOrRespErr<T>> so it can be used by QML.
template <typename C, typename F, typename... Args>
requires (!detail::is_expected<typename std::invoke_result_t<F, C *, Args...>::value_type>)
QCoro::QmlTask wrapQmlTask(std::shared_ptr<C> c, F f, QString prefix, Args &&...args);
template <typename C, typename F, typename... Args>
requires detail::is_expected<typename std::invoke_result_t<F, C *, Args...>::value_type>
QCoro::QmlTask wrapQmlTask(std::shared_ptr<C> c, F f, QString prefix, Args &&...args);
enum class GenerationParam {
NPredict,
Temperature,
@@ -61,12 +95,28 @@ protected:
void tryParseValue(this S &self, QMap<GenerationParam, QVariant> &values, GenerationParam key, T C::* dest);
};
class ModelProvider {
class ProviderStatus {
Q_GADGET
Q_PROPERTY(bool ok READ ok CONSTANT)
Q_PROPERTY(QString detail READ detail CONSTANT)
public:
explicit ProviderStatus(QString okMsg): m_ok(true), m_detail(std::move(okMsg)) {}
explicit ProviderStatus(const backend::ResponseError &error);
bool ok () const { return m_ok; }
const QString &detail() const { return m_detail; }
private:
bool m_ok;
QString m_detail;
};
class ModelProvider : public std::enable_shared_from_this<ModelProvider> {
protected:
explicit ModelProvider(QUuid id, QString name, QUrl baseUrl) // create built-in or load
: m_id(std::move(id)), m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {}
explicit ModelProvider(QString name, QUrl baseUrl) // create custom
: m_name(std::move(name)), m_baseUrl(std::move(baseUrl)) {}
struct protected_t { explicit protected_t() = default; };
explicit ModelProvider(protected_t, QUuid id, QString name, QUrl baseUrl);
public:
virtual ~ModelProvider() noexcept = 0;
@@ -74,6 +124,8 @@ public:
virtual QObject *asQObject() = 0;
virtual const QObject *asQObject() const = 0;
virtual bool isBuiltin() const = 0;
// getters
[[nodiscard]] const QUuid &id () const { return m_id; }
[[nodiscard]] const QString &name () const { return m_name; }
@@ -82,13 +134,24 @@ public:
virtual auto supportedGenerationParams() const -> QSet<GenerationParam> = 0;
virtual auto makeGenerationParams(const QMap<GenerationParam, QVariant> &values) const -> GenerationParams * = 0;
// endpoints
virtual auto status () -> QCoro::Task<ProviderStatus > = 0;
virtual auto listModels() -> QCoro::Task<backend::DataOrRespErr<QStringList>> = 0;
/// create a model using this provider
[[nodiscard]] auto newModel(const QVariant &key) const -> std::shared_ptr<ModelDescription>;
friend bool operator==(const ModelProvider &a, const ModelProvider &b)
{ return a.m_id == b.m_id; }
protected:
[[nodiscard]] virtual auto newModelImpl(const QVariant &key) const -> std::shared_ptr<ModelDescription> = 0;
QUuid m_id;
QString m_name;
QUrl m_baseUrl;
template <typename T> friend struct Creatable;
};
class ModelProviderBuiltin : public virtual ModelProvider {
@@ -97,6 +160,8 @@ protected:
: m_icon(std::move(icon)) {}
public:
bool isBuiltin() const final { return true; }
[[nodiscard]] const QUrl &icon() const { return m_icon; }
protected:
@@ -119,6 +184,8 @@ protected:
[[nodiscard]] DataStoreResult<> setMemberProp(this S &self, T C::* member, std::string_view name, T value,
std::optional<QString> createName = {});
[[nodiscard]] virtual bool persisted() const { return true; }
ProviderStore *m_store;
};
@@ -128,11 +195,20 @@ protected:
: ModelProviderMutable(store) {}
public:
bool isBuiltin() const final { return false; }
// setters
[[nodiscard]] DataStoreResult<> setName (QString value)
{ return setMemberProp<QString>(&ModelProviderCustom::m_name, "name", std::move(value)); }
[[nodiscard]] DataStoreResult<> setBaseUrl(QUrl value)
{ return setMemberProp<QUrl >(&ModelProviderCustom::m_baseUrl, "baseUrl", std::move(value)); }
[[nodiscard]] auto persist() -> DataStoreResult<>;
protected:
[[nodiscard]] bool persisted() const override { return m_persisted; }
bool m_persisted = false;
};
class ProviderRegistry : public QObject {
@@ -156,17 +232,21 @@ protected:
public:
static ProviderRegistry *globalInstance();
[[nodiscard]] bool add(std::shared_ptr<ModelProviderCustom> provider);
[[nodiscard]] auto add(std::shared_ptr<ModelProviderCustom> provider) -> DataStoreResult<>;
Q_INVOKABLE bool addQml(QmlSharedPtr *provider);
// TODO(jared): implement a way to remove custom providers via the model
auto operator[](const QUuid &id) -> const ModelProvider * { return m_providers.at(id).get(); }
// TODO(jared): implement a way to remove custom providers via the model
[[nodiscard]] size_t customProviderCount () const { return m_customProviders.size(); }
[[nodiscard]] auto customProviderAt (size_t i) const -> ModelProviderCustom *;
[[nodiscard]] size_t builtinProviderCount() const { return m_builtinProviders.size(); }
[[nodiscard]] auto builtinProviderAt (size_t i) const -> ModelProviderBuiltin *;
[[nodiscard]] size_t providerCount() const { return m_providers.size(); }
[[nodiscard]] auto providerAt(size_t i) const -> const ModelProvider *;
ProviderStore *customStore() { return &m_customStore; }
Q_SIGNALS:
void customProviderAdded(size_t index);
void customProviderRemoved(size_t index); // TODO: use
void customProviderChanged(size_t index);
void aboutToBeCleared();
private:
@@ -175,6 +255,7 @@ private:
private Q_SLOTS:
void onModelPathChanged();
void onProviderChanged();
private:
static constexpr size_t N_BUILTIN = 3;
@@ -183,51 +264,32 @@ private:
ProviderStore m_customStore;
ProviderStore m_builtinStore;
std::unordered_map<QUuid, std::shared_ptr<ModelProvider>> m_providers;
std::vector<std::unique_ptr<QUuid>> m_customProviders;
std::array<QUuid, N_BUILTIN> m_builtinProviders;
std::vector<const QUuid *> m_providerList; // TODO: implement
};
// TODO: api keys are allowed to change for here and also below. That should emit dataChanged.
class BuiltinProviderList : public QAbstractListModel {
class ProviderList : public QAbstractListModel {
Q_OBJECT
QML_SINGLETON
QML_ELEMENT
public:
explicit BuiltinProviderList()
: m_size(ProviderRegistry::globalInstance()->builtinProviderCount()) {}
explicit ProviderList();
static BuiltinProviderList *create(QQmlEngine *, QJSEngine *) { return new BuiltinProviderList(); }
static ProviderList *create(QQmlEngine *, QJSEngine *) { return new ProviderList(); }
auto roleNames() const -> QHash<int, QByteArray> override;
int rowCount(const QModelIndex &parent = {}) const override
{ Q_UNUSED(parent) return int(m_size); }
QVariant data(const QModelIndex &index, int role) const override;
private:
size_t m_size;
};
class CustomProviderList : public QAbstractListModel {
Q_OBJECT
public:
explicit CustomProviderList();
int rowCount(const QModelIndex &parent = {}) const override
{ Q_UNUSED(parent) return int(m_size); }
QVariant data(const QModelIndex &index, int role) const override;
private Q_SLOTS:
void onCustomProviderAdded(size_t index);
void onCustomProviderAdded (size_t index);
void onCustomProviderRemoved(size_t index);
void onCustomProviderChanged(size_t index);
void onAboutToBeCleared();
private:
size_t m_size;
};
// todo: don't have singletons use singletons directly
// TODO: actually use the provider sort, here, rather than unsorted, for builtins
class ProviderListSort : public QSortFilterProxyModel {
Q_OBJECT
QML_SINGLETON
@@ -243,8 +305,7 @@ protected:
bool lessThan(const QModelIndex &left, const QModelIndex &right) const override;
private:
// TODO: support custom providers as well
BuiltinProviderList m_model;
ProviderList m_model;
};

View File

@@ -1,9 +1,43 @@
#include <fmt/format.h>
#include <QCoro/QCoroQmlTask>
#include <QCoro/QCoroTask>
#include <QDebug>
#include <QVariant>
#include <QtLogging>
#include <expected>
#include <functional>
namespace gpt4all::ui {
template <typename C, typename F, typename... Args>
requires (!detail::is_expected<typename std::invoke_result_t<F, C *, Args...>::value_type>)
QCoro::QmlTask wrapQmlTask(C *obj, F f, QString prefix, Args &&...args)
{
std::shared_ptr<C> ptr(obj->shared_from_this(), obj);
return [](std::shared_ptr<C> ptr, F f, QString prefix, Args &&...args) -> QCoro::Task<QVariant> {
co_return QVariant::fromValue(co_await std::invoke(f, ptr.get(), std::forward<Args>(args)...));
}(std::move(ptr), std::move(f), std::move(prefix), std::forward<Args>(args)...);
}
template <typename C, typename F, typename... Args>
requires detail::is_expected<typename std::invoke_result_t<F, C *, Args...>::value_type>
QCoro::QmlTask wrapQmlTask(C *obj, F f, QString prefix, Args &&...args)
{
std::shared_ptr<C> ptr(obj->shared_from_this(), obj);
return [](std::shared_ptr<C> ptr, F f, QString prefix, Args &&...args) -> QCoro::Task<QVariant> {
auto result = co_await std::invoke(f, ptr.get(), std::forward<Args>(args)...);
if (result)
co_return QVariant::fromValue(*result);
qWarning().noquote() << prefix << "failed:" << result.error().errorString();
co_return QVariant::fromValue(nullptr);
}(std::move(ptr), std::move(f), std::move(prefix), std::forward<Args>(args)...);
}
template <typename T, typename S, typename C>
void GenerationParams::tryParseValue(this S &self, QMap<GenerationParam, QVariant> &values, GenerationParam key,
T C::* dest)
@@ -20,9 +54,11 @@ auto ModelProviderMutable::setMemberProp(this S &self, T C::* member, std::strin
auto &cur = self.*member;
if (cur != value) {
cur = std::move(value);
if (mpc.persisted()) {
auto data = mpc.asData();
if (auto res = mpc.m_store->setData(std::move(data), createName); !res)
return res;
}
QMetaObject::invokeMethod(self.asQObject(), fmt::format("{}Changed", name).c_str(), cur);
}
return {};

View File

@@ -2,6 +2,7 @@
#include "config.h"
#include "download.h"
#include "llm.h"
#include "llmodel_provider.h"
#include "localdocs.h"
#include "logger.h"
#include "modellist.h"
@@ -153,6 +154,7 @@ int main(int argc, char *argv[])
qmlRegisterSingletonInstance("network", 1, 0, "Network", Network::globalInstance());
qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance());
qmlRegisterSingletonInstance("toollist", 1, 0, "ToolList", ToolModel::globalInstance());
qmlRegisterSingletonInstance("gpt4all.ProviderRegistry", 1, 0, "ProviderRegistry", ProviderRegistry::globalInstance());
qmlRegisterUncreatableMetaObject(ToolEnums::staticMetaObject, "toolenums", 1, 0, "ToolEnums", "Error: only enums");
qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums");

View File

@@ -178,9 +178,9 @@ MySettings::MySettings()
{
}
const QString &MySettings::userAgent()
const QByteArray &MySettings::userAgent()
{
static const QString s_userAgent = QStringLiteral("gpt4all/" APP_VERSION);
static const QByteArray s_userAgent = QByteArrayLiteral("gpt4all/" APP_VERSION);
return s_userAgent;
}

View File

@@ -88,7 +88,7 @@ public Q_SLOTS:
public:
static MySettings *globalInstance();
static const QString &userAgent();
static const QByteArray &userAgent();
Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl);

View File

@@ -0,0 +1,41 @@
#include "qmlfunctions.h"
#include "llmodel_ollama.h"
#include "llmodel_openai.h"
#include "llmodel_provider.h"
#include <exception>
#include <utility>
namespace gpt4all::ui {
QmlSharedPtr *QmlFunctions::newCustomOpenaiProvider(QString name, QUrl baseUrl, QString apiKey) const
{
auto *store = ProviderRegistry::globalInstance()->customStore();
std::shared_ptr<OpenaiProviderCustom> ptr;
try {
ptr = OpenaiProviderCustom::create(store, std::move(name), std::move(baseUrl), std::move(apiKey));
} catch (const std::exception &e) {
qWarning() << "newCustomOpenaiProvider failed:" << e.what();
return nullptr;
}
return new QmlSharedPtr(std::move(ptr));
}
QmlSharedPtr *QmlFunctions::newCustomOllamaProvider(QString name, QUrl baseUrl) const
{
auto *store = ProviderRegistry::globalInstance()->customStore();
std::shared_ptr<OllamaProviderCustom> ptr;
try {
ptr = OllamaProviderCustom::create(store, std::move(name), std::move(baseUrl));
} catch (const std::exception &e) {
qWarning() << "newCustomOllamaProvider failed:" << e.what();
return nullptr;
}
return new QmlSharedPtr(std::move(ptr));
}
} // namespace gpt4all::ui

View File

@@ -0,0 +1,31 @@
#pragma once
#include "qmlsharedptr.h" // IWYU pragma: keep
#include <QObject>
#include <QQmlEngine>
#include <QString> // IWYU pragma: keep
#include <QUrl> // IWYU pragma: keep
namespace gpt4all::ui {
// The singleton through which all static methods and free functions are called in QML.
class QmlFunctions : public QObject {
Q_OBJECT
QML_ELEMENT
QML_SINGLETON
explicit QmlFunctions() = default;
public:
static QmlFunctions *create(QQmlEngine *, QJSEngine *) { return new QmlFunctions; }
Q_INVOKABLE QmlSharedPtr *newCustomOpenaiProvider(QString name, QUrl baseUrl, QString apiKey) const;
Q_INVOKABLE QmlSharedPtr *newCustomOllamaProvider(QString name, QUrl baseUrl) const;
};
} // namespace gpt4all::ui

View File

@@ -0,0 +1,14 @@
#include "qmlsharedptr.h"
#include <QJSEngine>
namespace gpt4all::ui {
QmlSharedPtr::QmlSharedPtr(std::shared_ptr<QObject> ptr)
: m_ptr(std::move(ptr))
{ if (m_ptr) { QJSEngine::setObjectOwnership(m_ptr.get(), QJSEngine::CppOwnership); } }
} // namespace gpt4all::ui

View File

@@ -0,0 +1,25 @@
#pragma once
#include <QObject>
#include <memory>
namespace gpt4all::ui {
class QmlSharedPtr : public QObject {
Q_OBJECT
public:
explicit QmlSharedPtr(std::shared_ptr<QObject> ptr);
const std::shared_ptr<QObject> &ptr() { return m_ptr; }
Q_INVOKABLE QObject *get() { return m_ptr.get(); }
private:
std::shared_ptr<QObject> m_ptr;
};
} // namespace gpt4all::ui

View File

@@ -2,6 +2,7 @@
#include <fmt/format.h>
#include <gpt4all-backend/formatters.h> // IWYU pragma: keep
#include <tl/generator.hpp>
#include <QByteArray>
#include <QDebug>
@@ -157,10 +158,11 @@ auto DataStoreBase::read(QFileDevice &file, json::stream_parser &parser) -> Data
}
};
auto inner = [&] -> DataStoreResult<> {
bool partialRead = false;
auto chunkIt = iterChunks();
// read JSON data
parser.reset();
for (auto &chunk : chunkIt) {
if (!chunk)
return std::unexpected(chunk.error());
@@ -183,14 +185,6 @@ auto DataStoreBase::read(QFileDevice &file, json::stream_parser &parser) -> Data
return std::unexpected(u"unexpected data after json: \"%1\""_s.arg(*chunk));
}
}
return {};
};
auto res = inner();
if (!res) {
parser.reset();
return std::unexpected(res.error());
}
return parser.release();
}

View File

@@ -4,7 +4,6 @@
#include <boost/json.hpp> // IWYU pragma: keep
#include <boost/system.hpp> // IWYU pragma: keep
#include <tl/generator.hpp>
#include <QFile>
#include <QFileDevice>
@@ -22,6 +21,8 @@
#include <utility>
#include <variant>
#include <ranges>
class QByteArray;
class QSaveFile;
@@ -96,7 +97,7 @@ class DataStore : public DataStoreBase {
public:
explicit DataStore(std::filesystem::path path);
auto list() -> tl::generator<const T &>;
auto list() { return m_entries | std::views::transform([](auto &e) { return e.second; }); }
auto setData(T data, std::optional<QString> createName = {}) -> DataStoreResult<>;
auto remove(const QUuid &id) -> DataStoreResult<>;
@@ -109,7 +110,7 @@ public:
{ auto it = m_entries.find(id); return it == m_entries.end() ? std::nullopt : std::optional(&it->second); }
protected:
auto createImpl(T data, const QString &name) -> DataStoreResult<const T *>;
auto createImpl(T data, const QString &name) -> DataStoreResult<>;
auto clear() -> DataStoreResult<> final;
CacheInsertResult cacheInsert(const boost::json::value &jv) override;

View File

@@ -6,8 +6,11 @@
#include <QSaveFile>
#include <QtAssert>
#include <ranges>
#include <system_error>
namespace views = std::views;
namespace gpt4all::ui {
@@ -21,14 +24,7 @@ DataStore<T>::DataStore(std::filesystem::path path)
}
template <typename T>
auto DataStore<T>::list() -> tl::generator<const T &>
{
for (auto &[_, value] : m_entries)
co_yield value;
}
template <typename T>
auto DataStore<T>::createImpl(T data, const QString &name) -> DataStoreResult<const T *>
auto DataStore<T>::createImpl(T data, const QString &name) -> DataStoreResult<>
{
// acquire path
auto file = openNew(name);
@@ -42,12 +38,7 @@ auto DataStore<T>::createImpl(T data, const QString &name) -> DataStoreResult<co
// insert
auto [it, unique] = m_entries.emplace(data.id, std::move(data));
Q_ASSERT(unique);
// acquire data ownership
if (auto res = acquire(data.id); !res)
return std::unexpected(res.error());
return &it->second;
return {};
}
template <typename T>

View File

@@ -57,25 +57,9 @@ auto tag_invoke(const boost::json::value_to_tag<ModelProviderData> &, const boos
};
}
auto ProviderStore::create(QString name, QUrl base_url, QString api_key)
-> DataStoreResult<const ModelProviderData *>
auto ProviderStore::create(ModelProviderData data) -> DataStoreResult<>
{
ModelProviderData data {
.id = QUuid::createUuid(),
.custom_details = CustomProviderDetails { name, std::move(base_url) },
.provider_details = OpenaiProviderDetails { std::move(api_key) },
};
return createImpl(std::move(data), name);
}
auto ProviderStore::create(QString name, QUrl base_url)
-> DataStoreResult<const ModelProviderData *>
{
ModelProviderData data {
.id = QUuid::createUuid(),
.custom_details = CustomProviderDetails { name, std::move(base_url) },
.provider_details = {},
};
auto name = data.custom_details.value().name;
return createImpl(std::move(data), name);
}

View File

@@ -49,10 +49,7 @@ private:
public:
using Super::Super;
/// OpenAI
auto create(QString name, QUrl base_url, QString api_key) -> DataStoreResult<const ModelProviderData *>;
/// Ollama
auto create(QString name, QUrl base_url) -> DataStoreResult<const ModelProviderData *>;
auto create(ModelProviderData data) -> DataStoreResult<>;
};