chat: faster KV shift, continue generating, fix stop sequences (#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
Jared Van Bortel
2024-08-07 11:25:24 -04:00
committed by GitHub
parent 90de2d32f8
commit be66ec8ab5
16 changed files with 285 additions and 230 deletions

View File

@@ -134,7 +134,7 @@ public:
int32_t n_batch = 9;
float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};
using ProgressCallback = std::function<bool(float progress)>;
@@ -159,7 +159,7 @@ public:
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::string *fakeReply = nullptr);
@@ -213,9 +213,11 @@ protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
@@ -232,10 +234,6 @@ protected:
return -1;
}
// This is a helper function called from the default implementation of 'prompt' but it can be
// shared by all base classes so it isn't virtual
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);
const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
@@ -249,11 +247,11 @@ protected:
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx);
Token m_tokenize_last_token = -1; // not serialized