mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-16 16:31:30 +00:00
WIP
This commit is contained in:
parent
1ba555a174
commit
1dc9f22d5b
@ -13,7 +13,10 @@ add_subdirectory(src)
|
|||||||
target_sources(gpt4all-backend PUBLIC
|
target_sources(gpt4all-backend PUBLIC
|
||||||
FILE_SET public_headers TYPE HEADERS BASE_DIRS include FILES
|
FILE_SET public_headers TYPE HEADERS BASE_DIRS include FILES
|
||||||
include/gpt4all-backend/formatters.h
|
include/gpt4all-backend/formatters.h
|
||||||
|
include/gpt4all-backend/generation-params.h
|
||||||
include/gpt4all-backend/json-helpers.h
|
include/gpt4all-backend/json-helpers.h
|
||||||
include/gpt4all-backend/ollama-client.h
|
include/gpt4all-backend/ollama-client.h
|
||||||
|
include/gpt4all-backend/ollama-model.h
|
||||||
include/gpt4all-backend/ollama-types.h
|
include/gpt4all-backend/ollama-types.h
|
||||||
|
include/gpt4all-backend/rest.h
|
||||||
)
|
)
|
||||||
|
22
gpt4all-backend/include/gpt4all-backend/generation-params.h
Normal file
22
gpt4all-backend/include/gpt4all-backend/generation-params.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QtTypes>
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::backend {
|
||||||
|
|
||||||
|
|
||||||
|
struct GenerationParams {
|
||||||
|
uint n_predict;
|
||||||
|
float temperature;
|
||||||
|
float top_p;
|
||||||
|
// int32_t top_k = 40;
|
||||||
|
// float min_p = 0.0f;
|
||||||
|
// int32_t n_batch = 9;
|
||||||
|
// float repeat_penalty = 1.10f;
|
||||||
|
// int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||||
|
// float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::backend
|
@ -22,6 +22,7 @@ class QRestReply;
|
|||||||
|
|
||||||
namespace gpt4all::backend {
|
namespace gpt4all::backend {
|
||||||
|
|
||||||
|
|
||||||
struct ResponseError {
|
struct ResponseError {
|
||||||
public:
|
public:
|
||||||
struct BadStatus { int code; };
|
struct BadStatus { int code; };
|
||||||
@ -52,7 +53,7 @@ using DataOrRespErr = std::expected<T, ResponseError>;
|
|||||||
|
|
||||||
class OllamaClient {
|
class OllamaClient {
|
||||||
public:
|
public:
|
||||||
OllamaClient(QUrl baseUrl, QString m_userAgent = QStringLiteral("GPT4All"))
|
OllamaClient(QUrl baseUrl, QString m_userAgent)
|
||||||
: m_baseUrl(baseUrl)
|
: m_baseUrl(baseUrl)
|
||||||
, m_userAgent(std::move(m_userAgent))
|
, m_userAgent(std::move(m_userAgent))
|
||||||
{}
|
{}
|
||||||
|
13
gpt4all-backend/include/gpt4all-backend/ollama-model.h
Normal file
13
gpt4all-backend/include/gpt4all-backend/ollama-model.h
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace gpt4all::backend {
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaClient;
|
||||||
|
|
||||||
|
class OllamaModel {
|
||||||
|
OllamaClient *client;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::backend
|
13
gpt4all-backend/include/gpt4all-backend/rest.h
Normal file
13
gpt4all-backend/include/gpt4all-backend/rest.h
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
class QRestReply;
|
||||||
|
class QString;
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::backend {
|
||||||
|
|
||||||
|
|
||||||
|
QString restErrorString(const QRestReply &reply);
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::backend
|
@ -5,6 +5,7 @@ add_library(${TARGET} STATIC
|
|||||||
ollama-client.cpp
|
ollama-client.cpp
|
||||||
ollama-types.cpp
|
ollama-types.cpp
|
||||||
qt-json-stream.cpp
|
qt-json-stream.cpp
|
||||||
|
rest.cpp
|
||||||
)
|
)
|
||||||
target_compile_features(${TARGET} PUBLIC cxx_std_23)
|
target_compile_features(${TARGET} PUBLIC cxx_std_23)
|
||||||
gpt4all_add_warning_options(${TARGET})
|
gpt4all_add_warning_options(${TARGET})
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "json-helpers.h" // IWYU pragma: keep
|
#include "json-helpers.h" // IWYU pragma: keep
|
||||||
#include "qt-json-stream.h"
|
#include "qt-json-stream.h"
|
||||||
|
#include "rest.h"
|
||||||
|
|
||||||
#include <QCoro/QCoroIODevice> // IWYU pragma: keep
|
#include <QCoro/QCoroIODevice> // IWYU pragma: keep
|
||||||
#include <QCoro/QCoroNetworkReply> // IWYU pragma: keep
|
#include <QCoro/QCoroNetworkReply> // IWYU pragma: keep
|
||||||
@ -26,22 +27,14 @@ namespace gpt4all::backend {
|
|||||||
|
|
||||||
ResponseError::ResponseError(const QRestReply *reply)
|
ResponseError::ResponseError(const QRestReply *reply)
|
||||||
{
|
{
|
||||||
auto *nr = reply->networkReply();
|
|
||||||
if (reply->hasError()) {
|
if (reply->hasError()) {
|
||||||
error = nr->error();
|
error = reply->networkReply()->error();
|
||||||
errorString = nr->errorString();
|
|
||||||
} else if (!reply->isHttpStatusSuccess()) {
|
} else if (!reply->isHttpStatusSuccess()) {
|
||||||
auto code = reply->httpStatus();
|
error = BadStatus(reply->httpStatus());
|
||||||
auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute);
|
|
||||||
error = BadStatus(code);
|
|
||||||
errorString = u"HTTP %1%2%3 for URL \"%4\""_s.arg(
|
|
||||||
QString::number(code),
|
|
||||||
reason.isValid() ? u" "_s : QString(),
|
|
||||||
reason.toString(),
|
|
||||||
nr->request().url().toString()
|
|
||||||
);
|
|
||||||
} else
|
} else
|
||||||
Q_UNREACHABLE();
|
Q_UNREACHABLE();
|
||||||
|
|
||||||
|
errorString = restErrorString(*reply);
|
||||||
}
|
}
|
||||||
|
|
||||||
QNetworkRequest OllamaClient::makeRequest(const QString &path) const
|
QNetworkRequest OllamaClient::makeRequest(const QString &path) const
|
||||||
|
34
gpt4all-backend/src/rest.cpp
Normal file
34
gpt4all-backend/src/rest.cpp
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#include "rest.h"
|
||||||
|
|
||||||
|
#include <QNetworkReply>
|
||||||
|
#include <QRestReply>
|
||||||
|
#include <QString>
|
||||||
|
|
||||||
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::backend {
|
||||||
|
|
||||||
|
|
||||||
|
QString restErrorString(const QRestReply &reply)
|
||||||
|
{
|
||||||
|
auto *nr = reply.networkReply();
|
||||||
|
if (reply.hasError())
|
||||||
|
return nr->errorString();
|
||||||
|
|
||||||
|
if (!reply.isHttpStatusSuccess()) {
|
||||||
|
auto code = reply.httpStatus();
|
||||||
|
auto reason = nr->attribute(QNetworkRequest::HttpReasonPhraseAttribute);
|
||||||
|
return u"HTTP %1%2%3 for URL \"%4\""_s.arg(
|
||||||
|
QString::number(code),
|
||||||
|
reason.isValid() ? u" "_s : QString(),
|
||||||
|
reason.toString(),
|
||||||
|
nr->request().url().toString()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_UNREACHABLE();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::backend
|
@ -227,9 +227,10 @@ if (APPLE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
qt_add_executable(chat
|
qt_add_executable(chat
|
||||||
|
src/llmodel/provider.cpp src/llmodel/provider.h
|
||||||
|
src/llmodel/openai.cpp src/llmodel/openai.h
|
||||||
src/main.cpp
|
src/main.cpp
|
||||||
src/chat.cpp src/chat.h
|
src/chat.cpp src/chat.h
|
||||||
src/chatapi.cpp src/chatapi.h
|
|
||||||
src/chatlistmodel.cpp src/chatlistmodel.h
|
src/chatlistmodel.cpp src/chatlistmodel.h
|
||||||
src/chatllm.cpp src/chatllm.h
|
src/chatllm.cpp src/chatllm.h
|
||||||
src/chatmodel.h src/chatmodel.cpp
|
src/chatmodel.h src/chatmodel.cpp
|
||||||
@ -456,8 +457,16 @@ else()
|
|||||||
# Link PDFium
|
# Link PDFium
|
||||||
target_link_libraries(chat PRIVATE pdfium)
|
target_link_libraries(chat PRIVATE pdfium)
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(chat
|
target_link_libraries(chat PRIVATE
|
||||||
PRIVATE gpt4all-backend llmodel nlohmann_json::nlohmann_json SingleApplication fmt::fmt duckx::duckx QXlsx)
|
QCoro6::Core QCoro6::Network
|
||||||
|
QXlsx
|
||||||
|
SingleApplication
|
||||||
|
duckx::duckx
|
||||||
|
fmt::fmt
|
||||||
|
gpt4all-backend
|
||||||
|
llmodel
|
||||||
|
nlohmann_json::nlohmann_json
|
||||||
|
)
|
||||||
target_include_directories(chat PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/deps/minja/include)
|
target_include_directories(chat PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/deps/minja/include)
|
||||||
|
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
|
@ -412,21 +412,6 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
|
|||||||
emit tokenSpeedChanged();
|
emit tokenSpeedChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
QString Chat::deviceBackend() const
|
|
||||||
{
|
|
||||||
return m_llmodel->deviceBackend();
|
|
||||||
}
|
|
||||||
|
|
||||||
QString Chat::device() const
|
|
||||||
{
|
|
||||||
return m_llmodel->device();
|
|
||||||
}
|
|
||||||
|
|
||||||
QString Chat::fallbackReason() const
|
|
||||||
{
|
|
||||||
return m_llmodel->fallbackReason();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
||||||
{
|
{
|
||||||
m_databaseResults = results;
|
m_databaseResults = results;
|
||||||
|
@ -39,9 +39,6 @@ class Chat : public QObject
|
|||||||
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
|
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
|
||||||
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
|
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
|
||||||
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged)
|
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged)
|
||||||
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
|
|
||||||
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
|
|
||||||
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
|
|
||||||
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
|
Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged)
|
||||||
// 0=no, 1=waiting, 2=working
|
// 0=no, 1=waiting, 2=working
|
||||||
Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
|
Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged)
|
||||||
@ -122,10 +119,6 @@ public:
|
|||||||
QString modelLoadingError() const { return m_modelLoadingError; }
|
QString modelLoadingError() const { return m_modelLoadingError; }
|
||||||
|
|
||||||
QString tokenSpeed() const { return m_tokenSpeed; }
|
QString tokenSpeed() const { return m_tokenSpeed; }
|
||||||
QString deviceBackend() const;
|
|
||||||
QString device() const;
|
|
||||||
// not loaded -> QString(), no fallback -> QString("")
|
|
||||||
QString fallbackReason() const;
|
|
||||||
|
|
||||||
int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }
|
int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; }
|
||||||
|
|
||||||
@ -159,8 +152,6 @@ Q_SIGNALS:
|
|||||||
void isServerChanged();
|
void isServerChanged();
|
||||||
void collectionListChanged(const QList<QString> &collectionList);
|
void collectionListChanged(const QList<QString> &collectionList);
|
||||||
void tokenSpeedChanged();
|
void tokenSpeedChanged();
|
||||||
void deviceChanged();
|
|
||||||
void fallbackReasonChanged();
|
|
||||||
void collectionModelChanged();
|
void collectionModelChanged();
|
||||||
void trySwitchContextInProgressChanged();
|
void trySwitchContextInProgressChanged();
|
||||||
void loadedModelInfoChanged();
|
void loadedModelInfoChanged();
|
||||||
@ -192,8 +183,6 @@ private:
|
|||||||
ModelInfo m_modelInfo;
|
ModelInfo m_modelInfo;
|
||||||
QString m_modelLoadingError;
|
QString m_modelLoadingError;
|
||||||
QString m_tokenSpeed;
|
QString m_tokenSpeed;
|
||||||
QString m_device;
|
|
||||||
QString m_fallbackReason;
|
|
||||||
QList<QString> m_collections;
|
QList<QString> m_collections;
|
||||||
QList<QString> m_generatedQuestions;
|
QList<QString> m_generatedQuestions;
|
||||||
ChatModel *m_chatModel;
|
ChatModel *m_chatModel;
|
||||||
|
@ -1,359 +0,0 @@
|
|||||||
#include "chatapi.h"
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
#include <fmt/format.h>
|
|
||||||
#include <gpt4all-backend/formatters.h>
|
|
||||||
|
|
||||||
#include <QAnyStringView>
|
|
||||||
#include <QCoreApplication>
|
|
||||||
#include <QDebug>
|
|
||||||
#include <QGuiApplication>
|
|
||||||
#include <QJsonArray>
|
|
||||||
#include <QJsonDocument>
|
|
||||||
#include <QJsonObject>
|
|
||||||
#include <QJsonValue>
|
|
||||||
#include <QLatin1String>
|
|
||||||
#include <QNetworkAccessManager>
|
|
||||||
#include <QNetworkRequest>
|
|
||||||
#include <QStringView>
|
|
||||||
#include <QThread>
|
|
||||||
#include <QUrl>
|
|
||||||
#include <QUtf8StringView> // IWYU pragma: keep
|
|
||||||
#include <QVariant>
|
|
||||||
#include <QXmlStreamReader>
|
|
||||||
#include <Qt>
|
|
||||||
#include <QtAssert>
|
|
||||||
#include <QtLogging>
|
|
||||||
|
|
||||||
#include <expected>
|
|
||||||
#include <functional>
|
|
||||||
#include <iostream>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
using namespace Qt::Literals::StringLiterals;
|
|
||||||
|
|
||||||
//#define DEBUG
|
|
||||||
|
|
||||||
|
|
||||||
ChatAPI::ChatAPI()
|
|
||||||
: QObject(nullptr)
|
|
||||||
, m_modelName("gpt-3.5-turbo")
|
|
||||||
, m_requestURL("")
|
|
||||||
, m_responseCallback(nullptr)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ChatAPI::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
|
|
||||||
{
|
|
||||||
Q_UNUSED(modelPath);
|
|
||||||
Q_UNUSED(n_ctx);
|
|
||||||
Q_UNUSED(ngl);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
|
||||||
{
|
|
||||||
Q_UNUSED(modelPath);
|
|
||||||
Q_UNUSED(n_ctx);
|
|
||||||
Q_UNUSED(ngl);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ChatAPI::setThreadCount(int32_t n_threads)
|
|
||||||
{
|
|
||||||
Q_UNUSED(n_threads);
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t ChatAPI::threadCount() const
|
|
||||||
{
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
ChatAPI::~ChatAPI()
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ChatAPI::isModelLoaded() const
|
|
||||||
{
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QString>
|
|
||||||
{
|
|
||||||
QJsonArray messages;
|
|
||||||
|
|
||||||
auto xmlError = [&xml] {
|
|
||||||
return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString()));
|
|
||||||
};
|
|
||||||
|
|
||||||
if (xml.hasError())
|
|
||||||
return xmlError();
|
|
||||||
if (xml.atEnd())
|
|
||||||
return messages;
|
|
||||||
|
|
||||||
// skip header
|
|
||||||
bool foundElement = false;
|
|
||||||
do {
|
|
||||||
switch (xml.readNext()) {
|
|
||||||
using enum QXmlStreamReader::TokenType;
|
|
||||||
case Invalid:
|
|
||||||
return xmlError();
|
|
||||||
case EndDocument:
|
|
||||||
return messages;
|
|
||||||
default:
|
|
||||||
foundElement = true;
|
|
||||||
case StartDocument:
|
|
||||||
case Comment:
|
|
||||||
case DTD:
|
|
||||||
case ProcessingInstruction:
|
|
||||||
;
|
|
||||||
}
|
|
||||||
} while (!foundElement);
|
|
||||||
|
|
||||||
// document body loop
|
|
||||||
bool foundRoot = false;
|
|
||||||
for (;;) {
|
|
||||||
switch (xml.tokenType()) {
|
|
||||||
using enum QXmlStreamReader::TokenType;
|
|
||||||
case StartElement:
|
|
||||||
{
|
|
||||||
auto name = xml.name();
|
|
||||||
if (!foundRoot) {
|
|
||||||
if (name != "chat"_L1)
|
|
||||||
return std::unexpected(u"unexpected tag: %1"_s.arg(name));
|
|
||||||
foundRoot = true;
|
|
||||||
} else {
|
|
||||||
if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1)
|
|
||||||
return std::unexpected(u"unknown role: %1"_s.arg(name));
|
|
||||||
auto content = xml.readElementText();
|
|
||||||
if (xml.tokenType() != EndElement)
|
|
||||||
return xmlError();
|
|
||||||
messages << makeJsonObject({
|
|
||||||
{ "role"_L1, name.toString().trimmed() },
|
|
||||||
{ "content"_L1, content },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case Characters:
|
|
||||||
if (!xml.isWhitespace())
|
|
||||||
return std::unexpected(u"unexpected text: %1"_s.arg(xml.text()));
|
|
||||||
case Comment:
|
|
||||||
case ProcessingInstruction:
|
|
||||||
case EndElement:
|
|
||||||
break;
|
|
||||||
case EndDocument:
|
|
||||||
return messages;
|
|
||||||
case Invalid:
|
|
||||||
return xmlError();
|
|
||||||
default:
|
|
||||||
return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString()));
|
|
||||||
}
|
|
||||||
xml.readNext();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ChatAPI::prompt(
|
|
||||||
std::string_view prompt,
|
|
||||||
const PromptCallback &promptCallback,
|
|
||||||
const ResponseCallback &responseCallback,
|
|
||||||
const PromptContext &promptCtx
|
|
||||||
) {
|
|
||||||
Q_UNUSED(promptCallback)
|
|
||||||
|
|
||||||
if (!isModelLoaded())
|
|
||||||
throw std::invalid_argument("Attempted to prompt an unloaded model.");
|
|
||||||
if (!promptCtx.n_predict)
|
|
||||||
return; // nothing requested
|
|
||||||
|
|
||||||
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
|
|
||||||
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
|
|
||||||
// the OpenAI tiktoken library or to implement our own tokenization function that matches precisely
|
|
||||||
// the tokenization used by the OpenAI model we're calling. OpenAI has not introduced any means of
|
|
||||||
// using the REST API to count tokens in a prompt.
|
|
||||||
auto root = makeJsonObject({
|
|
||||||
{ "model"_L1, m_modelName },
|
|
||||||
{ "stream"_L1, true },
|
|
||||||
{ "temperature"_L1, promptCtx.temp },
|
|
||||||
{ "top_p"_L1, promptCtx.top_p },
|
|
||||||
});
|
|
||||||
|
|
||||||
// conversation history
|
|
||||||
{
|
|
||||||
QUtf8StringView promptUtf8(prompt);
|
|
||||||
QXmlStreamReader xml(promptUtf8);
|
|
||||||
auto messages = parsePrompt(xml);
|
|
||||||
if (!messages) {
|
|
||||||
auto error = fmt::format("Failed to parse API model prompt: {}", messages.error());
|
|
||||||
qDebug().noquote() << "ChatAPI ERROR:" << error << "Prompt:\n\n" << promptUtf8 << '\n';
|
|
||||||
throw std::invalid_argument(error);
|
|
||||||
}
|
|
||||||
root.insert("messages"_L1, *messages);
|
|
||||||
}
|
|
||||||
|
|
||||||
QJsonDocument doc(root);
|
|
||||||
|
|
||||||
#if defined(DEBUG)
|
|
||||||
qDebug().noquote() << "ChatAPI::prompt begin network request" << doc.toJson();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
m_responseCallback = responseCallback;
|
|
||||||
|
|
||||||
// The following code sets up a worker thread and object to perform the actual api request to
|
|
||||||
// chatgpt and then blocks until it is finished
|
|
||||||
QThread workerThread;
|
|
||||||
ChatAPIWorker worker(this);
|
|
||||||
worker.moveToThread(&workerThread);
|
|
||||||
connect(&worker, &ChatAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
|
|
||||||
connect(this, &ChatAPI::request, &worker, &ChatAPIWorker::request, Qt::QueuedConnection);
|
|
||||||
workerThread.start();
|
|
||||||
emit request(m_apiKey, doc.toJson(QJsonDocument::Compact));
|
|
||||||
workerThread.wait();
|
|
||||||
|
|
||||||
m_responseCallback = nullptr;
|
|
||||||
|
|
||||||
#if defined(DEBUG)
|
|
||||||
qDebug() << "ChatAPI::prompt end network request";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ChatAPI::callResponse(int32_t token, const std::string& string)
|
|
||||||
{
|
|
||||||
Q_ASSERT(m_responseCallback);
|
|
||||||
if (!m_responseCallback) {
|
|
||||||
std::cerr << "ChatAPI ERROR: no response callback!\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return m_responseCallback(token, string);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ChatAPIWorker::request(const QString &apiKey, const QByteArray &array)
|
|
||||||
{
|
|
||||||
QUrl apiUrl(m_chat->url());
|
|
||||||
const QString authorization = u"Bearer %1"_s.arg(apiKey).trimmed();
|
|
||||||
QNetworkRequest request(apiUrl);
|
|
||||||
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
|
||||||
request.setRawHeader("Authorization", authorization.toUtf8());
|
|
||||||
#if defined(DEBUG)
|
|
||||||
qDebug() << "ChatAPI::request"
|
|
||||||
<< "API URL: " << apiUrl.toString()
|
|
||||||
<< "Authorization: " << authorization.toUtf8();
|
|
||||||
#endif
|
|
||||||
m_networkManager = new QNetworkAccessManager(this);
|
|
||||||
QNetworkReply *reply = m_networkManager->post(request, array);
|
|
||||||
connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
|
|
||||||
connect(reply, &QNetworkReply::finished, this, &ChatAPIWorker::handleFinished);
|
|
||||||
connect(reply, &QNetworkReply::readyRead, this, &ChatAPIWorker::handleReadyRead);
|
|
||||||
connect(reply, &QNetworkReply::errorOccurred, this, &ChatAPIWorker::handleErrorOccurred);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ChatAPIWorker::handleFinished()
|
|
||||||
{
|
|
||||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
|
||||||
if (!reply) {
|
|
||||||
emit finished();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
|
||||||
|
|
||||||
if (!response.isValid()) {
|
|
||||||
m_chat->callResponse(
|
|
||||||
-1,
|
|
||||||
tr("ERROR: Network error occurred while connecting to the API server")
|
|
||||||
.toStdString()
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ok;
|
|
||||||
int code = response.toInt(&ok);
|
|
||||||
if (!ok || code != 200) {
|
|
||||||
bool isReplyEmpty(reply->readAll().isEmpty());
|
|
||||||
if (isReplyEmpty)
|
|
||||||
m_chat->callResponse(
|
|
||||||
-1,
|
|
||||||
tr("ChatAPIWorker::handleFinished got HTTP Error %1 %2")
|
|
||||||
.arg(code)
|
|
||||||
.arg(reply->errorString())
|
|
||||||
.toStdString()
|
|
||||||
);
|
|
||||||
qWarning().noquote() << "ERROR: ChatAPIWorker::handleFinished got HTTP Error" << code << "response:"
|
|
||||||
<< reply->errorString();
|
|
||||||
}
|
|
||||||
reply->deleteLater();
|
|
||||||
emit finished();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ChatAPIWorker::handleReadyRead()
|
|
||||||
{
|
|
||||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
|
||||||
if (!reply) {
|
|
||||||
emit finished();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
|
||||||
|
|
||||||
if (!response.isValid())
|
|
||||||
return;
|
|
||||||
|
|
||||||
bool ok;
|
|
||||||
int code = response.toInt(&ok);
|
|
||||||
if (!ok || code != 200) {
|
|
||||||
m_chat->callResponse(
|
|
||||||
-1,
|
|
||||||
u"ERROR: ChatAPIWorker::handleReadyRead got HTTP Error %1 %2: %3"_s
|
|
||||||
.arg(code).arg(reply->errorString(), reply->readAll()).toStdString()
|
|
||||||
);
|
|
||||||
emit finished();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (reply->canReadLine()) {
|
|
||||||
QString jsonData = reply->readLine().trimmed();
|
|
||||||
if (jsonData.startsWith("data:"))
|
|
||||||
jsonData.remove(0, 5);
|
|
||||||
jsonData = jsonData.trimmed();
|
|
||||||
if (jsonData.isEmpty())
|
|
||||||
continue;
|
|
||||||
if (jsonData == "[DONE]")
|
|
||||||
continue;
|
|
||||||
#if defined(DEBUG)
|
|
||||||
qDebug().noquote() << "line" << jsonData;
|
|
||||||
#endif
|
|
||||||
QJsonParseError err;
|
|
||||||
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
|
|
||||||
if (err.error != QJsonParseError::NoError) {
|
|
||||||
m_chat->callResponse(-1, u"ERROR: ChatAPI responded with invalid json \"%1\""_s
|
|
||||||
.arg(err.errorString()).toStdString());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const QJsonObject root = document.object();
|
|
||||||
const QJsonArray choices = root.value("choices").toArray();
|
|
||||||
const QJsonObject choice = choices.first().toObject();
|
|
||||||
const QJsonObject delta = choice.value("delta").toObject();
|
|
||||||
const QString content = delta.value("content").toString();
|
|
||||||
m_currentResponse += content;
|
|
||||||
if (!m_chat->callResponse(0, content.toStdString())) {
|
|
||||||
reply->abort();
|
|
||||||
emit finished();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ChatAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
|
|
||||||
{
|
|
||||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
|
||||||
if (!reply || reply->error() == QNetworkReply::OperationCanceledError /*when we call abort on purpose*/) {
|
|
||||||
emit finished();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
qWarning().noquote() << "ERROR: ChatAPIWorker::handleErrorOccurred got HTTP Error" << code << "response:"
|
|
||||||
<< reply->errorString();
|
|
||||||
emit finished();
|
|
||||||
}
|
|
@ -1,173 +0,0 @@
|
|||||||
#ifndef CHATAPI_H
|
|
||||||
#define CHATAPI_H
|
|
||||||
|
|
||||||
#include <gpt4all-backend/llmodel.h>
|
|
||||||
|
|
||||||
#include <QByteArray>
|
|
||||||
#include <QNetworkReply>
|
|
||||||
#include <QObject>
|
|
||||||
#include <QString>
|
|
||||||
#include <QtPreprocessorSupport>
|
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <span>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
// IWYU pragma: no_forward_declare QByteArray
|
|
||||||
class ChatAPI;
|
|
||||||
class QNetworkAccessManager;
|
|
||||||
|
|
||||||
|
|
||||||
class ChatAPIWorker : public QObject {
|
|
||||||
Q_OBJECT
|
|
||||||
public:
|
|
||||||
ChatAPIWorker(ChatAPI *chatAPI)
|
|
||||||
: QObject(nullptr)
|
|
||||||
, m_networkManager(nullptr)
|
|
||||||
, m_chat(chatAPI) {}
|
|
||||||
virtual ~ChatAPIWorker() {}
|
|
||||||
|
|
||||||
QString currentResponse() const { return m_currentResponse; }
|
|
||||||
|
|
||||||
void request(const QString &apiKey, const QByteArray &array);
|
|
||||||
|
|
||||||
Q_SIGNALS:
|
|
||||||
void finished();
|
|
||||||
|
|
||||||
private Q_SLOTS:
|
|
||||||
void handleFinished();
|
|
||||||
void handleReadyRead();
|
|
||||||
void handleErrorOccurred(QNetworkReply::NetworkError code);
|
|
||||||
|
|
||||||
private:
|
|
||||||
ChatAPI *m_chat;
|
|
||||||
QNetworkAccessManager *m_networkManager;
|
|
||||||
QString m_currentResponse;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ChatAPI : public QObject, public LLModel {
|
|
||||||
Q_OBJECT
|
|
||||||
public:
|
|
||||||
ChatAPI();
|
|
||||||
virtual ~ChatAPI();
|
|
||||||
|
|
||||||
bool supportsEmbedding() const override { return false; }
|
|
||||||
bool supportsCompletion() const override { return true; }
|
|
||||||
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
|
|
||||||
bool isModelLoaded() const override;
|
|
||||||
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
|
|
||||||
|
|
||||||
// All three of the state virtual functions are handled custom inside of chatllm save/restore
|
|
||||||
size_t stateSize() const override
|
|
||||||
{ throwNotImplemented(); }
|
|
||||||
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override
|
|
||||||
{ Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); }
|
|
||||||
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
|
|
||||||
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
|
|
||||||
|
|
||||||
void prompt(std::string_view prompt,
|
|
||||||
const PromptCallback &promptCallback,
|
|
||||||
const ResponseCallback &responseCallback,
|
|
||||||
const PromptContext &ctx) override;
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
int32_t countPromptTokens(std::string_view prompt) const override
|
|
||||||
{ Q_UNUSED(prompt); throwNotImplemented(); }
|
|
||||||
|
|
||||||
void setThreadCount(int32_t n_threads) override;
|
|
||||||
int32_t threadCount() const override;
|
|
||||||
|
|
||||||
void setModelName(const QString &modelName) { m_modelName = modelName; }
|
|
||||||
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
|
|
||||||
void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; }
|
|
||||||
QString url() const { return m_requestURL; }
|
|
||||||
|
|
||||||
bool callResponse(int32_t token, const std::string &string);
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
int32_t contextLength() const override
|
|
||||||
{ throwNotImplemented(); }
|
|
||||||
|
|
||||||
auto specialTokens() -> std::unordered_map<std::string, std::string> const override
|
|
||||||
{ return {}; }
|
|
||||||
|
|
||||||
Q_SIGNALS:
|
|
||||||
void request(const QString &apiKey, const QByteArray &array);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
|
||||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
|
||||||
// completely replace
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
std::vector<Token> tokenize(std::string_view str) const override
|
|
||||||
{ Q_UNUSED(str); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
bool isSpecialToken(Token id) const override
|
|
||||||
{ Q_UNUSED(id); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
std::string tokenToString(Token id) const override
|
|
||||||
{ Q_UNUSED(id); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
void initSampler(const PromptContext &ctx) override
|
|
||||||
{ Q_UNUSED(ctx); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
Token sampleToken() const override
|
|
||||||
{ throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
bool evalTokens(int32_t nPast, std::span<const Token> tokens) const override
|
|
||||||
{ Q_UNUSED(nPast); Q_UNUSED(tokens); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
void shiftContext(const PromptContext &promptCtx, int32_t *nPast) override
|
|
||||||
{ Q_UNUSED(promptCtx); Q_UNUSED(nPast); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
int32_t inputLength() const override
|
|
||||||
{ throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
int32_t computeModelInputPosition(std::span<const Token> input) const override
|
|
||||||
{ Q_UNUSED(input); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
void setModelInputPosition(int32_t pos) override
|
|
||||||
{ Q_UNUSED(pos); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
void appendInputToken(Token tok) override
|
|
||||||
{ Q_UNUSED(tok); throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
const std::vector<Token> &endTokens() const override
|
|
||||||
{ throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
bool shouldAddBOS() const override
|
|
||||||
{ throwNotImplemented(); }
|
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
std::span<const Token> inputTokens() const override
|
|
||||||
{ throwNotImplemented(); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
ResponseCallback m_responseCallback;
|
|
||||||
QString m_modelName;
|
|
||||||
QString m_apiKey;
|
|
||||||
QString m_requestURL;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // CHATAPI_H
|
|
@ -1,17 +1,19 @@
|
|||||||
#include "chatllm.h"
|
#include "chatllm.h"
|
||||||
|
|
||||||
#include "chat.h"
|
#include "chat.h"
|
||||||
#include "chatapi.h"
|
|
||||||
#include "chatmodel.h"
|
#include "chatmodel.h"
|
||||||
#include "jinja_helpers.h"
|
#include "jinja_helpers.h"
|
||||||
|
#include "llmodel/chat.h"
|
||||||
|
#include "llmodel/openai.h"
|
||||||
#include "localdocs.h"
|
#include "localdocs.h"
|
||||||
#include "mysettings.h"
|
#include "mysettings.h"
|
||||||
#include "network.h"
|
#include "network.h"
|
||||||
#include "tool.h"
|
#include "tool.h"
|
||||||
#include "toolmodel.h"
|
|
||||||
#include "toolcallparser.h"
|
#include "toolcallparser.h"
|
||||||
|
#include "toolmodel.h"
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
#include <gpt4all-backend/generation-params.h>
|
||||||
#include <minja/minja.hpp>
|
#include <minja/minja.hpp>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
@ -64,6 +66,8 @@
|
|||||||
|
|
||||||
using namespace Qt::Literals::StringLiterals;
|
using namespace Qt::Literals::StringLiterals;
|
||||||
using namespace ToolEnums;
|
using namespace ToolEnums;
|
||||||
|
using namespace gpt4all;
|
||||||
|
using namespace gpt4all::ui;
|
||||||
namespace ranges = std::ranges;
|
namespace ranges = std::ranges;
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
@ -115,14 +119,12 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
static auto promptModelWithTools(
|
static auto promptModelWithTools(
|
||||||
LLModel *model, const LLModel::PromptCallback &promptCallback, BaseResponseHandler &respHandler,
|
ChatLLModel *model, BaseResponseHandler &respHandler, const backend::GenerationParams ¶ms,
|
||||||
const LLModel::PromptContext &ctx, const QByteArray &prompt, const QStringList &toolNames
|
const QByteArray &prompt, const QStringList &toolNames
|
||||||
) -> std::pair<QStringList, bool>
|
) -> std::pair<QStringList, bool>
|
||||||
{
|
{
|
||||||
ToolCallParser toolCallParser(toolNames);
|
ToolCallParser toolCallParser(toolNames);
|
||||||
auto handleResponse = [&toolCallParser, &respHandler](LLModel::Token token, std::string_view piece) -> bool {
|
auto handleResponse = [&toolCallParser, &respHandler](std::string_view piece) -> bool {
|
||||||
Q_UNUSED(token)
|
|
||||||
|
|
||||||
toolCallParser.update(piece.data());
|
toolCallParser.update(piece.data());
|
||||||
|
|
||||||
// Split the response into two if needed
|
// Split the response into two if needed
|
||||||
@ -157,7 +159,7 @@ static auto promptModelWithTools(
|
|||||||
|
|
||||||
return !shouldExecuteToolCall && !respHandler.getStopGenerating();
|
return !shouldExecuteToolCall && !respHandler.getStopGenerating();
|
||||||
};
|
};
|
||||||
model->prompt(std::string_view(prompt), promptCallback, handleResponse, ctx);
|
model->prompt(std::string_view(prompt), promptCallback, handleResponse, params);
|
||||||
|
|
||||||
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
|
const bool shouldExecuteToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete
|
||||||
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
|
&& toolCallParser.startTag() != ToolCallConstants::ThinkStartTag;
|
||||||
@ -217,9 +219,8 @@ void LLModelStore::destroy()
|
|||||||
m_availableModel.reset();
|
m_availableModel.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) {
|
void LLModelInfo::resetModel(ChatLLM *cllm, ChatLLModel *model) {
|
||||||
this->model.reset(model);
|
this->model.reset(model);
|
||||||
fallbackReason.reset();
|
|
||||||
emit cllm->loadedModelInfoChanged();
|
emit cllm->loadedModelInfoChanged();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,8 +233,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
|||||||
, m_stopGenerating(false)
|
, m_stopGenerating(false)
|
||||||
, m_timer(nullptr)
|
, m_timer(nullptr)
|
||||||
, m_isServer(isServer)
|
, m_isServer(isServer)
|
||||||
, m_forceMetal(MySettings::globalInstance()->forceMetal())
|
|
||||||
, m_reloadingToChangeVariant(false)
|
|
||||||
, m_chatModel(parent->chatModel())
|
, m_chatModel(parent->chatModel())
|
||||||
{
|
{
|
||||||
moveToThread(&m_llmThread);
|
moveToThread(&m_llmThread);
|
||||||
@ -243,8 +242,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
|||||||
Qt::QueuedConnection); // explicitly queued
|
Qt::QueuedConnection); // explicitly queued
|
||||||
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
||||||
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
|
|
||||||
connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged);
|
|
||||||
|
|
||||||
// The following are blocking operations and will block the llm thread
|
// The following are blocking operations and will block the llm thread
|
||||||
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
||||||
@ -284,31 +281,6 @@ void ChatLLM::handleThreadStarted()
|
|||||||
emit threadStarted();
|
emit threadStarted();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatLLM::handleForceMetalChanged(bool forceMetal)
|
|
||||||
{
|
|
||||||
#if defined(Q_OS_MAC) && defined(__aarch64__)
|
|
||||||
m_forceMetal = forceMetal;
|
|
||||||
if (isModelLoaded() && m_shouldBeLoaded) {
|
|
||||||
m_reloadingToChangeVariant = true;
|
|
||||||
unloadModel();
|
|
||||||
reloadModel();
|
|
||||||
m_reloadingToChangeVariant = false;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
Q_UNUSED(forceMetal);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void ChatLLM::handleDeviceChanged()
|
|
||||||
{
|
|
||||||
if (isModelLoaded() && m_shouldBeLoaded) {
|
|
||||||
m_reloadingToChangeVariant = true;
|
|
||||||
unloadModel();
|
|
||||||
reloadModel();
|
|
||||||
m_reloadingToChangeVariant = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ChatLLM::loadDefaultModel()
|
bool ChatLLM::loadDefaultModel()
|
||||||
{
|
{
|
||||||
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
|
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
|
||||||
@ -325,10 +297,9 @@ void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
|||||||
// and if so we just acquire it from the store and switch the context and return true. If the
|
// and if so we just acquire it from the store and switch the context and return true. If the
|
||||||
// store doesn't have it or we're already loaded or in any other case just return false.
|
// store doesn't have it or we're already loaded or in any other case just return false.
|
||||||
|
|
||||||
// If we're already loaded or a server or we're reloading to change the variant/device or the
|
// If we're already loaded or a server or the modelInfo is empty, then this should fail
|
||||||
// modelInfo is empty, then this should fail
|
|
||||||
if (
|
if (
|
||||||
isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded
|
isModelLoaded() || m_isServer || modelInfo.name().isEmpty() || !m_shouldBeLoaded
|
||||||
) {
|
) {
|
||||||
emit trySwitchContextOfLoadedModelCompleted(0);
|
emit trySwitchContextOfLoadedModelCompleted(0);
|
||||||
return;
|
return;
|
||||||
@ -409,7 +380,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the store just gave us exactly the model we were looking for
|
// Check if the store just gave us exactly the model we were looking for
|
||||||
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
|
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) {
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||||
#endif
|
#endif
|
||||||
@ -482,7 +453,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
|||||||
emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
|
emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
|
||||||
emit loadedModelInfoChanged();
|
emit loadedModelInfoChanged();
|
||||||
|
|
||||||
modelLoadProps.insert("requestedDevice", MySettings::globalInstance()->device());
|
|
||||||
modelLoadProps.insert("model", modelInfo.filename());
|
modelLoadProps.insert("model", modelInfo.filename());
|
||||||
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
|
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
|
||||||
} else {
|
} else {
|
||||||
@ -504,43 +474,17 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
|
|||||||
QElapsedTimer modelLoadTimer;
|
QElapsedTimer modelLoadTimer;
|
||||||
modelLoadTimer.start();
|
modelLoadTimer.start();
|
||||||
|
|
||||||
QString requestedDevice = MySettings::globalInstance()->device();
|
|
||||||
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
|
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
|
||||||
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
|
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
|
||||||
|
|
||||||
std::string backend = "auto";
|
std::string backend = "auto";
|
||||||
#ifdef Q_OS_MAC
|
|
||||||
if (requestedDevice == "CPU") {
|
|
||||||
backend = "cpu";
|
|
||||||
} else if (m_forceMetal) {
|
|
||||||
#ifdef __aarch64__
|
|
||||||
backend = "metal";
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
#else // !defined(Q_OS_MAC)
|
|
||||||
if (requestedDevice.startsWith("CUDA: "))
|
|
||||||
backend = "cuda";
|
|
||||||
#endif
|
|
||||||
|
|
||||||
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
||||||
|
|
||||||
auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx](std::string const &backend) {
|
auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx]() {
|
||||||
QString constructError;
|
QString constructError;
|
||||||
m_llModelInfo.resetModel(this);
|
m_llModelInfo.resetModel(this);
|
||||||
try {
|
auto *model = LLModel::Implementation::construct(filePath.toStdString(), "", n_ctx);
|
||||||
auto *model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx);
|
m_llModelInfo.resetModel(this, model);
|
||||||
m_llModelInfo.resetModel(this, model);
|
|
||||||
} catch (const LLModel::MissingImplementationError &e) {
|
|
||||||
modelLoadProps.insert("error", "missing_model_impl");
|
|
||||||
constructError = e.what();
|
|
||||||
} catch (const LLModel::UnsupportedModelError &e) {
|
|
||||||
modelLoadProps.insert("error", "unsupported_model_file");
|
|
||||||
constructError = e.what();
|
|
||||||
} catch (const LLModel::BadArchError &e) {
|
|
||||||
constructError = e.what();
|
|
||||||
modelLoadProps.insert("error", "unsupported_model_arch");
|
|
||||||
modelLoadProps.insert("model_arch", QString::fromStdString(e.arch()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!m_llModelInfo.model) {
|
if (!m_llModelInfo.model) {
|
||||||
if (!m_isServer)
|
if (!m_isServer)
|
||||||
@ -558,7 +502,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!construct(backend))
|
if (!construct())
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
|
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
|
||||||
@ -572,58 +516,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) {
|
|
||||||
float memGB = dev->heapSize / float(1024 * 1024 * 1024);
|
|
||||||
return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<LLModel::GPUDevice> availableDevices;
|
|
||||||
const LLModel::GPUDevice *defaultDevice = nullptr;
|
|
||||||
{
|
|
||||||
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx, ngl);
|
|
||||||
availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
|
|
||||||
// Pick the best device
|
|
||||||
// NB: relies on the fact that Kompute devices are listed first
|
|
||||||
if (!availableDevices.empty() && availableDevices.front().type == 2 /*a discrete gpu*/) {
|
|
||||||
defaultDevice = &availableDevices.front();
|
|
||||||
float memGB = defaultDevice->heapSize / float(1024 * 1024 * 1024);
|
|
||||||
memGB = std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place
|
|
||||||
modelLoadProps.insert("default_device", QString::fromStdString(defaultDevice->name));
|
|
||||||
modelLoadProps.insert("default_device_mem", approxDeviceMemGB(defaultDevice));
|
|
||||||
modelLoadProps.insert("default_device_backend", QString::fromStdString(defaultDevice->backendName()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool actualDeviceIsCPU = true;
|
|
||||||
|
|
||||||
#if defined(Q_OS_MAC) && defined(__aarch64__)
|
|
||||||
if (m_llModelInfo.model->implementation().buildVariant() == "metal")
|
|
||||||
actualDeviceIsCPU = false;
|
|
||||||
#else
|
|
||||||
if (requestedDevice != "CPU") {
|
|
||||||
const auto *device = defaultDevice;
|
|
||||||
if (requestedDevice != "Auto") {
|
|
||||||
// Use the selected device
|
|
||||||
for (const LLModel::GPUDevice &d : availableDevices) {
|
|
||||||
if (QString::fromStdString(d.selectionName()) == requestedDevice) {
|
|
||||||
device = &d;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string unavail_reason;
|
|
||||||
if (!device) {
|
|
||||||
// GPU not available
|
|
||||||
} else if (!m_llModelInfo.model->initializeGPUDevice(device->index, &unavail_reason)) {
|
|
||||||
m_llModelInfo.fallbackReason = QString::fromStdString(unavail_reason);
|
|
||||||
} else {
|
|
||||||
actualDeviceIsCPU = false;
|
|
||||||
modelLoadProps.insert("requested_device_mem", approxDeviceMemGB(device));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
|
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
|
||||||
|
|
||||||
if (!m_shouldBeLoaded) {
|
if (!m_shouldBeLoaded) {
|
||||||
@ -635,35 +527,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (actualDeviceIsCPU) {
|
|
||||||
// we asked llama.cpp to use the CPU
|
|
||||||
} else if (!success) {
|
|
||||||
// llama_init_from_file returned nullptr
|
|
||||||
m_llModelInfo.fallbackReason = "GPU loading failed (out of VRAM?)";
|
|
||||||
modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed");
|
|
||||||
|
|
||||||
// For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls
|
|
||||||
if (backend == "cuda" && !construct("auto"))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
|
|
||||||
|
|
||||||
if (!m_shouldBeLoaded) {
|
|
||||||
m_llModelInfo.resetModel(this);
|
|
||||||
if (!m_isServer)
|
|
||||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
|
||||||
resetModel();
|
|
||||||
emit modelLoadingPercentageChanged(0.0f);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else if (!m_llModelInfo.model->usingGPUDevice()) {
|
|
||||||
// ggml_vk_init was not called in llama.cpp
|
|
||||||
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
|
||||||
// for instance if the quantization method is not supported on Vulkan yet
|
|
||||||
m_llModelInfo.fallbackReason = "model or quant has no GPU support";
|
|
||||||
modelLoadProps.insert("cpu_fallback_reason", "gpu_unsupported_model");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!success) {
|
if (!success) {
|
||||||
m_llModelInfo.resetModel(this);
|
m_llModelInfo.resetModel(this);
|
||||||
if (!m_isServer)
|
if (!m_isServer)
|
||||||
@ -756,7 +619,7 @@ void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static LLModel::PromptContext promptContextFromSettings(const ModelInfo &modelInfo)
|
static backend::GenerationParams genParamsFromSettings(const ModelInfo &modelInfo)
|
||||||
{
|
{
|
||||||
auto *mySettings = MySettings::globalInstance();
|
auto *mySettings = MySettings::globalInstance();
|
||||||
return {
|
return {
|
||||||
@ -779,7 +642,7 @@ void ChatLLM::prompt(const QStringList &enabledCollections)
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
promptInternalChat(enabledCollections, promptContextFromSettings(m_modelInfo));
|
promptInternalChat(enabledCollections, genParamsFromSettings(m_modelInfo));
|
||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
// FIXME(jared): this is neither translated nor serialized
|
// FIXME(jared): this is neither translated nor serialized
|
||||||
m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what())));
|
m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what())));
|
||||||
@ -906,7 +769,7 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const MessageItem> items) cons
|
|||||||
Q_UNREACHABLE();
|
Q_UNREACHABLE();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx,
|
auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const backend::GenerationParams ¶ms,
|
||||||
qsizetype startOffset) -> ChatPromptResult
|
qsizetype startOffset) -> ChatPromptResult
|
||||||
{
|
{
|
||||||
Q_ASSERT(isModelLoaded());
|
Q_ASSERT(isModelLoaded());
|
||||||
@ -944,7 +807,7 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL
|
|||||||
auto messageItems = getChat();
|
auto messageItems = getChat();
|
||||||
messageItems.pop_back(); // exclude new response
|
messageItems.pop_back(); // exclude new response
|
||||||
|
|
||||||
auto result = promptInternal(messageItems, ctx, !databaseResults.isEmpty());
|
auto result = promptInternal(messageItems, params, !databaseResults.isEmpty());
|
||||||
return {
|
return {
|
||||||
/*PromptResult*/ {
|
/*PromptResult*/ {
|
||||||
.response = std::move(result.response),
|
.response = std::move(result.response),
|
||||||
@ -1014,7 +877,7 @@ private:
|
|||||||
|
|
||||||
auto ChatLLM::promptInternal(
|
auto ChatLLM::promptInternal(
|
||||||
const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
|
const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
|
||||||
const LLModel::PromptContext &ctx,
|
const backend::GenerationParams params,
|
||||||
bool usedLocalDocs
|
bool usedLocalDocs
|
||||||
) -> PromptResult
|
) -> PromptResult
|
||||||
{
|
{
|
||||||
@ -1052,13 +915,6 @@ auto ChatLLM::promptInternal(
|
|||||||
|
|
||||||
PromptResult result {};
|
PromptResult result {};
|
||||||
|
|
||||||
auto handlePrompt = [this, &result](std::span<const LLModel::Token> batch, bool cached) -> bool {
|
|
||||||
Q_UNUSED(cached)
|
|
||||||
result.promptTokens += batch.size();
|
|
||||||
m_timer->start();
|
|
||||||
return !m_stopGenerating;
|
|
||||||
};
|
|
||||||
|
|
||||||
QElapsedTimer totalTime;
|
QElapsedTimer totalTime;
|
||||||
totalTime.start();
|
totalTime.start();
|
||||||
ChatViewResponseHandler respHandler(this, &totalTime, &result);
|
ChatViewResponseHandler respHandler(this, &totalTime, &result);
|
||||||
@ -1070,8 +926,10 @@ auto ChatLLM::promptInternal(
|
|||||||
emit promptProcessing();
|
emit promptProcessing();
|
||||||
m_llModelInfo.model->setThreadCount(mySettings->threadCount());
|
m_llModelInfo.model->setThreadCount(mySettings->threadCount());
|
||||||
m_stopGenerating = false;
|
m_stopGenerating = false;
|
||||||
|
// TODO: set result.promptTokens based on ollama prompt_eval_count
|
||||||
|
// TODO: support interruption via m_stopGenerating
|
||||||
std::tie(finalBuffers, shouldExecuteTool) = promptModelWithTools(
|
std::tie(finalBuffers, shouldExecuteTool) = promptModelWithTools(
|
||||||
m_llModelInfo.model.get(), handlePrompt, respHandler, ctx,
|
m_llModelInfo.model.get(), handlePrompt, respHandler, params,
|
||||||
QByteArray::fromRawData(conversation.data(), conversation.size()),
|
QByteArray::fromRawData(conversation.data(), conversation.size()),
|
||||||
ToolCallConstants::AllTagNames
|
ToolCallConstants::AllTagNames
|
||||||
);
|
);
|
||||||
@ -1251,10 +1109,10 @@ void ChatLLM::generateName()
|
|||||||
NameResponseHandler respHandler(this);
|
NameResponseHandler respHandler(this);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// TODO: support interruption via m_stopGenerating
|
||||||
promptModelWithTools(
|
promptModelWithTools(
|
||||||
m_llModelInfo.model.get(),
|
m_llModelInfo.model.get(),
|
||||||
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
|
respHandler, genParamsFromSettings(m_modelInfo),
|
||||||
respHandler, promptContextFromSettings(m_modelInfo),
|
|
||||||
applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(),
|
applyJinjaTemplate(forkConversation(chatNamePrompt)).c_str(),
|
||||||
{ ToolCallConstants::ThinkTagName }
|
{ ToolCallConstants::ThinkTagName }
|
||||||
);
|
);
|
||||||
@ -1327,10 +1185,10 @@ void ChatLLM::generateQuestions(qint64 elapsed)
|
|||||||
QElapsedTimer totalTime;
|
QElapsedTimer totalTime;
|
||||||
totalTime.start();
|
totalTime.start();
|
||||||
try {
|
try {
|
||||||
|
// TODO: support interruption via m_stopGenerating
|
||||||
promptModelWithTools(
|
promptModelWithTools(
|
||||||
m_llModelInfo.model.get(),
|
m_llModelInfo.model.get(),
|
||||||
/*promptCallback*/ [this](auto &&...) { return !m_stopGenerating; },
|
respHandler, genParamsFromSettings(m_modelInfo),
|
||||||
respHandler, promptContextFromSettings(m_modelInfo),
|
|
||||||
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
|
applyJinjaTemplate(forkConversation(suggestedFollowUpPrompt)).c_str(),
|
||||||
{ ToolCallConstants::ThinkTagName }
|
{ ToolCallConstants::ThinkTagName }
|
||||||
);
|
);
|
||||||
|
@ -3,10 +3,9 @@
|
|||||||
|
|
||||||
#include "chatmodel.h"
|
#include "chatmodel.h"
|
||||||
#include "database.h"
|
#include "database.h"
|
||||||
|
#include "llmodel/chat.h"
|
||||||
#include "modellist.h"
|
#include "modellist.h"
|
||||||
|
|
||||||
#include <gpt4all-backend/llmodel.h>
|
|
||||||
|
|
||||||
#include <QByteArray>
|
#include <QByteArray>
|
||||||
#include <QElapsedTimer>
|
#include <QElapsedTimer>
|
||||||
#include <QFileInfo>
|
#include <QFileInfo>
|
||||||
@ -91,14 +90,9 @@ inline LLModelTypeV1 parseLLModelTypeV0(int v0)
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct LLModelInfo {
|
struct LLModelInfo {
|
||||||
std::unique_ptr<LLModel> model;
|
std::unique_ptr<gpt4all::ui::ChatLLModel> model;
|
||||||
QFileInfo fileInfo;
|
QFileInfo fileInfo;
|
||||||
std::optional<QString> fallbackReason;
|
void resetModel(ChatLLM *cllm, gpt4all::ui::ChatLLModel *model = nullptr);
|
||||||
|
|
||||||
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
|
|
||||||
// must be able to serialize the information even if it is in the unloaded state
|
|
||||||
|
|
||||||
void resetModel(ChatLLM *cllm, LLModel *model = nullptr);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TokenTimer : public QObject {
|
class TokenTimer : public QObject {
|
||||||
@ -145,9 +139,6 @@ class Chat;
|
|||||||
class ChatLLM : public QObject
|
class ChatLLM : public QObject
|
||||||
{
|
{
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
|
|
||||||
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
|
|
||||||
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
|
|
||||||
public:
|
public:
|
||||||
ChatLLM(Chat *parent, bool isServer = false);
|
ChatLLM(Chat *parent, bool isServer = false);
|
||||||
virtual ~ChatLLM();
|
virtual ~ChatLLM();
|
||||||
@ -175,27 +166,6 @@ public:
|
|||||||
void acquireModel();
|
void acquireModel();
|
||||||
void resetModel();
|
void resetModel();
|
||||||
|
|
||||||
QString deviceBackend() const
|
|
||||||
{
|
|
||||||
if (!isModelLoaded()) return QString();
|
|
||||||
std::string name = LLModel::GPUDevice::backendIdToName(m_llModelInfo.model->backendName());
|
|
||||||
return QString::fromStdString(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
QString device() const
|
|
||||||
{
|
|
||||||
if (!isModelLoaded()) return QString();
|
|
||||||
const char *name = m_llModelInfo.model->gpuDeviceName();
|
|
||||||
return name ? QString(name) : u"CPU"_s;
|
|
||||||
}
|
|
||||||
|
|
||||||
// not loaded -> QString(), no fallback -> QString("")
|
|
||||||
QString fallbackReason() const
|
|
||||||
{
|
|
||||||
if (!isModelLoaded()) return QString();
|
|
||||||
return m_llModelInfo.fallbackReason.value_or(u""_s);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool serialize(QDataStream &stream, int version);
|
bool serialize(QDataStream &stream, int version);
|
||||||
bool deserialize(QDataStream &stream, int version);
|
bool deserialize(QDataStream &stream, int version);
|
||||||
|
|
||||||
@ -211,8 +181,6 @@ public Q_SLOTS:
|
|||||||
void handleChatIdChanged(const QString &id);
|
void handleChatIdChanged(const QString &id);
|
||||||
void handleShouldBeLoadedChanged();
|
void handleShouldBeLoadedChanged();
|
||||||
void handleThreadStarted();
|
void handleThreadStarted();
|
||||||
void handleForceMetalChanged(bool forceMetal);
|
|
||||||
void handleDeviceChanged();
|
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void loadedModelInfoChanged();
|
void loadedModelInfoChanged();
|
||||||
@ -233,8 +201,6 @@ Q_SIGNALS:
|
|||||||
void trySwitchContextOfLoadedModelCompleted(int value);
|
void trySwitchContextOfLoadedModelCompleted(int value);
|
||||||
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||||
void reportSpeed(const QString &speed);
|
void reportSpeed(const QString &speed);
|
||||||
void reportDevice(const QString &device);
|
|
||||||
void reportFallbackReason(const QString &fallbackReason);
|
|
||||||
void databaseResultsChanged(const QList<ResultInfo>&);
|
void databaseResultsChanged(const QList<ResultInfo>&);
|
||||||
void modelInfoChanged(const ModelInfo &modelInfo);
|
void modelInfoChanged(const ModelInfo &modelInfo);
|
||||||
|
|
||||||
@ -249,12 +215,11 @@ protected:
|
|||||||
QList<ResultInfo> databaseResults;
|
QList<ResultInfo> databaseResults;
|
||||||
};
|
};
|
||||||
|
|
||||||
ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx,
|
auto promptInternalChat(const QStringList &enabledCollections, const gpt4all::backend::GenerationParams ¶ms,
|
||||||
qsizetype startOffset = 0);
|
qsizetype startOffset = 0) -> ChatPromptResult;
|
||||||
// passing a string_view directly skips templating and uses the raw string
|
// passing a string_view directly skips templating and uses the raw string
|
||||||
PromptResult promptInternal(const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
|
auto promptInternal(const std::variant<std::span<const MessageItem>, std::string_view> &prompt,
|
||||||
const LLModel::PromptContext &ctx,
|
const gpt4all::backend::GenerationParams ¶ms, bool usedLocalDocs) -> PromptResult;
|
||||||
bool usedLocalDocs);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
||||||
@ -282,8 +247,6 @@ private:
|
|||||||
std::atomic<bool> m_forceUnloadModel;
|
std::atomic<bool> m_forceUnloadModel;
|
||||||
std::atomic<bool> m_markedForDeletion;
|
std::atomic<bool> m_markedForDeletion;
|
||||||
bool m_isServer;
|
bool m_isServer;
|
||||||
bool m_forceMetal;
|
|
||||||
bool m_reloadingToChangeVariant;
|
|
||||||
friend class ChatViewResponseHandler;
|
friend class ChatViewResponseHandler;
|
||||||
friend class SimpleResponseHandler;
|
friend class SimpleResponseHandler;
|
||||||
};
|
};
|
||||||
|
@ -88,73 +88,14 @@ bool EmbeddingLLMWorker::loadModel()
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
QString requestedDevice = MySettings::globalInstance()->localDocsEmbedDevice();
|
|
||||||
std::string backend = "auto";
|
|
||||||
#ifdef Q_OS_MAC
|
|
||||||
if (requestedDevice == "Auto" || requestedDevice == "CPU")
|
|
||||||
backend = "cpu";
|
|
||||||
#else
|
|
||||||
if (requestedDevice.startsWith("CUDA: "))
|
|
||||||
backend = "cuda";
|
|
||||||
#endif
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
m_model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx);
|
m_model = LLModel::Implementation::construct(filePath.toStdString(), "", n_ctx);
|
||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
qWarning() << "embllm WARNING: Could not load embedding model:" << e.what();
|
qWarning() << "embllm WARNING: Could not load embedding model:" << e.what();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool actualDeviceIsCPU = true;
|
|
||||||
|
|
||||||
#if defined(Q_OS_MAC) && defined(__aarch64__)
|
|
||||||
if (m_model->implementation().buildVariant() == "metal")
|
|
||||||
actualDeviceIsCPU = false;
|
|
||||||
#else
|
|
||||||
if (requestedDevice != "CPU") {
|
|
||||||
const LLModel::GPUDevice *device = nullptr;
|
|
||||||
std::vector<LLModel::GPUDevice> availableDevices = m_model->availableGPUDevices(0);
|
|
||||||
if (requestedDevice != "Auto") {
|
|
||||||
// Use the selected device
|
|
||||||
for (const LLModel::GPUDevice &d : availableDevices) {
|
|
||||||
if (QString::fromStdString(d.selectionName()) == requestedDevice) {
|
|
||||||
device = &d;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string unavail_reason;
|
|
||||||
if (!device) {
|
|
||||||
// GPU not available
|
|
||||||
} else if (!m_model->initializeGPUDevice(device->index, &unavail_reason)) {
|
|
||||||
qWarning().noquote() << "embllm WARNING: Did not use GPU:" << QString::fromStdString(unavail_reason);
|
|
||||||
} else {
|
|
||||||
actualDeviceIsCPU = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
bool success = m_model->loadModel(filePath.toStdString(), n_ctx, 100);
|
bool success = m_model->loadModel(filePath.toStdString(), n_ctx, 100);
|
||||||
|
|
||||||
// CPU fallback
|
|
||||||
if (!actualDeviceIsCPU && !success) {
|
|
||||||
// llama_init_from_file returned nullptr
|
|
||||||
qWarning() << "embllm WARNING: Did not use GPU: GPU loading failed (out of VRAM?)";
|
|
||||||
|
|
||||||
if (backend == "cuda") {
|
|
||||||
// For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls
|
|
||||||
try {
|
|
||||||
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto", n_ctx);
|
|
||||||
} catch (const std::exception &e) {
|
|
||||||
qWarning() << "embllm WARNING: Could not load embedding model:" << e.what();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
success = m_model->loadModel(filePath.toStdString(), n_ctx, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!success) {
|
if (!success) {
|
||||||
qWarning() << "embllm WARNING: Could not load embedding model";
|
qWarning() << "embllm WARNING: Could not load embedding model";
|
||||||
delete m_model;
|
delete m_model;
|
||||||
|
32
gpt4all-chat/src/llmodel/chat.h
Normal file
32
gpt4all-chat/src/llmodel/chat.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QStringView>
|
||||||
|
|
||||||
|
class QString;
|
||||||
|
namespace QCoro { template <typename T> class AsyncGenerator; }
|
||||||
|
namespace gpt4all::backend { struct GenerationParams; }
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
struct ChatResponseMetadata {
|
||||||
|
int nPromptTokens;
|
||||||
|
int nResponseTokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: implement two of these; one based on Ollama (TBD) and the other based on OpenAI (chatapi.h)
|
||||||
|
class ChatLLModel {
|
||||||
|
public:
|
||||||
|
virtual ~ChatLLModel() = 0;
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
|
virtual QString name() = 0;
|
||||||
|
|
||||||
|
virtual void preload() = 0;
|
||||||
|
virtual auto chat(QStringView prompt, const backend::GenerationParams ¶ms,
|
||||||
|
/*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator<QString> = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
217
gpt4all-chat/src/llmodel/openai.cpp
Normal file
217
gpt4all-chat/src/llmodel/openai.cpp
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#include "openai.h"
|
||||||
|
|
||||||
|
#include "mysettings.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
#include <QCoro/QCoroAsyncGenerator> // IWYU pragma: keep
|
||||||
|
#include <QCoro/QCoroNetworkReply> // IWYU pragma: keep
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <gpt4all-backend/formatters.h>
|
||||||
|
#include <gpt4all-backend/generation-params.h>
|
||||||
|
#include <gpt4all-backend/rest.h>
|
||||||
|
|
||||||
|
#include <QByteArray>
|
||||||
|
#include <QJsonArray>
|
||||||
|
#include <QJsonDocument>
|
||||||
|
#include <QJsonObject>
|
||||||
|
#include <QJsonValue>
|
||||||
|
#include <QLatin1String>
|
||||||
|
#include <QNetworkAccessManager>
|
||||||
|
#include <QNetworkRequest>
|
||||||
|
#include <QRestAccessManager>
|
||||||
|
#include <QRestReply>
|
||||||
|
#include <QStringView>
|
||||||
|
#include <QUrl>
|
||||||
|
#include <QUtf8StringView> // IWYU pragma: keep
|
||||||
|
#include <QVariant>
|
||||||
|
#include <QXmlStreamReader>
|
||||||
|
#include <Qt>
|
||||||
|
|
||||||
|
#include <expected>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using namespace Qt::Literals::StringLiterals;
|
||||||
|
|
||||||
|
//#define DEBUG
|
||||||
|
|
||||||
|
|
||||||
|
static auto processRespLine(const QByteArray &line) -> std::optional<QString>
|
||||||
|
{
|
||||||
|
auto jsonData = line.trimmed();
|
||||||
|
if (jsonData.startsWith("data:"_ba))
|
||||||
|
jsonData.remove(0, 5);
|
||||||
|
jsonData = jsonData.trimmed();
|
||||||
|
if (jsonData.isEmpty())
|
||||||
|
return std::nullopt;
|
||||||
|
if (jsonData == "[DONE]")
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
QJsonParseError err;
|
||||||
|
auto document = QJsonDocument::fromJson(jsonData, &err);
|
||||||
|
if (document.isNull())
|
||||||
|
throw std::runtime_error(fmt::format("OpenAI chat response parsing failed: {}", err.errorString()));
|
||||||
|
|
||||||
|
auto root = document.object();
|
||||||
|
auto choices = root.value("choices").toArray();
|
||||||
|
auto choice = choices.first().toObject();
|
||||||
|
auto delta = choice.value("delta").toObject();
|
||||||
|
return delta.value("content").toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
void OpenaiModelDescription::setDisplayName(QString value)
|
||||||
|
{
|
||||||
|
if (m_displayName != value) {
|
||||||
|
m_displayName = std::move(value);
|
||||||
|
emit displayNameChanged(m_displayName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpenaiModelDescription::setModelName(QString value)
|
||||||
|
{
|
||||||
|
if (m_modelName != value) {
|
||||||
|
m_modelName = std::move(value);
|
||||||
|
emit modelNameChanged(m_modelName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OpenaiLLModel::OpenaiLLModel(OpenaiConnectionDetails connDetails, QNetworkAccessManager *nam)
|
||||||
|
: m_connDetails(std::move(connDetails))
|
||||||
|
, m_nam(nam)
|
||||||
|
{}
|
||||||
|
|
||||||
|
static auto parsePrompt(QXmlStreamReader &xml) -> std::expected<QJsonArray, QString>
|
||||||
|
{
|
||||||
|
QJsonArray messages;
|
||||||
|
|
||||||
|
auto xmlError = [&xml] {
|
||||||
|
return std::unexpected(u"%1:%2: %3"_s.arg(xml.lineNumber()).arg(xml.columnNumber()).arg(xml.errorString()));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (xml.hasError())
|
||||||
|
return xmlError();
|
||||||
|
if (xml.atEnd())
|
||||||
|
return messages;
|
||||||
|
|
||||||
|
// skip header
|
||||||
|
bool foundElement = false;
|
||||||
|
do {
|
||||||
|
switch (xml.readNext()) {
|
||||||
|
using enum QXmlStreamReader::TokenType;
|
||||||
|
case Invalid:
|
||||||
|
return xmlError();
|
||||||
|
case EndDocument:
|
||||||
|
return messages;
|
||||||
|
default:
|
||||||
|
foundElement = true;
|
||||||
|
case StartDocument:
|
||||||
|
case Comment:
|
||||||
|
case DTD:
|
||||||
|
case ProcessingInstruction:
|
||||||
|
;
|
||||||
|
}
|
||||||
|
} while (!foundElement);
|
||||||
|
|
||||||
|
// document body loop
|
||||||
|
bool foundRoot = false;
|
||||||
|
for (;;) {
|
||||||
|
switch (xml.tokenType()) {
|
||||||
|
using enum QXmlStreamReader::TokenType;
|
||||||
|
case StartElement:
|
||||||
|
{
|
||||||
|
auto name = xml.name();
|
||||||
|
if (!foundRoot) {
|
||||||
|
if (name != "chat"_L1)
|
||||||
|
return std::unexpected(u"unexpected tag: %1"_s.arg(name));
|
||||||
|
foundRoot = true;
|
||||||
|
} else {
|
||||||
|
if (name != "user"_L1 && name != "assistant"_L1 && name != "system"_L1)
|
||||||
|
return std::unexpected(u"unknown role: %1"_s.arg(name));
|
||||||
|
auto content = xml.readElementText();
|
||||||
|
if (xml.tokenType() != EndElement)
|
||||||
|
return xmlError();
|
||||||
|
messages << makeJsonObject({
|
||||||
|
{ "role"_L1, name.toString().trimmed() },
|
||||||
|
{ "content"_L1, content },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Characters:
|
||||||
|
if (!xml.isWhitespace())
|
||||||
|
return std::unexpected(u"unexpected text: %1"_s.arg(xml.text()));
|
||||||
|
case Comment:
|
||||||
|
case ProcessingInstruction:
|
||||||
|
case EndElement:
|
||||||
|
break;
|
||||||
|
case EndDocument:
|
||||||
|
return messages;
|
||||||
|
case Invalid:
|
||||||
|
return xmlError();
|
||||||
|
default:
|
||||||
|
return std::unexpected(u"unexpected token: %1"_s.arg(xml.tokenString()));
|
||||||
|
}
|
||||||
|
xml.readNext();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto OpenaiLLModel::chat(QStringView prompt, const backend::GenerationParams ¶ms,
|
||||||
|
/*out*/ ChatResponseMetadata &metadata) -> QCoro::AsyncGenerator<QString>
|
||||||
|
{
|
||||||
|
auto *mySettings = MySettings::globalInstance();
|
||||||
|
|
||||||
|
if (!params.n_predict)
|
||||||
|
co_return; // nothing requested
|
||||||
|
|
||||||
|
auto reqBody = makeJsonObject({
|
||||||
|
{ "model"_L1, m_connDetails.modelName },
|
||||||
|
{ "max_completion_tokens"_L1, qint64(params.n_predict) },
|
||||||
|
{ "stream"_L1, true },
|
||||||
|
{ "temperature"_L1, params.temperature },
|
||||||
|
{ "top_p"_L1, params.top_p },
|
||||||
|
});
|
||||||
|
|
||||||
|
// conversation history
|
||||||
|
{
|
||||||
|
QXmlStreamReader xml(prompt);
|
||||||
|
auto messages = parsePrompt(xml);
|
||||||
|
if (!messages)
|
||||||
|
throw std::invalid_argument(fmt::format("Failed to parse OpenAI prompt: {}", messages.error()));
|
||||||
|
reqBody.insert("messages"_L1, *messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
QNetworkRequest request(m_connDetails.baseUrl.resolved(QUrl("/v1/chat/completions")));
|
||||||
|
request.setHeader(QNetworkRequest::UserAgentHeader, mySettings->userAgent());
|
||||||
|
request.setRawHeader("authorization", u"Bearer %1"_s.arg(m_connDetails.apiKey).toUtf8());
|
||||||
|
|
||||||
|
QRestAccessManager restNam(m_nam);
|
||||||
|
std::unique_ptr<QNetworkReply> reply(restNam.post(request, QJsonDocument(reqBody)));
|
||||||
|
|
||||||
|
auto makeError = [](const QRestReply &reply) {
|
||||||
|
return std::runtime_error(fmt::format("OpenAI chat request failed: {}", backend::restErrorString(reply)));
|
||||||
|
};
|
||||||
|
|
||||||
|
QRestReply restReply(reply.get());
|
||||||
|
if (reply->error())
|
||||||
|
throw makeError(restReply);
|
||||||
|
|
||||||
|
auto coroReply = qCoro(reply.get());
|
||||||
|
for (;;) {
|
||||||
|
auto line = co_await coroReply.readLine();
|
||||||
|
if (!restReply.isSuccess())
|
||||||
|
throw makeError(restReply);
|
||||||
|
if (line.isEmpty()) {
|
||||||
|
Q_ASSERT(reply->atEnd());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (auto chunk = processRespLine(line))
|
||||||
|
co_yield *chunk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
75
gpt4all-chat/src/llmodel/openai.h
Normal file
75
gpt4all-chat/src/llmodel/openai.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
#include "provider.h"
|
||||||
|
|
||||||
|
#include <QObject>
|
||||||
|
#include <QQmlEngine>
|
||||||
|
#include <QString>
|
||||||
|
#include <QUrl>
|
||||||
|
|
||||||
|
class QNetworkAccessManager;
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
class OpenaiModelDescription : public QObject {
|
||||||
|
Q_OBJECT
|
||||||
|
QML_ELEMENT
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit OpenaiModelDescription(OpenaiProvider *provider, QString displayName, QString modelName)
|
||||||
|
: QObject(provider)
|
||||||
|
, m_provider(provider)
|
||||||
|
, m_displayName(std::move(displayName))
|
||||||
|
, m_modelName(std::move(modelName))
|
||||||
|
{}
|
||||||
|
|
||||||
|
// getters
|
||||||
|
[[nodiscard]] OpenaiProvider *provider () const { return m_provider; }
|
||||||
|
[[nodiscard]] const QString &displayName() const { return m_displayName; }
|
||||||
|
[[nodiscard]] const QString &modelName () const { return m_modelName; }
|
||||||
|
|
||||||
|
// setters
|
||||||
|
void setDisplayName(QString value);
|
||||||
|
void setModelName (QString value);
|
||||||
|
|
||||||
|
Q_SIGNALS:
|
||||||
|
void displayNameChanged(const QString &value);
|
||||||
|
void modelNameChanged (const QString &value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpenaiProvider *m_provider;
|
||||||
|
QString m_displayName;
|
||||||
|
QString m_modelName;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OpenaiConnectionDetails {
|
||||||
|
QUrl baseUrl;
|
||||||
|
QString modelName;
|
||||||
|
QString apiKey;
|
||||||
|
|
||||||
|
OpenaiConnectionDetails(const OpenaiModelDescription *desc)
|
||||||
|
: baseUrl(desc->provider()->baseUrl())
|
||||||
|
, apiKey(desc->provider()->apiKey())
|
||||||
|
, modelName(desc->modelName())
|
||||||
|
{}
|
||||||
|
};
|
||||||
|
|
||||||
|
class OpenaiLLModel : public ChatLLModel {
|
||||||
|
public:
|
||||||
|
explicit OpenaiLLModel(OpenaiConnectionDetails connDetails, QNetworkAccessManager *nam);
|
||||||
|
|
||||||
|
void preload() override { /* not supported -> no-op */ }
|
||||||
|
|
||||||
|
auto chat(QStringView prompt, const backend::GenerationParams ¶ms, /*out*/ ChatResponseMetadata &metadata)
|
||||||
|
-> QCoro::AsyncGenerator<QString> override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpenaiConnectionDetails m_connDetails;
|
||||||
|
QNetworkAccessManager *m_nam;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
26
gpt4all-chat/src/llmodel/provider.cpp
Normal file
26
gpt4all-chat/src/llmodel/provider.cpp
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#include "provider.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
void OpenaiProvider::setBaseUrl(QUrl value)
|
||||||
|
{
|
||||||
|
if (m_baseUrl != value) {
|
||||||
|
m_baseUrl = std::move(value);
|
||||||
|
emit baseUrlChanged(m_baseUrl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpenaiProvider::setApiKey(QString value)
|
||||||
|
{
|
||||||
|
if (m_apiKey != value) {
|
||||||
|
m_apiKey = std::move(value);
|
||||||
|
emit apiKeyChanged(m_apiKey);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
47
gpt4all-chat/src/llmodel/provider.h
Normal file
47
gpt4all-chat/src/llmodel/provider.h
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <QObject>
|
||||||
|
#include <QQmlEngine>
|
||||||
|
#include <QString>
|
||||||
|
#include <QUrl>
|
||||||
|
|
||||||
|
|
||||||
|
namespace gpt4all::ui {
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider : public QObject {
|
||||||
|
Q_OBJECT
|
||||||
|
|
||||||
|
Q_PROPERTY(QString name READ name CONSTANT)
|
||||||
|
|
||||||
|
public:
|
||||||
|
[[nodiscard]] virtual QString name() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OpenaiProvider : public ModelProvider {
|
||||||
|
Q_OBJECT
|
||||||
|
QML_ELEMENT
|
||||||
|
|
||||||
|
Q_PROPERTY(QUrl baseUrl READ baseUrl WRITE setBaseUrl NOTIFY baseUrlChanged)
|
||||||
|
Q_PROPERTY(QString apiKey READ apiKey WRITE setApiKey NOTIFY apiKeyChanged)
|
||||||
|
|
||||||
|
public:
|
||||||
|
[[nodiscard]] QString name() override { return m_name; }
|
||||||
|
[[nodiscard]] const QUrl &baseUrl() { return m_baseUrl; }
|
||||||
|
[[nodiscard]] const QString &apiKey () { return m_apiKey; }
|
||||||
|
|
||||||
|
void setBaseUrl(QUrl value);
|
||||||
|
void setApiKey (QString value);
|
||||||
|
|
||||||
|
Q_SIGNALS:
|
||||||
|
void baseUrlChanged(const QUrl &value);
|
||||||
|
void apiKeyChanged (const QString &value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
QString m_name;
|
||||||
|
QUrl m_baseUrl;
|
||||||
|
QString m_apiKey;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace gpt4all::ui
|
@ -1,6 +1,7 @@
|
|||||||
#include "mysettings.h"
|
#include "mysettings.h"
|
||||||
|
|
||||||
#include "chatllm.h"
|
#include "chatllm.h"
|
||||||
|
#include "config.h"
|
||||||
#include "modellist.h"
|
#include "modellist.h"
|
||||||
|
|
||||||
#include <gpt4all-backend/llmodel.h>
|
#include <gpt4all-backend/llmodel.h>
|
||||||
@ -48,7 +49,6 @@ namespace ModelSettingsKey { namespace {
|
|||||||
namespace defaults {
|
namespace defaults {
|
||||||
|
|
||||||
static const int threadCount = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
static const int threadCount = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
static const bool forceMetal = false;
|
|
||||||
static const bool networkIsActive = false;
|
static const bool networkIsActive = false;
|
||||||
static const bool networkUsageStatsActive = false;
|
static const bool networkUsageStatsActive = false;
|
||||||
static const QString device = "Auto";
|
static const QString device = "Auto";
|
||||||
@ -71,7 +71,6 @@ static const QVariantMap basicDefaults {
|
|||||||
{ "localdocs/fileExtensions", QStringList { "docx", "pdf", "txt", "md", "rst" } },
|
{ "localdocs/fileExtensions", QStringList { "docx", "pdf", "txt", "md", "rst" } },
|
||||||
{ "localdocs/useRemoteEmbed", false },
|
{ "localdocs/useRemoteEmbed", false },
|
||||||
{ "localdocs/nomicAPIKey", "" },
|
{ "localdocs/nomicAPIKey", "" },
|
||||||
{ "localdocs/embedDevice", "Auto" },
|
|
||||||
{ "network/attribution", "" },
|
{ "network/attribution", "" },
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -174,11 +173,16 @@ MySettings *MySettings::globalInstance()
|
|||||||
MySettings::MySettings()
|
MySettings::MySettings()
|
||||||
: QObject(nullptr)
|
: QObject(nullptr)
|
||||||
, m_deviceList(getDevices())
|
, m_deviceList(getDevices())
|
||||||
, m_embeddingsDeviceList(getDevices(/*skipKompute*/ true))
|
|
||||||
, m_uiLanguages(getUiLanguages(modelPath()))
|
, m_uiLanguages(getUiLanguages(modelPath()))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const QString &MySettings::userAgent()
|
||||||
|
{
|
||||||
|
static const QString s_userAgent = QStringLiteral("gpt4all/" APP_VERSION);
|
||||||
|
return s_userAgent;
|
||||||
|
}
|
||||||
|
|
||||||
QVariant MySettings::checkJinjaTemplateError(const QString &tmpl)
|
QVariant MySettings::checkJinjaTemplateError(const QString &tmpl)
|
||||||
{
|
{
|
||||||
if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString()))
|
if (auto err = ChatLLM::checkJinjaTemplateError(tmpl.toStdString()))
|
||||||
@ -256,7 +260,6 @@ void MySettings::restoreApplicationDefaults()
|
|||||||
setNetworkPort(basicDefaults.value("networkPort").toInt());
|
setNetworkPort(basicDefaults.value("networkPort").toInt());
|
||||||
setModelPath(defaultLocalModelsPath());
|
setModelPath(defaultLocalModelsPath());
|
||||||
setUserDefaultModel(basicDefaults.value("userDefaultModel").toString());
|
setUserDefaultModel(basicDefaults.value("userDefaultModel").toString());
|
||||||
setForceMetal(defaults::forceMetal);
|
|
||||||
setSuggestionMode(basicDefaults.value("suggestionMode").value<SuggestionMode>());
|
setSuggestionMode(basicDefaults.value("suggestionMode").value<SuggestionMode>());
|
||||||
setLanguageAndLocale(defaults::languageAndLocale);
|
setLanguageAndLocale(defaults::languageAndLocale);
|
||||||
}
|
}
|
||||||
@ -269,7 +272,6 @@ void MySettings::restoreLocalDocsDefaults()
|
|||||||
setLocalDocsFileExtensions(basicDefaults.value("localdocs/fileExtensions").toStringList());
|
setLocalDocsFileExtensions(basicDefaults.value("localdocs/fileExtensions").toStringList());
|
||||||
setLocalDocsUseRemoteEmbed(basicDefaults.value("localdocs/useRemoteEmbed").toBool());
|
setLocalDocsUseRemoteEmbed(basicDefaults.value("localdocs/useRemoteEmbed").toBool());
|
||||||
setLocalDocsNomicAPIKey(basicDefaults.value("localdocs/nomicAPIKey").toString());
|
setLocalDocsNomicAPIKey(basicDefaults.value("localdocs/nomicAPIKey").toString());
|
||||||
setLocalDocsEmbedDevice(basicDefaults.value("localdocs/embedDevice").toString());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MySettings::eraseModel(const ModelInfo &info)
|
void MySettings::eraseModel(const ModelInfo &info)
|
||||||
@ -628,7 +630,6 @@ bool MySettings::localDocsShowReferences() const { return getBasicSetting
|
|||||||
QStringList MySettings::localDocsFileExtensions() const { return getBasicSetting("localdocs/fileExtensions").toStringList(); }
|
QStringList MySettings::localDocsFileExtensions() const { return getBasicSetting("localdocs/fileExtensions").toStringList(); }
|
||||||
bool MySettings::localDocsUseRemoteEmbed() const { return getBasicSetting("localdocs/useRemoteEmbed").toBool(); }
|
bool MySettings::localDocsUseRemoteEmbed() const { return getBasicSetting("localdocs/useRemoteEmbed").toBool(); }
|
||||||
QString MySettings::localDocsNomicAPIKey() const { return getBasicSetting("localdocs/nomicAPIKey" ).toString(); }
|
QString MySettings::localDocsNomicAPIKey() const { return getBasicSetting("localdocs/nomicAPIKey" ).toString(); }
|
||||||
QString MySettings::localDocsEmbedDevice() const { return getBasicSetting("localdocs/embedDevice" ).toString(); }
|
|
||||||
QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); }
|
QString MySettings::networkAttribution() const { return getBasicSetting("network/attribution" ).toString(); }
|
||||||
|
|
||||||
ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnumSetting("chatTheme", chatThemeNames)); }
|
ChatTheme MySettings::chatTheme() const { return ChatTheme (getEnumSetting("chatTheme", chatThemeNames)); }
|
||||||
@ -646,7 +647,6 @@ void MySettings::setLocalDocsShowReferences(bool value) { setBasic
|
|||||||
void MySettings::setLocalDocsFileExtensions(const QStringList &value) { setBasicSetting("localdocs/fileExtensions", value, "localDocsFileExtensions"); }
|
void MySettings::setLocalDocsFileExtensions(const QStringList &value) { setBasicSetting("localdocs/fileExtensions", value, "localDocsFileExtensions"); }
|
||||||
void MySettings::setLocalDocsUseRemoteEmbed(bool value) { setBasicSetting("localdocs/useRemoteEmbed", value, "localDocsUseRemoteEmbed"); }
|
void MySettings::setLocalDocsUseRemoteEmbed(bool value) { setBasicSetting("localdocs/useRemoteEmbed", value, "localDocsUseRemoteEmbed"); }
|
||||||
void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasicSetting("localdocs/nomicAPIKey", value, "localDocsNomicAPIKey"); }
|
void MySettings::setLocalDocsNomicAPIKey(const QString &value) { setBasicSetting("localdocs/nomicAPIKey", value, "localDocsNomicAPIKey"); }
|
||||||
void MySettings::setLocalDocsEmbedDevice(const QString &value) { setBasicSetting("localdocs/embedDevice", value, "localDocsEmbedDevice"); }
|
|
||||||
void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); }
|
void MySettings::setNetworkAttribution(const QString &value) { setBasicSetting("network/attribution", value, "networkAttribution"); }
|
||||||
|
|
||||||
void MySettings::setChatTheme(ChatTheme value) { setBasicSetting("chatTheme", chatThemeNames .value(int(value))); }
|
void MySettings::setChatTheme(ChatTheme value) { setBasicSetting("chatTheme", chatThemeNames .value(int(value))); }
|
||||||
@ -706,19 +706,6 @@ void MySettings::setDevice(const QString &value)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MySettings::forceMetal() const
|
|
||||||
{
|
|
||||||
return m_forceMetal;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MySettings::setForceMetal(bool value)
|
|
||||||
{
|
|
||||||
if (m_forceMetal != value) {
|
|
||||||
m_forceMetal = value;
|
|
||||||
emit forceMetalChanged(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool MySettings::networkIsActive() const
|
bool MySettings::networkIsActive() const
|
||||||
{
|
{
|
||||||
return m_settings.value("network/isActive", defaults::networkIsActive).toBool();
|
return m_settings.value("network/isActive", defaults::networkIsActive).toBool();
|
||||||
|
@ -62,7 +62,6 @@ class MySettings : public QObject
|
|||||||
Q_PROPERTY(ChatTheme chatTheme READ chatTheme WRITE setChatTheme NOTIFY chatThemeChanged)
|
Q_PROPERTY(ChatTheme chatTheme READ chatTheme WRITE setChatTheme NOTIFY chatThemeChanged)
|
||||||
Q_PROPERTY(FontSize fontSize READ fontSize WRITE setFontSize NOTIFY fontSizeChanged)
|
Q_PROPERTY(FontSize fontSize READ fontSize WRITE setFontSize NOTIFY fontSizeChanged)
|
||||||
Q_PROPERTY(QString languageAndLocale READ languageAndLocale WRITE setLanguageAndLocale NOTIFY languageAndLocaleChanged)
|
Q_PROPERTY(QString languageAndLocale READ languageAndLocale WRITE setLanguageAndLocale NOTIFY languageAndLocaleChanged)
|
||||||
Q_PROPERTY(bool forceMetal READ forceMetal WRITE setForceMetal NOTIFY forceMetalChanged)
|
|
||||||
Q_PROPERTY(QString lastVersionStarted READ lastVersionStarted WRITE setLastVersionStarted NOTIFY lastVersionStartedChanged)
|
Q_PROPERTY(QString lastVersionStarted READ lastVersionStarted WRITE setLastVersionStarted NOTIFY lastVersionStartedChanged)
|
||||||
Q_PROPERTY(int localDocsChunkSize READ localDocsChunkSize WRITE setLocalDocsChunkSize NOTIFY localDocsChunkSizeChanged)
|
Q_PROPERTY(int localDocsChunkSize READ localDocsChunkSize WRITE setLocalDocsChunkSize NOTIFY localDocsChunkSizeChanged)
|
||||||
Q_PROPERTY(int localDocsRetrievalSize READ localDocsRetrievalSize WRITE setLocalDocsRetrievalSize NOTIFY localDocsRetrievalSizeChanged)
|
Q_PROPERTY(int localDocsRetrievalSize READ localDocsRetrievalSize WRITE setLocalDocsRetrievalSize NOTIFY localDocsRetrievalSizeChanged)
|
||||||
@ -70,13 +69,11 @@ class MySettings : public QObject
|
|||||||
Q_PROPERTY(QStringList localDocsFileExtensions READ localDocsFileExtensions WRITE setLocalDocsFileExtensions NOTIFY localDocsFileExtensionsChanged)
|
Q_PROPERTY(QStringList localDocsFileExtensions READ localDocsFileExtensions WRITE setLocalDocsFileExtensions NOTIFY localDocsFileExtensionsChanged)
|
||||||
Q_PROPERTY(bool localDocsUseRemoteEmbed READ localDocsUseRemoteEmbed WRITE setLocalDocsUseRemoteEmbed NOTIFY localDocsUseRemoteEmbedChanged)
|
Q_PROPERTY(bool localDocsUseRemoteEmbed READ localDocsUseRemoteEmbed WRITE setLocalDocsUseRemoteEmbed NOTIFY localDocsUseRemoteEmbedChanged)
|
||||||
Q_PROPERTY(QString localDocsNomicAPIKey READ localDocsNomicAPIKey WRITE setLocalDocsNomicAPIKey NOTIFY localDocsNomicAPIKeyChanged)
|
Q_PROPERTY(QString localDocsNomicAPIKey READ localDocsNomicAPIKey WRITE setLocalDocsNomicAPIKey NOTIFY localDocsNomicAPIKeyChanged)
|
||||||
Q_PROPERTY(QString localDocsEmbedDevice READ localDocsEmbedDevice WRITE setLocalDocsEmbedDevice NOTIFY localDocsEmbedDeviceChanged)
|
|
||||||
Q_PROPERTY(QString networkAttribution READ networkAttribution WRITE setNetworkAttribution NOTIFY networkAttributionChanged)
|
Q_PROPERTY(QString networkAttribution READ networkAttribution WRITE setNetworkAttribution NOTIFY networkAttributionChanged)
|
||||||
Q_PROPERTY(bool networkIsActive READ networkIsActive WRITE setNetworkIsActive NOTIFY networkIsActiveChanged)
|
Q_PROPERTY(bool networkIsActive READ networkIsActive WRITE setNetworkIsActive NOTIFY networkIsActiveChanged)
|
||||||
Q_PROPERTY(bool networkUsageStatsActive READ networkUsageStatsActive WRITE setNetworkUsageStatsActive NOTIFY networkUsageStatsActiveChanged)
|
Q_PROPERTY(bool networkUsageStatsActive READ networkUsageStatsActive WRITE setNetworkUsageStatsActive NOTIFY networkUsageStatsActiveChanged)
|
||||||
Q_PROPERTY(QString device READ device WRITE setDevice NOTIFY deviceChanged)
|
Q_PROPERTY(QString device READ device WRITE setDevice NOTIFY deviceChanged)
|
||||||
Q_PROPERTY(QStringList deviceList MEMBER m_deviceList CONSTANT)
|
Q_PROPERTY(QStringList deviceList MEMBER m_deviceList CONSTANT)
|
||||||
Q_PROPERTY(QStringList embeddingsDeviceList MEMBER m_embeddingsDeviceList CONSTANT)
|
|
||||||
Q_PROPERTY(int networkPort READ networkPort WRITE setNetworkPort NOTIFY networkPortChanged)
|
Q_PROPERTY(int networkPort READ networkPort WRITE setNetworkPort NOTIFY networkPortChanged)
|
||||||
Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged)
|
Q_PROPERTY(SuggestionMode suggestionMode READ suggestionMode WRITE setSuggestionMode NOTIFY suggestionModeChanged)
|
||||||
Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT)
|
Q_PROPERTY(QStringList uiLanguages MEMBER m_uiLanguages CONSTANT)
|
||||||
@ -91,6 +88,8 @@ public Q_SLOTS:
|
|||||||
public:
|
public:
|
||||||
static MySettings *globalInstance();
|
static MySettings *globalInstance();
|
||||||
|
|
||||||
|
static const QString &userAgent();
|
||||||
|
|
||||||
Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl);
|
Q_INVOKABLE static QVariant checkJinjaTemplateError(const QString &tmpl);
|
||||||
|
|
||||||
// Restore methods
|
// Restore methods
|
||||||
@ -172,8 +171,6 @@ public:
|
|||||||
void setChatTheme(ChatTheme value);
|
void setChatTheme(ChatTheme value);
|
||||||
FontSize fontSize() const;
|
FontSize fontSize() const;
|
||||||
void setFontSize(FontSize value);
|
void setFontSize(FontSize value);
|
||||||
bool forceMetal() const;
|
|
||||||
void setForceMetal(bool value);
|
|
||||||
QString device();
|
QString device();
|
||||||
void setDevice(const QString &value);
|
void setDevice(const QString &value);
|
||||||
int32_t contextLength() const;
|
int32_t contextLength() const;
|
||||||
@ -203,8 +200,6 @@ public:
|
|||||||
void setLocalDocsUseRemoteEmbed(bool value);
|
void setLocalDocsUseRemoteEmbed(bool value);
|
||||||
QString localDocsNomicAPIKey() const;
|
QString localDocsNomicAPIKey() const;
|
||||||
void setLocalDocsNomicAPIKey(const QString &value);
|
void setLocalDocsNomicAPIKey(const QString &value);
|
||||||
QString localDocsEmbedDevice() const;
|
|
||||||
void setLocalDocsEmbedDevice(const QString &value);
|
|
||||||
|
|
||||||
// Network settings
|
// Network settings
|
||||||
QString networkAttribution() const;
|
QString networkAttribution() const;
|
||||||
@ -243,7 +238,6 @@ Q_SIGNALS:
|
|||||||
void userDefaultModelChanged();
|
void userDefaultModelChanged();
|
||||||
void chatThemeChanged();
|
void chatThemeChanged();
|
||||||
void fontSizeChanged();
|
void fontSizeChanged();
|
||||||
void forceMetalChanged(bool);
|
|
||||||
void lastVersionStartedChanged();
|
void lastVersionStartedChanged();
|
||||||
void localDocsChunkSizeChanged();
|
void localDocsChunkSizeChanged();
|
||||||
void localDocsRetrievalSizeChanged();
|
void localDocsRetrievalSizeChanged();
|
||||||
@ -251,7 +245,6 @@ Q_SIGNALS:
|
|||||||
void localDocsFileExtensionsChanged();
|
void localDocsFileExtensionsChanged();
|
||||||
void localDocsUseRemoteEmbedChanged();
|
void localDocsUseRemoteEmbedChanged();
|
||||||
void localDocsNomicAPIKeyChanged();
|
void localDocsNomicAPIKeyChanged();
|
||||||
void localDocsEmbedDeviceChanged();
|
|
||||||
void networkAttributionChanged();
|
void networkAttributionChanged();
|
||||||
void networkIsActiveChanged();
|
void networkIsActiveChanged();
|
||||||
void networkPortChanged();
|
void networkPortChanged();
|
||||||
@ -287,9 +280,7 @@ private:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
QSettings m_settings;
|
QSettings m_settings;
|
||||||
bool m_forceMetal;
|
|
||||||
const QStringList m_deviceList;
|
const QStringList m_deviceList;
|
||||||
const QStringList m_embeddingsDeviceList;
|
|
||||||
const QStringList m_uiLanguages;
|
const QStringList m_uiLanguages;
|
||||||
std::unique_ptr<QTranslator> m_translator;
|
std::unique_ptr<QTranslator> m_translator;
|
||||||
|
|
||||||
|
@ -372,8 +372,6 @@ void Network::trackChatEvent(const QString &ev, QVariantMap props)
|
|||||||
Q_ASSERT(curChat);
|
Q_ASSERT(curChat);
|
||||||
if (!props.contains("model"))
|
if (!props.contains("model"))
|
||||||
props.insert("model", curChat->modelInfo().filename());
|
props.insert("model", curChat->modelInfo().filename());
|
||||||
props.insert("device_backend", curChat->deviceBackend());
|
|
||||||
props.insert("actualDevice", curChat->device());
|
|
||||||
props.insert("doc_collections_enabled", curChat->collectionList().count());
|
props.insert("doc_collections_enabled", curChat->collectionList().count());
|
||||||
props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount());
|
props.insert("doc_collections_total", LocalDocs::globalInstance()->localDocsModel()->rowCount());
|
||||||
props.insert("datalake_active", MySettings::globalInstance()->networkIsActive());
|
props.insert("datalake_active", MySettings::globalInstance()->networkIsActive());
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <gpt4all-backend/formatters.h>
|
#include <gpt4all-backend/formatters.h>
|
||||||
|
#include <gpt4all-backend/generation-params.h>
|
||||||
#include <gpt4all-backend/llmodel.h>
|
#include <gpt4all-backend/llmodel.h>
|
||||||
|
|
||||||
#include <QByteArray>
|
#include <QByteArray>
|
||||||
@ -126,7 +127,7 @@ class BaseCompletionRequest {
|
|||||||
public:
|
public:
|
||||||
QString model; // required
|
QString model; // required
|
||||||
// NB: some parameters are not supported yet
|
// NB: some parameters are not supported yet
|
||||||
int32_t max_tokens = 16;
|
uint max_tokens = 16;
|
||||||
qint64 n = 1;
|
qint64 n = 1;
|
||||||
float temperature = 1.f;
|
float temperature = 1.f;
|
||||||
float top_p = 1.f;
|
float top_p = 1.f;
|
||||||
@ -161,7 +162,7 @@ protected:
|
|||||||
|
|
||||||
value = reqValue("max_tokens", Integer, false, /*min*/ 1);
|
value = reqValue("max_tokens", Integer, false, /*min*/ 1);
|
||||||
if (!value.isNull())
|
if (!value.isNull())
|
||||||
this->max_tokens = int32_t(qMin(value.toInteger(), INT32_MAX));
|
this->max_tokens = uint(qMin(value.toInteger(), UINT32_MAX));
|
||||||
|
|
||||||
value = reqValue("n", Integer, false, /*min*/ 1);
|
value = reqValue("n", Integer, false, /*min*/ 1);
|
||||||
if (!value.isNull())
|
if (!value.isNull())
|
||||||
@ -666,7 +667,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
|
|||||||
m_chatModel->appendResponse();
|
m_chatModel->appendResponse();
|
||||||
|
|
||||||
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
||||||
LLModel::PromptContext promptCtx {
|
backend::GenerationParams genParams {
|
||||||
.n_predict = request.max_tokens,
|
.n_predict = request.max_tokens,
|
||||||
.top_k = mySettings->modelTopK(modelInfo),
|
.top_k = mySettings->modelTopK(modelInfo),
|
||||||
.top_p = request.top_p,
|
.top_p = request.top_p,
|
||||||
@ -685,7 +686,7 @@ auto Server::handleCompletionRequest(const CompletionRequest &request)
|
|||||||
PromptResult result;
|
PromptResult result;
|
||||||
try {
|
try {
|
||||||
result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()),
|
result = promptInternal(std::string_view(promptUtf8.cbegin(), promptUtf8.cend()),
|
||||||
promptCtx,
|
genParams,
|
||||||
/*usedLocalDocs*/ false);
|
/*usedLocalDocs*/ false);
|
||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
m_chatModel->setResponseValue(e.what());
|
m_chatModel->setResponseValue(e.what());
|
||||||
@ -779,7 +780,7 @@ auto Server::handleChatRequest(const ChatRequest &request)
|
|||||||
auto startOffset = m_chatModel->appendResponseWithHistory(messages);
|
auto startOffset = m_chatModel->appendResponseWithHistory(messages);
|
||||||
|
|
||||||
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
// FIXME(jared): taking parameters from the UI inhibits reproducibility of results
|
||||||
LLModel::PromptContext promptCtx {
|
backend::GenerationParams genParams {
|
||||||
.n_predict = request.max_tokens,
|
.n_predict = request.max_tokens,
|
||||||
.top_k = mySettings->modelTopK(modelInfo),
|
.top_k = mySettings->modelTopK(modelInfo),
|
||||||
.top_p = request.top_p,
|
.top_p = request.top_p,
|
||||||
@ -796,7 +797,7 @@ auto Server::handleChatRequest(const ChatRequest &request)
|
|||||||
for (int i = 0; i < request.n; ++i) {
|
for (int i = 0; i < request.n; ++i) {
|
||||||
ChatPromptResult result;
|
ChatPromptResult result;
|
||||||
try {
|
try {
|
||||||
result = promptInternalChat(m_collections, promptCtx, startOffset);
|
result = promptInternalChat(m_collections, genParams, startOffset);
|
||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
m_chatModel->setResponseValue(e.what());
|
m_chatModel->setResponseValue(e.what());
|
||||||
m_chatModel->setError();
|
m_chatModel->setError();
|
||||||
|
Loading…
Reference in New Issue
Block a user