chat: faster KV shift, continue generating, fix stop sequences (#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
Jared Van Bortel
2024-08-07 11:25:24 -04:00
committed by GitHub
parent 90de2d32f8
commit be66ec8ab5
16 changed files with 285 additions and 230 deletions

View File

@@ -69,7 +69,7 @@ public:
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &ctx,
bool special,
std::string *fakeReply) override;
@@ -97,38 +97,57 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override {
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override
{
(void)ctx;
(void)str;
(void)special;
throw std::logic_error("not implemented");
}
std::string tokenToString(Token id) const override {
bool isSpecialToken(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override {
std::string tokenToString(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override
{
(void)ctx;
throw std::logic_error("not implemented");
}
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override {
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
{
(void)ctx;
(void)tokens;
throw std::logic_error("not implemented");
}
int32_t contextLength() const override {
void shiftContext(PromptContext &promptCtx) override
{
(void)promptCtx;
throw std::logic_error("not implemented");
}
const std::vector<Token> &endTokens() const override {
int32_t contextLength() const override
{
throw std::logic_error("not implemented");
}
bool shouldAddBOS() const override {
const std::vector<Token> &endTokens() const override
{
throw std::logic_error("not implemented");
}
bool shouldAddBOS() const override
{
throw std::logic_error("not implemented");
}