mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-07 19:40:21 +00:00
server: improve correctness of request parsing and responses (#2929)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
|
||||
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
|
||||
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD 23)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
set(BUILD_SHARED_LIBS ON)
|
||||
|
Submodule gpt4all-backend/deps/llama.cpp-mainline updated: 443665aec4...ced74fbad4
@@ -162,7 +162,7 @@ public:
|
||||
bool allowContextShift,
|
||||
PromptContext &ctx,
|
||||
bool special = false,
|
||||
std::string *fakeReply = nullptr);
|
||||
std::optional<std::string_view> fakeReply = {});
|
||||
|
||||
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
||||
|
||||
@@ -212,7 +212,7 @@ public:
|
||||
protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
|
||||
virtual std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0;
|
||||
virtual bool isSpecialToken(Token id) const = 0;
|
||||
virtual std::string tokenToString(Token id) const = 0;
|
||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||
@@ -249,7 +249,8 @@ protected:
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp);
|
||||
std::vector<Token> embd_inp,
|
||||
bool isResponse = false);
|
||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx);
|
||||
|
@@ -536,13 +536,13 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special)
|
||||
{
|
||||
bool atStart = m_tokenize_last_token == -1;
|
||||
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
|
||||
std::vector<LLModel::Token> fres(str.length() + 4);
|
||||
int32_t fres_len = llama_tokenize_gpt4all(
|
||||
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
||||
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
||||
/*parse_special*/ special, /*insert_space*/ insertSpace
|
||||
);
|
||||
fres.resize(fres_len);
|
||||
|
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
struct LLamaPrivate;
|
||||
@@ -52,7 +53,7 @@ private:
|
||||
bool m_supportsCompletion = false;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
|
||||
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override;
|
||||
bool isSpecialToken(Token id) const override;
|
||||
std::string tokenToString(Token id) const override;
|
||||
Token sampleToken(PromptContext &ctx) const override;
|
||||
|
@@ -12,6 +12,7 @@
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
struct LLModelWrapper {
|
||||
@@ -130,13 +131,10 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
|
||||
std::string fake_reply_str;
|
||||
if (fake_reply) { fake_reply_str = fake_reply; }
|
||||
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
|
||||
wrapper->promptContext, special, fake_reply_p);
|
||||
wrapper->promptContext, special,
|
||||
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
namespace ranges = std::ranges;
|
||||
@@ -45,7 +46,7 @@ void LLModel::prompt(const std::string &prompt,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
bool special,
|
||||
std::string *fakeReply)
|
||||
std::optional<std::string_view> fakeReply)
|
||||
{
|
||||
if (!isModelLoaded()) {
|
||||
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
|
||||
@@ -129,11 +130,11 @@ void LLModel::prompt(const std::string &prompt,
|
||||
return; // error
|
||||
|
||||
// decode the assistant's reply, either generated or spoofed
|
||||
if (fakeReply == nullptr) {
|
||||
if (!fakeReply) {
|
||||
generateResponse(responseCallback, allowContextShift, promptCtx);
|
||||
} else {
|
||||
embd_inp = tokenize(promptCtx, *fakeReply, false);
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
|
||||
return; // error
|
||||
}
|
||||
|
||||
@@ -157,7 +158,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp) {
|
||||
std::vector<Token> embd_inp,
|
||||
bool isResponse) {
|
||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
||||
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
||||
@@ -196,7 +198,9 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
promptCtx.n_past += 1;
|
||||
if (!promptCallback(batch.at(t)))
|
||||
Token tok = batch.at(t);
|
||||
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
||||
if (!res)
|
||||
return false;
|
||||
}
|
||||
i = batch_end;
|
||||
|
Reference in New Issue
Block a user