make it build - still plenty of TODOs

This commit is contained in:
Jared Van Bortel 2025-03-11 17:08:11 -04:00
parent 7745f208bc
commit f7cd880f96
23 changed files with 129 additions and 77 deletions

10
deps/CMakeLists.txt vendored
View File

@ -11,3 +11,13 @@ set(QCORO_WITH_QTQUICK OFF)
set(QCORO_WITH_QML OFF) set(QCORO_WITH_QML OFF)
set(QCORO_WITH_QTTEST OFF) set(QCORO_WITH_QTTEST OFF)
add_subdirectory(qcoro) add_subdirectory(qcoro)
set(GPT4ALL_BOOST_TAG 1.87.0)
FetchContent_Declare(
boost
URL "https://github.com/boostorg/boost/releases/download/boost-${GPT4ALL_BOOST_TAG}/boost-${GPT4ALL_BOOST_TAG}-cmake.tar.xz"
URL_HASH "SHA256=7da75f171837577a52bbf217e17f8ea576c7c246e4594d617bfde7fafd408be5"
)
set(BOOST_INCLUDE_LIBRARIES json describe system)
FetchContent_MakeAvailable(boost)

View File

@ -2,15 +2,4 @@ include(FetchContent)
set(BUILD_SHARED_LIBS OFF) set(BUILD_SHARED_LIBS OFF)
# suppress warnings during boost build
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:BOOST_ALLOW_DEPRECATED_HEADERS>)
set(GPT4ALL_BOOST_TAG 1.87.0)
FetchContent_Declare(
boost
URL "https://github.com/boostorg/boost/releases/download/boost-${GPT4ALL_BOOST_TAG}/boost-${GPT4ALL_BOOST_TAG}-cmake.tar.xz"
URL_HASH "SHA256=7da75f171837577a52bbf217e17f8ea576c7c246e4594d617bfde7fafd408be5"
)
FetchContent_MakeAvailable(boost)
add_subdirectory(date) add_subdirectory(date)

View File

@ -240,7 +240,8 @@ qt_add_executable(chat
src/jinja_replacements.cpp src/jinja_replacements.h src/jinja_replacements.cpp src/jinja_replacements.h
src/json-helpers.cpp src/json-helpers.h src/json-helpers.cpp src/json-helpers.h
src/llm.cpp src/llm.h src/llm.cpp src/llm.h
src/llmodel_chat.h src/llmodel_chat.h src/llmodel_chat.cpp
src/llmodel_description.h src/llmodel_description.cpp
src/llmodel_ollama.cpp src/llmodel_ollama.h src/llmodel_ollama.cpp src/llmodel_ollama.h
src/llmodel_openai.cpp src/llmodel_openai.h src/llmodel_openai.cpp src/llmodel_openai.h
src/llmodel_provider.cpp src/llmodel_provider.h src/llmodel_provider.cpp src/llmodel_provider.h

View File

@ -127,7 +127,7 @@ struct PromptModelWithToolsResult {
bool shouldExecuteToolCall; bool shouldExecuteToolCall;
}; };
static auto promptModelWithTools( static auto promptModelWithTools(
ChatLLMInstance *model, BaseResponseHandler &respHandler, const GenerationParams &params, const QByteArray &prompt, ChatLLMInstance *model, BaseResponseHandler &respHandler, const GenerationParams *params, const QByteArray &prompt,
const QStringList &toolNames const QStringList &toolNames
) -> QCoro::Task<PromptModelWithToolsResult> ) -> QCoro::Task<PromptModelWithToolsResult>
{ {
@ -499,8 +499,8 @@ void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
} }
} }
auto ChatLLM::modelDescription() -> const ModelDescription * auto ChatLLM::modelProvider() -> const ModelProvider *
{ return m_llmInstance->description(); } { return m_llmInstance->description()->provider(); }
void ChatLLM::prompt(const QStringList &enabledCollections) void ChatLLM::prompt(const QStringList &enabledCollections)
{ {
@ -512,7 +512,7 @@ void ChatLLM::prompt(const QStringList &enabledCollections)
} }
try { try {
promptInternalChat(enabledCollections, mySettings->modelGenParams(m_modelInfo)); QCoro::waitFor(promptInternalChat(enabledCollections, mySettings->modelGenParams(m_modelInfo).get()));
} catch (const std::exception &e) { } catch (const std::exception &e) {
// FIXME(jared): this is neither translated nor serialized // FIXME(jared): this is neither translated nor serialized
m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what()))); m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what())));
@ -641,8 +641,8 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const MessageItem> items) cons
Q_UNREACHABLE(); Q_UNREACHABLE();
} }
auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const GenerationParams &params, auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const GenerationParams *params,
qsizetype startOffset) -> ChatPromptResult qsizetype startOffset) -> QCoro::Task<ChatPromptResult>
{ {
Q_ASSERT(isModelLoaded()); Q_ASSERT(isModelLoaded());
Q_ASSERT(m_chatModel); Q_ASSERT(m_chatModel);
@ -679,8 +679,8 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const Ge
auto messageItems = getChat(); auto messageItems = getChat();
messageItems.pop_back(); // exclude new response messageItems.pop_back(); // exclude new response
auto result = promptInternal(messageItems, params, !databaseResults.isEmpty()); auto result = co_await promptInternal(messageItems, params, !databaseResults.isEmpty());
return { co_return {
/*PromptResult*/ { /*PromptResult*/ {
.response = std::move(result.response), .response = std::move(result.response),
.promptTokens = result.promptTokens, .promptTokens = result.promptTokens,
@ -748,7 +748,7 @@ private:
}; };
auto ChatLLM::promptInternal( auto ChatLLM::promptInternal(
const std::variant<std::span<const MessageItem>, std::string_view> &prompt, const GenerationParams &params, const std::variant<std::span<const MessageItem>, std::string_view> &prompt, const GenerationParams *params,
bool usedLocalDocs bool usedLocalDocs
) -> QCoro::Task<PromptResult> ) -> QCoro::Task<PromptResult>
{ {
@ -967,7 +967,7 @@ void ChatLLM::generateName()
// TODO: support interruption via m_stopGenerating // TODO: support interruption via m_stopGenerating
promptModelWithTools( promptModelWithTools(
m_llmInstance.get(), m_llmInstance.get(),
respHandler, mySettings->modelGenParams(m_modelInfo), respHandler, mySettings->modelGenParams(m_modelInfo).get(),
applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(), applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(),
{ ToolCallConstants::ThinkTagName } { ToolCallConstants::ThinkTagName }
); );
@ -1043,7 +1043,7 @@ void ChatLLM::generateQuestions(qint64 elapsed)
// TODO: support interruption via m_stopGenerating // TODO: support interruption via m_stopGenerating
promptModelWithTools( promptModelWithTools(
m_llmInstance.get(), m_llmInstance.get(),
respHandler, mySettings->modelGenParams(m_modelInfo), respHandler, mySettings->modelGenParams(m_modelInfo).get(),
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(), applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
{ ToolCallConstants::ThinkTagName } { ToolCallConstants::ThinkTagName }
); );

View File

@ -6,6 +6,8 @@
#include "llmodel_chat.h" #include "llmodel_chat.h"
#include "modellist.h" #include "modellist.h"
#include <QCoro/QCoroTask> // IWYU pragma: keep
#include <QByteArray> #include <QByteArray>
#include <QElapsedTimer> #include <QElapsedTimer>
#include <QFileInfo> #include <QFileInfo>
@ -32,6 +34,7 @@ using namespace Qt::Literals::StringLiterals;
class ChatLLM; class ChatLLM;
class QDataStream; class QDataStream;
namespace QCoro { template <typename T> class Task; } namespace QCoro { template <typename T> class Task; }
namespace gpt4all::ui { class ModelProvider; }
// NOTE: values serialized to disk, do not change or reuse // NOTE: values serialized to disk, do not change or reuse
@ -210,13 +213,13 @@ protected:
QList<ResultInfo> databaseResults; QList<ResultInfo> databaseResults;
}; };
auto modelDescription() -> const gpt4all::ui::ModelDescription *; auto modelProvider() -> const gpt4all::ui::ModelProvider *;
auto promptInternalChat(const QStringList &enabledCollections, const gpt4all::ui::GenerationParams &params, auto promptInternalChat(const QStringList &enabledCollections, const gpt4all::ui::GenerationParams *params,
qsizetype startOffset = 0) -> ChatPromptResult; qsizetype startOffset = 0) -> QCoro::Task<ChatPromptResult>;
// passing a string_view directly skips templating and uses the raw string // passing a string_view directly skips templating and uses the raw string
auto promptInternal(const std::variant<std::span<const MessageItem>, std::string_view> &prompt, auto promptInternal(const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
const gpt4all::ui::GenerationParams &params, bool usedLocalDocs) -> QCoro::Task<PromptResult>; const gpt4all::ui::GenerationParams *params, bool usedLocalDocs) -> QCoro::Task<PromptResult>;
private: private:
auto loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps) -> QCoro::Task<bool>; auto loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps) -> QCoro::Task<bool>;

View File

@ -0,0 +1,10 @@
#include "llmodel_chat.h"
namespace gpt4all::ui {
ChatLLMInstance::~ChatLLMInstance() noexcept = default;
} // namespace gpt4all::ui

View File

@ -22,11 +22,11 @@ struct ChatResponseMetadata {
// TODO: implement two of these; one based on Ollama (TBD) and the other based on OpenAI (chatapi.h) // TODO: implement two of these; one based on Ollama (TBD) and the other based on OpenAI (chatapi.h)
class ChatLLMInstance { class ChatLLMInstance {
public: public:
virtual ~ChatLLMInstance() = 0; virtual ~ChatLLMInstance() noexcept = 0;
virtual auto description() const -> const ModelDescription * = 0; virtual auto description() const -> const ModelDescription * = 0;
virtual auto preload() -> QCoro::Task<void> = 0; virtual auto preload() -> QCoro::Task<void> = 0;
virtual auto generate(QStringView prompt, const GenerationParams &params, /*out*/ ChatResponseMetadata &metadata) virtual auto generate(QStringView prompt, const GenerationParams *params, /*out*/ ChatResponseMetadata &metadata)
-> QCoro::AsyncGenerator<QString> = 0; -> QCoro::AsyncGenerator<QString> = 0;
}; };

View File

@ -7,6 +7,8 @@
namespace gpt4all::ui { namespace gpt4all::ui {
ModelDescription::~ModelDescription() noexcept = default;
auto ModelDescription::newInstance(QNetworkAccessManager *nam) const -> std::unique_ptr<ChatLLMInstance> auto ModelDescription::newInstance(QNetworkAccessManager *nam) const -> std::unique_ptr<ChatLLMInstance>
{ return std::unique_ptr<ChatLLMInstance>(newInstanceImpl(nam)); } { return std::unique_ptr<ChatLLMInstance>(newInstanceImpl(nam)); }

View File

@ -1,5 +1,7 @@
#pragma once #pragma once
#include "llmodel_provider.h" // IWYU pragma: keep
#include <QObject> #include <QObject>
#include <QVariant> #include <QVariant>
@ -12,7 +14,6 @@ namespace gpt4all::ui {
class ChatLLMInstance; class ChatLLMInstance;
class ModelProvider;
// TODO: implement shared_from_this guidance for restricted construction // TODO: implement shared_from_this guidance for restricted construction
class ModelDescription : public std::enable_shared_from_this<ModelDescription> { class ModelDescription : public std::enable_shared_from_this<ModelDescription> {

View File

@ -21,6 +21,8 @@ auto OllamaGenerationParams::toMap() const -> QMap<QLatin1StringView, QVariant>
}; };
} }
OllamaProvider::~OllamaProvider() noexcept = default;
auto OllamaProvider::supportedGenerationParams() const -> QSet<GenerationParam> auto OllamaProvider::supportedGenerationParams() const -> QSet<GenerationParam>
{ {
using enum GenerationParam; using enum GenerationParam;
@ -70,7 +72,7 @@ auto OllamaChatModel::preload() -> QCoro::Task<>
co_return; co_return;
} }
auto OllamaChatModel::generate(QStringView prompt, const GenerationParams &params, auto OllamaChatModel::generate(QStringView prompt, const GenerationParams *params,
/*out*/ ChatResponseMetadata &metadata) /*out*/ ChatResponseMetadata &metadata)
-> QCoro::AsyncGenerator<QString> -> QCoro::AsyncGenerator<QString>
{ {

View File

@ -50,7 +50,7 @@ public:
auto makeGenerationParams(const QMap<GenerationParam, QVariant> &values) const -> OllamaGenerationParams * override; auto makeGenerationParams(const QMap<GenerationParam, QVariant> &values) const -> OllamaGenerationParams * override;
}; };
class OllamaProviderBuiltin : public ModelProviderBuiltin, public OllamaProvider { class OllamaProviderBuiltin : public OllamaProvider, public ModelProviderBuiltin {
Q_GADGET Q_GADGET
public: public:
@ -109,7 +109,7 @@ public:
auto preload() -> QCoro::Task<void> override; auto preload() -> QCoro::Task<void> override;
auto generate(QStringView prompt, const GenerationParams &params, /*out*/ ChatResponseMetadata &metadata) auto generate(QStringView prompt, const GenerationParams *params, /*out*/ ChatResponseMetadata &metadata)
-> QCoro::AsyncGenerator<QString> override; -> QCoro::AsyncGenerator<QString> override;
private: private:

View File

@ -84,6 +84,8 @@ auto OpenaiGenerationParams::toMap() const -> QMap<QLatin1StringView, QVariant>
}; };
} }
OpenaiProvider::~OpenaiProvider() noexcept = default;
auto OpenaiProvider::supportedGenerationParams() const -> QSet<GenerationParam> auto OpenaiProvider::supportedGenerationParams() const -> QSet<GenerationParam>
{ {
using enum GenerationParam; using enum GenerationParam;
@ -212,22 +214,22 @@ static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QStr
} }
} }
auto preload() -> QCoro::Task<> auto OpenaiChatModel::preload() -> QCoro::Task<>
{ co_return; /* not supported -> no-op */ } { co_return; /* not supported -> no-op */ }
auto OpenaiChatModel::generate(QStringView prompt, const GenerationParams &params, auto OpenaiChatModel::generate(QStringView prompt, const GenerationParams *params,
/*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator<QString> /*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator<QString>
{ {
auto *mySettings = MySettings::globalInstance(); auto *mySettings = MySettings::globalInstance();
if (params.isNoop()) if (params->isNoop())
co_return; // nothing requested co_return; // nothing requested
auto reqBody = makeJsonObject({ auto reqBody = makeJsonObject({
{ "model"_L1, m_description->modelName() }, { "model"_L1, m_description->modelName() },
{ "stream"_L1, true }, { "stream"_L1, true },
}); });
extend(reqBody, params.toMap()); extend(reqBody, params->toMap());
// conversation history // conversation history
{ {

View File

@ -63,7 +63,7 @@ protected:
QString m_apiKey; QString m_apiKey;
}; };
class OpenaiProviderBuiltin : public ModelProviderBuiltin, public OpenaiProvider { class OpenaiProviderBuiltin : public OpenaiProvider, public ModelProviderBuiltin {
Q_GADGET Q_GADGET
Q_PROPERTY(QString apiKey READ apiKey CONSTANT) Q_PROPERTY(QString apiKey READ apiKey CONSTANT)
@ -127,7 +127,7 @@ public:
auto preload() -> QCoro::Task<void> override; auto preload() -> QCoro::Task<void> override;
auto generate(QStringView prompt, const GenerationParams &params, /*out*/ ChatResponseMetadata &metadata) auto generate(QStringView prompt, const GenerationParams *params, /*out*/ ChatResponseMetadata &metadata)
-> QCoro::AsyncGenerator<QString> override; -> QCoro::AsyncGenerator<QString> override;
private: private:

View File

@ -14,6 +14,8 @@ namespace fs = std::filesystem;
namespace gpt4all::ui { namespace gpt4all::ui {
GenerationParams::~GenerationParams() noexcept = default;
void GenerationParams::parse(QMap<GenerationParam, QVariant> values) void GenerationParams::parse(QMap<GenerationParam, QVariant> values)
{ {
parseInner(values); parseInner(values);
@ -38,6 +40,8 @@ QVariant GenerationParams::tryParseValue(QMap<GenerationParam, QVariant> &values
return value; return value;
} }
ModelProvider::~ModelProvider() noexcept = default;
ModelProviderCustom::~ModelProviderCustom() noexcept ModelProviderCustom::~ModelProviderCustom() noexcept
{ {
if (auto res = m_store->release(m_id); !res) if (auto res = m_store->release(m_id); !res)

View File

@ -97,9 +97,6 @@ class ModelProviderBuiltin : public virtual ModelProvider {
Q_GADGET Q_GADGET
Q_PROPERTY(QString name READ name CONSTANT) Q_PROPERTY(QString name READ name CONSTANT)
Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT) Q_PROPERTY(QUrl baseUrl READ baseUrl CONSTANT)
public:
~ModelProviderBuiltin() noexcept override = 0;
}; };
class ModelProviderCustom : public virtual ModelProvider { class ModelProviderCustom : public virtual ModelProvider {

View File

@ -2,6 +2,7 @@
#include "download.h" #include "download.h"
#include "jinja_replacements.h" #include "jinja_replacements.h"
#include "llmodel_description.h"
#include "mysettings.h" #include "mysettings.h"
#include "network.h" #include "network.h"
@ -46,6 +47,7 @@
#include <utility> #include <utility>
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
using namespace gpt4all::ui;
//#define USE_LOCAL_MODELSJSON //#define USE_LOCAL_MODELSJSON
@ -90,6 +92,12 @@ void ModelInfo::setId(const QString &id)
m_id = id; m_id = id;
} }
void ModelInfo::setModelDesc(std::shared_ptr<const ModelDescription> value)
{ m_modelDesc = std::move(value); }
void ModelInfo::setModelDescQt(const ModelDescription *value)
{ return setModelDesc(value->shared_from_this()); }
QString ModelInfo::name() const QString ModelInfo::name() const
{ {
return MySettings::globalInstance()->modelName(*this); return MySettings::globalInstance()->modelName(*this);

View File

@ -77,7 +77,7 @@ private:
struct ModelInfo { struct ModelInfo {
Q_GADGET Q_GADGET
Q_PROPERTY(QString id READ id WRITE setId) Q_PROPERTY(QString id READ id WRITE setId)
Q_PROPERTY(const ModelDescription *modelDesc READ modelDescQt WRITE setModelDescQt) Q_PROPERTY(const gpt4all::ui::ModelDescription *modelDesc READ modelDescQt WRITE setModelDescQt)
Q_PROPERTY(QString name READ name WRITE setName) Q_PROPERTY(QString name READ name WRITE setName)
Q_PROPERTY(QString filename READ filename WRITE setFilename) Q_PROPERTY(QString filename READ filename WRITE setFilename)
Q_PROPERTY(QString dirpath MEMBER dirpath) Q_PROPERTY(QString dirpath MEMBER dirpath)
@ -140,12 +140,13 @@ public:
QString id() const; QString id() const;
void setId(const QString &id); void setId(const QString &id);
auto modelDesc() const -> const std::shared_ptr<const gpt4all::ui::ModelDescription> &; auto modelDesc() const -> const std::shared_ptr<const gpt4all::ui::ModelDescription> &
{ return m_modelDesc; }
auto modelDescQt() const -> const gpt4all::ui::ModelDescription * auto modelDescQt() const -> const gpt4all::ui::ModelDescription *
{ return modelDesc().get(); } { return modelDesc().get(); }
void setModelDesc(std::shared_ptr<const gpt4all::ui::ModelDescription> value); void setModelDesc(std::shared_ptr<const gpt4all::ui::ModelDescription> value);
void setModelDescQt(const gpt4all::ui::ModelDescription *); // TODO: implement void setModelDescQt(const gpt4all::ui::ModelDescription *value);
QString name() const; QString name() const;
void setName(const QString &name); void setName(const QString &name);
@ -257,7 +258,7 @@ private:
QVariant getField(QLatin1StringView name) const; QVariant getField(QLatin1StringView name) const;
QString m_id; QString m_id;
std::shared_ptr<const gpt4all::ui::ModelDescription> m_modelDesc; std::shared_ptr<const gpt4all::ui::ModelDescription> m_modelDesc; // TODO: set this somewhere
QString m_name; QString m_name;
QString m_filename; QString m_filename;
QString m_description; QString m_description;

View File

@ -2,6 +2,7 @@
#include "chatllm.h" #include "chatllm.h"
#include "config.h" #include "config.h"
#include "llmodel_provider.h"
#include "modellist.h" #include "modellist.h"
#include <gpt4all-backend/llmodel.h> #include <gpt4all-backend/llmodel.h>
@ -31,6 +32,7 @@
#endif #endif
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
using namespace gpt4all::ui;
// used only for settings serialization, do not translate // used only for settings serialization, do not translate
@ -352,6 +354,28 @@ int MySettings::modelRepeatPenaltyTokens (const ModelInfo &info) const
QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); } QString MySettings::modelChatNamePrompt (const ModelInfo &info) const { return getModelSetting("chatNamePrompt", info).toString(); }
QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); } QString MySettings::modelSuggestedFollowUpPrompt(const ModelInfo &info) const { return getModelSetting("suggestedFollowUpPrompt", info).toString(); }
auto MySettings::modelGenParams(const ModelInfo &info) -> std::unique_ptr<GenerationParams>
{
#if 0
// this coed is copied from server.cpp.
std::unique_ptr<GenerationParams> genParams;
{
using enum GenerationParam;
QMap<GenerationParam, QVariant> values;
if (auto v = request.max_tokens ) values.insert(NPredict, *v);
if (auto v = request.temperature) values.insert(Temperature, *v);
if (auto v = request.top_p ) values.insert(TopP, *v);
if (auto v = request.min_p ) values.insert(MinP, *v);
try {
genParams.reset(modelProvider()->makeGenerationParams(values));
} catch (const std::exception &e) {
throw InvalidRequestError(e.what());
}
}
#endif
return nullptr; // TODO: implement
}
auto MySettings::getUpgradeableModelSetting( auto MySettings::getUpgradeableModelSetting(
const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey const ModelInfo &info, QLatin1StringView legacyKey, QLatin1StringView newKey
) const -> UpgradeableSetting ) const -> UpgradeableSetting

View File

@ -156,8 +156,7 @@ public:
QString modelSuggestedFollowUpPrompt(const ModelInfo &info) const; QString modelSuggestedFollowUpPrompt(const ModelInfo &info) const;
Q_INVOKABLE void setModelSuggestedFollowUpPrompt(const ModelInfo &info, const QString &value, bool force = false); Q_INVOKABLE void setModelSuggestedFollowUpPrompt(const ModelInfo &info, const QString &value, bool force = false);
// TODO: implement auto modelGenParams(const ModelInfo &info) -> std::unique_ptr<gpt4all::ui::GenerationParams>;
auto modelGenParams(const ModelInfo &info) -> gpt4all::ui::GenerationParams;
// Application settings // Application settings
bool systemTray() const; bool systemTray() const;

View File

@ -3,9 +3,11 @@
#include "chat.h" #include "chat.h"
#include "chatmodel.h" #include "chatmodel.h"
#include "llmodel_description.h" #include "llmodel_description.h"
#include "llmodel_provider.h"
#include "modellist.h" #include "modellist.h"
#include "mysettings.h" #include "mysettings.h"
#include <QCoro/QCoroTask>
#include <fmt/format.h> #include <fmt/format.h>
#include <gpt4all-backend/formatters.h> #include <gpt4all-backend/formatters.h>
#include <gpt4all-backend/generation-params.h> #include <gpt4all-backend/generation-params.h>
@ -527,7 +529,7 @@ void Server::start()
#endif #endif
CompletionRequest req; CompletionRequest req;
parseRequest(req, std::move(reqObj)); parseRequest(req, std::move(reqObj));
auto [resp, respObj] = handleCompletionRequest(req); auto [resp, respObj] = QCoro::waitFor(handleCompletionRequest(req));
#if defined(DEBUG) #if defined(DEBUG)
if (respObj) if (respObj)
qDebug().noquote() << "/v1/completions reply" << QJsonDocument(*respObj).toJson(QJsonDocument::Indented); qDebug().noquote() << "/v1/completions reply" << QJsonDocument(*respObj).toJson(QJsonDocument::Indented);
@ -551,7 +553,7 @@ void Server::start()
#endif #endif
ChatRequest req; ChatRequest req;
parseRequest(req, std::move(reqObj)); parseRequest(req, std::move(reqObj));
auto [resp, respObj] = handleChatRequest(req); auto [resp, respObj] = QCoro::waitFor(handleChatRequest(req));
(void)respObj; (void)respObj;
#if defined(DEBUG) #if defined(DEBUG)
if (respObj) if (respObj)
@ -628,7 +630,7 @@ static auto makeError(auto &&...args) -> std::pair<QHttpServerResponse, std::opt
} }
auto Server::handleCompletionRequest(const CompletionRequest &request) auto Server::handleCompletionRequest(const CompletionRequest &request)
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>> -> QCoro::Task<std::pair<QHttpServerResponse, std::optional<QJsonObject>>>
{ {
Q_ASSERT(m_chatModel); Q_ASSERT(m_chatModel);
@ -649,7 +651,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
if (modelInfo.filename().isEmpty()) { if (modelInfo.filename().isEmpty()) {
std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl; std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); co_return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
emit requestResetResponseState(); // blocks emit requestResetResponseState(); // blocks
@ -657,10 +659,9 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
if (prevMsgIndex >= 0) if (prevMsgIndex >= 0)
m_chatModel->updateCurrentResponse(prevMsgIndex, false); m_chatModel->updateCurrentResponse(prevMsgIndex, false);
// NB: this resets the context, regardless of whether this model is already loaded if (!co_await loadModel(modelInfo)) {
if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); co_return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
std::unique_ptr<GenerationParams> genParams; std::unique_ptr<GenerationParams> genParams;
@ -672,7 +673,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
if (auto v = request.top_p ) values.insert(TopP, *v); if (auto v = request.top_p ) values.insert(TopP, *v);
if (auto v = request.min_p ) values.insert(MinP, *v); if (auto v = request.min_p ) values.insert(MinP, *v);
try { try {
genParams.reset(modelDescription()->makeGenerationParams(values)); genParams.reset(modelProvider()->makeGenerationParams(values));
} catch (const std::exception &e) { } catch (const std::exception &e) {
throw InvalidRequestError(e.what()); throw InvalidRequestError(e.what());
} }
@ -689,14 +690,14 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
for (int i = 0; i < request.n; ++i) { for (int i = 0; i < request.n; ++i) {
PromptResult result; PromptResult result;
try { try {
result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()), result = co_await promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()),
*genParams, genParams.get(),
/*usedLocalDocs*/ false); /*usedLocalDocs*/ false);
} catch (const std::exception &e) { } catch (const std::exception &e) {
m_chatModel->setResponseValue(e.what()); m_chatModel->setResponseValue(e.what());
m_chatModel->setError(); m_chatModel->setError();
emit responseStopped(0); emit responseStopped(0);
return makeError(QHttpServerResponder::StatusCode::InternalServerError); co_return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
QString resp = QString::fromUtf8(result.response); QString resp = QString::fromUtf8(result.response);
if (request.echo) if (request.echo)
@ -731,11 +732,11 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
{ "total_tokens", promptTokens + responseTokens }, { "total_tokens", promptTokens + responseTokens },
}); });
return {QHttpServerResponse(responseObject), responseObject}; co_return { QHttpServerResponse(responseObject), responseObject };
} }
auto Server::handleChatRequest(const ChatRequest &request) auto Server::handleChatRequest(const ChatRequest &request)
-> std::pair<QHttpServerResponse, std::optional<QJsonObject>> -> QCoro::Task<std::pair<QHttpServerResponse, std::optional<QJsonObject>>>
{ {
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo(); ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList(); const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
@ -754,15 +755,14 @@ auto Server::handleChatRequest(const ChatRequest &request)
if (modelInfo.filename().isEmpty()) { if (modelInfo.filename().isEmpty()) {
std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl; std::cerr << "ERROR: couldn't load default model " << request.model.toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); co_return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
emit requestResetResponseState(); // blocks emit requestResetResponseState(); // blocks
// NB: this resets the context, regardless of whether this model is already loaded if (!co_await loadModel(modelInfo)) {
if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl; std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return makeError(QHttpServerResponder::StatusCode::InternalServerError); co_return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
@ -790,7 +790,7 @@ auto Server::handleChatRequest(const ChatRequest &request)
if (auto v = request.top_p ) values.insert(TopP, *v); if (auto v = request.top_p ) values.insert(TopP, *v);
if (auto v = request.min_p ) values.insert(MinP, *v); if (auto v = request.min_p ) values.insert(MinP, *v);
try { try {
genParams.reset(modelDescription()->makeGenerationParams(values)); genParams.reset(modelProvider()->makeGenerationParams(values));
} catch (const std::exception &e) { } catch (const std::exception &e) {
throw InvalidRequestError(e.what()); throw InvalidRequestError(e.what());
} }
@ -802,12 +802,12 @@ auto Server::handleChatRequest(const ChatRequest &request)
for (int i = 0; i < request.n; ++i) { for (int i = 0; i < request.n; ++i) {
ChatPromptResult result; ChatPromptResult result;
try { try {
result = promptInternalChat(m_collections, *genParams, startOffset); result = co_await promptInternalChat(m_collections, genParams.get(), startOffset);
} catch (const std::exception &e) { } catch (const std::exception &e) {
m_chatModel->setResponseValue(e.what()); m_chatModel->setResponseValue(e.what());
m_chatModel->setError(); m_chatModel->setError();
emit responseStopped(0); emit responseStopped(0);
return makeError(QHttpServerResponder::StatusCode::InternalServerError); co_return makeError(QHttpServerResponder::StatusCode::InternalServerError);
} }
responses.emplace_back(result.response, result.databaseResults); responses.emplace_back(result.response, result.databaseResults);
if (i == 0) if (i == 0)
@ -855,5 +855,5 @@ auto Server::handleChatRequest(const ChatRequest &request)
{ "total_tokens", promptTokens + responseTokens }, { "total_tokens", promptTokens + responseTokens },
}); });
return {QHttpServerResponse(responseObject), responseObject}; co_return {QHttpServerResponse(responseObject), responseObject};
} }

View File

@ -18,6 +18,7 @@
class Chat; class Chat;
class ChatRequest; class ChatRequest;
class CompletionRequest; class CompletionRequest;
namespace QCoro { template <typename T> class Task; }
class Server : public ChatLLM class Server : public ChatLLM
@ -35,8 +36,8 @@ Q_SIGNALS:
void requestResetResponseState(); void requestResetResponseState();
private: private:
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>; auto handleCompletionRequest(const CompletionRequest &request) -> QCoro::Task<std::pair<QHttpServerResponse, std::optional<QJsonObject>>>;
auto handleChatRequest(const ChatRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>; auto handleChatRequest(const ChatRequest &request) -> QCoro::Task<std::pair<QHttpServerResponse, std::optional<QJsonObject>>>;
private Q_SLOTS: private Q_SLOTS:
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; } void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }

View File

@ -17,7 +17,7 @@ auto ProviderStore::create(QString name, QUrl base_url, QString api_key)
auto ProviderStore::create(QString name, QUrl base_url) auto ProviderStore::create(QString name, QUrl base_url)
-> DataStoreResult<const ModelProviderData *> -> DataStoreResult<const ModelProviderData *>
{ {
ModelProviderData data { QUuid::createUuid(), ProviderType::ollama, name, std::move(base_url) }; ModelProviderData data { QUuid::createUuid(), ProviderType::ollama, name, std::move(base_url), {} };
return createImpl(std::move(data), name); return createImpl(std::move(data), name);
} }

View File

@ -36,9 +36,7 @@ inline auto toFSPath(const QString &str) -> std::filesystem::path
reinterpret_cast<const char16_t *>(str.cend ()) }; reinterpret_cast<const char16_t *>(str.cend ()) };
} }
FileError::FileError(const QString &str, QFileDevice::FileError code) inline FileError::FileError(const QString &str, QFileDevice::FileError code)
: std::runtime_error(str.toUtf8().constData()) : std::runtime_error(str.toUtf8().constData())
, m_code(code) , m_code(code)
{ { Q_ASSERT(code); }
Q_ASSERT(code);
}