server: improve correctness of request parsing and responses (#2929)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-09-09 10:48:57 -04:00
committed by GitHub
parent 1aae4ffe0a
commit 39005288c5
22 changed files with 790 additions and 328 deletions

View File

@@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix a typo in Model Settings (by [@3Simplex](https://github.com/3Simplex) in [#2916](https://github.com/nomic-ai/gpt4all/pull/2916))
- Fix the antenna icon tooltip when using the local server ([#2922](https://github.com/nomic-ai/gpt4all/pull/2922))
- Fix a few issues with locating files and handling errors when loading remote models on startup ([#2875](https://github.com/nomic-ai/gpt4all/pull/2875))
- Significantly improve API server request parsing and response correctness ([#2929](https://github.com/nomic-ai/gpt4all/pull/2929))
## [3.2.1] - 2024-08-13

View File

@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.16)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
if(APPLE)
@@ -64,6 +64,12 @@ message(STATUS "Qt 6 root directory: ${Qt6_ROOT_DIR}")
set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(FMT_INSTALL OFF)
set(BUILD_SHARED_LIBS_SAVED "${BUILD_SHARED_LIBS}")
set(BUILD_SHARED_LIBS OFF)
add_subdirectory(deps/fmt)
set(BUILD_SHARED_LIBS "${BUILD_SHARED_LIBS_SAVED}")
add_subdirectory(../gpt4all-backend llmodel)
set(CHAT_EXE_RESOURCES)
@@ -240,7 +246,7 @@ else()
PRIVATE Qt6::Quick Qt6::Svg Qt6::HttpServer Qt6::Sql Qt6::Pdf)
endif()
target_link_libraries(chat
PRIVATE llmodel SingleApplication)
PRIVATE llmodel SingleApplication fmt::fmt)
# -- install --

1
gpt4all-chat/deps/fmt Submodule

Submodule gpt4all-chat/deps/fmt added at 0c9fce2ffe

View File

@@ -239,16 +239,17 @@ void Chat::newPromptResponsePair(const QString &prompt)
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt("Prompt: ", prompt);
m_chatModel->appendResponse("Response: ", prompt);
m_chatModel->appendResponse("Response: ", QString());
emit resetResponseRequested();
}
// the server needs to block until response is reset, so it calls resetResponse on its own m_llmThread
void Chat::serverNewPromptResponsePair(const QString &prompt)
{
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt("Prompt: ", prompt);
m_chatModel->appendResponse("Response: ", prompt);
m_chatModel->appendResponse("Response: ", QString());
}
bool Chat::restoringFromText() const

View File

@@ -93,7 +93,7 @@ void ChatAPI::prompt(const std::string &prompt,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply) {
std::optional<std::string_view> fakeReply) {
Q_UNUSED(promptCallback);
Q_UNUSED(allowContextShift);
@@ -121,7 +121,7 @@ void ChatAPI::prompt(const std::string &prompt,
if (fakeReply) {
promptCtx.n_past += 1;
m_context.append(formattedPrompt);
m_context.append(QString::fromStdString(*fakeReply));
m_context.append(QString::fromUtf8(fakeReply->data(), fakeReply->size()));
return;
}

View File

@@ -12,9 +12,10 @@
#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include <functional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
class QNetworkAccessManager;
@@ -72,7 +73,7 @@ public:
bool allowContextShift,
PromptContext &ctx,
bool special,
std::string *fakeReply) override;
std::optional<std::string_view> fakeReply) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
@@ -97,7 +98,7 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override
{
(void)ctx;
(void)str;

View File

@@ -626,16 +626,16 @@ void ChatLLM::regenerateResponse()
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
m_promptTokens = 0;
m_response = std::string();
emit responseChanged(QString::fromStdString(m_response));
m_response = m_trimmedResponse = std::string();
emit responseChanged(QString::fromStdString(m_trimmedResponse));
}
void ChatLLM::resetResponse()
{
m_promptTokens = 0;
m_promptResponseTokens = 0;
m_response = std::string();
emit responseChanged(QString::fromStdString(m_response));
m_response = m_trimmedResponse = std::string();
emit responseChanged(QString::fromStdString(m_trimmedResponse));
}
void ChatLLM::resetContext()
@@ -645,9 +645,12 @@ void ChatLLM::resetContext()
m_ctx = LLModel::PromptContext();
}
QString ChatLLM::response() const
QString ChatLLM::response(bool trim) const
{
return QString::fromStdString(remove_leading_whitespace(m_response));
std::string resp = m_response;
if (trim)
resp = remove_leading_whitespace(resp);
return QString::fromStdString(resp);
}
ModelInfo ChatLLM::modelInfo() const
@@ -705,7 +708,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
// check for error
if (token < 0) {
m_response.append(response);
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
m_trimmedResponse = remove_leading_whitespace(m_response);
emit responseChanged(QString::fromStdString(m_trimmedResponse));
return false;
}
@@ -715,7 +719,8 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
m_timer->inc();
Q_ASSERT(!response.empty());
m_response.append(response);
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
m_trimmedResponse = remove_leading_whitespace(m_response);
emit responseChanged(QString::fromStdString(m_trimmedResponse));
return !m_stopGenerating;
}
@@ -741,7 +746,7 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply)
{
if (!isModelLoaded())
return false;
@@ -751,7 +756,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
QList<ResultInfo> databaseResults;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
if (!collectionList.isEmpty()) {
if (!fakeReply && !collectionList.isEmpty()) {
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
emit databaseResultsChanged(databaseResults);
}
@@ -797,7 +802,8 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
/*allowContextShift*/ true, m_ctx, false,
fakeReply.transform(std::mem_fn(&QString::toStdString)));
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -805,9 +811,9 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_timer->stop();
qint64 elapsed = totalTime.elapsed();
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {
m_response = trimmed;
emit responseChanged(QString::fromStdString(m_response));
if (trimmed != m_trimmedResponse) {
m_trimmedResponse = trimmed;
emit responseChanged(QString::fromStdString(m_trimmedResponse));
}
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
@@ -1078,6 +1084,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
QString response;
stream >> response;
m_response = response.toStdString();
m_trimmedResponse = trim_whitespace(m_response);
QString nameResponse;
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
@@ -1306,10 +1313,9 @@ void ChatLLM::processRestoreStateFromText()
auto &response = *it++;
Q_ASSERT(response.first != "Prompt: ");
auto responseText = response.second.toStdString();
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
/*allowContextShift*/ true, m_ctx, false, &responseText);
/*allowContextShift*/ true, m_ctx, false, response.second.toUtf8().constData());
}
if (!m_stopGenerating) {

View File

@@ -116,7 +116,7 @@ public:
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
QString response() const;
QString response(bool trim = true) const;
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &info);
@@ -198,7 +198,7 @@ Q_SIGNALS:
protected:
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens);
int32_t repeat_penalty_tokens, std::optional<QString> fakeReply = {});
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleNamePrompt(int32_t token);
@@ -221,6 +221,7 @@ private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
std::string m_response;
std::string m_trimmedResponse;
std::string m_nameResponse;
QString m_questionResponse;
LLModelInfo m_llModelInfo;

View File

@@ -20,24 +20,25 @@ class LocalDocsCollectionsModel : public QSortFilterProxyModel
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(int updatingCount READ updatingCount NOTIFY updatingCountChanged)
public:
explicit LocalDocsCollectionsModel(QObject *parent);
int count() const { return rowCount(); }
int updatingCount() const;
public Q_SLOTS:
int count() const { return rowCount(); }
void setCollections(const QList<QString> &collections);
int updatingCount() const;
Q_SIGNALS:
void countChanged();
void updatingCountChanged();
private Q_SLOT:
void maybeTriggerUpdatingCountChanged();
protected:
bool filterAcceptsRow(int sourceRow, const QModelIndex &sourceParent) const override;
private Q_SLOTS:
void maybeTriggerUpdatingCountChanged();
private:
QList<QString> m_collections;
int m_updatingCount = 0;

View File

@@ -18,10 +18,12 @@
#include <QVector>
#include <Qt>
#include <QtGlobal>
#include <QtQml>
#include <utility>
using namespace Qt::Literals::StringLiterals;
struct ModelInfo {
Q_GADGET
Q_PROPERTY(QString id READ id WRITE setId)
@@ -523,7 +525,7 @@ private:
protected:
explicit ModelList();
~ModelList() { for (auto *model: m_models) { delete model; } }
~ModelList() override { for (auto *model: std::as_const(m_models)) { delete model; } }
friend class MyModelList;
};

View File

@@ -8,6 +8,7 @@
#include <QSettings>
#include <QString>
#include <QStringList>
#include <QTranslator>
#include <QVector>
#include <cstdint>

File diff suppressed because it is too large Load Diff

View File

@@ -4,22 +4,29 @@
#include "chatllm.h"
#include "database.h"
#include <QHttpServerRequest>
#include <QHttpServer>
#include <QHttpServerResponse>
#include <QObject>
#include <QJsonObject>
#include <QList>
#include <QObject>
#include <QString>
#include <memory>
#include <optional>
#include <utility>
class Chat;
class QHttpServer;
class ChatRequest;
class CompletionRequest;
class Server : public ChatLLM
{
Q_OBJECT
public:
Server(Chat *parent);
virtual ~Server();
explicit Server(Chat *chat);
~Server() override = default;
public Q_SLOTS:
void start();
@@ -27,14 +34,17 @@ public Q_SLOTS:
Q_SIGNALS:
void requestServerNewPromptResponsePair(const QString &prompt);
private:
auto handleCompletionRequest(const CompletionRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;
auto handleChatRequest(const ChatRequest &request) -> std::pair<QHttpServerResponse, std::optional<QJsonObject>>;
private Q_SLOTS:
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
void handleCollectionListChanged(const QList<QString> &collectionList) { m_collections = collectionList; }
private:
Chat *m_chat;
QHttpServer *m_server;
std::unique_ptr<QHttpServer> m_server;
QList<ResultInfo> m_databaseResults;
QList<QString> m_collections;
};