server: improve correctness of request parsing and responses (#2929)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-09-09 10:48:57 -04:00
committed by GitHub
parent 1aae4ffe0a
commit 39005288c5
22 changed files with 790 additions and 328 deletions

View File

@@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
set(BUILD_SHARED_LIBS ON)

View File

@@ -162,7 +162,7 @@ public:
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::string *fakeReply = nullptr);
std::optional<std::string_view> fakeReply = {});
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
@@ -212,7 +212,7 @@ public:
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
virtual std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
@@ -249,7 +249,8 @@ protected:
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp);
std::vector<Token> embd_inp,
bool isResponse = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);

View File

@@ -536,13 +536,13 @@ size_t LLamaModel::restoreState(const uint8_t *src)
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
}
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
d_ptr->model, str.data(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
/*parse_special*/ special, /*insert_space*/ insertSpace
);
fres.resize(fres_len);

View File

@@ -8,6 +8,7 @@
#include <memory>
#include <string>
#include <string_view>
#include <vector>
struct LLamaPrivate;
@@ -52,7 +53,7 @@ private:
bool m_supportsCompletion = false;
protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;

View File

@@ -12,6 +12,7 @@
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <vector>
struct LLModelWrapper {
@@ -130,13 +131,10 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
std::string fake_reply_str;
if (fake_reply) { fake_reply_str = fake_reply; }
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special, fake_reply_p);
wrapper->promptContext, special,
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies

View File

@@ -11,6 +11,7 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>
namespace ranges = std::ranges;
@@ -45,7 +46,7 @@ void LLModel::prompt(const std::string &prompt,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply)
std::optional<std::string_view> fakeReply)
{
if (!isModelLoaded()) {
std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n";
@@ -129,11 +130,11 @@ void LLModel::prompt(const std::string &prompt,
return; // error
// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
if (!fakeReply) {
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
return; // error
}
@@ -157,7 +158,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp) {
std::vector<Token> embd_inp,
bool isResponse) {
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
@@ -196,7 +198,9 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
for (size_t t = 0; t < tokens; ++t) {
promptCtx.tokens.push_back(batch.at(t));
promptCtx.n_past += 1;
if (!promptCallback(batch.at(t)))
Token tok = batch.at(t);
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
}
i = batch_end;