mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-06 19:10:15 +00:00
chat: faster KV shift, continue generating, fix stop sequences (#2781)
* Don't stop generating at end of context * Use llama_kv_cache ops to shift context * Fix and improve reverse prompt detection * Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
@@ -134,7 +134,7 @@ public:
|
||||
int32_t n_batch = 9;
|
||||
float repeat_penalty = 1.10f;
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
|
||||
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
|
||||
};
|
||||
|
||||
using ProgressCallback = std::function<bool(float progress)>;
|
||||
@@ -159,7 +159,7 @@ public:
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &ctx,
|
||||
bool special = false,
|
||||
std::string *fakeReply = nullptr);
|
||||
@@ -213,9 +213,11 @@ protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
|
||||
virtual bool isSpecialToken(Token id) const = 0;
|
||||
virtual std::string tokenToString(Token id) const = 0;
|
||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
|
||||
virtual void shiftContext(PromptContext &promptCtx) = 0;
|
||||
virtual int32_t contextLength() const = 0;
|
||||
virtual const std::vector<Token> &endTokens() const = 0;
|
||||
virtual bool shouldAddBOS() const = 0;
|
||||
@@ -232,10 +234,6 @@ protected:
|
||||
return -1;
|
||||
}
|
||||
|
||||
// This is a helper function called from the default implementation of 'prompt' but it can be
|
||||
// shared by all base classes so it isn't virtual
|
||||
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);
|
||||
|
||||
const Implementation *m_implementation = nullptr;
|
||||
|
||||
ProgressCallback m_progressCallback;
|
||||
@@ -249,11 +247,11 @@ protected:
|
||||
|
||||
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp);
|
||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx);
|
||||
|
||||
Token m_tokenize_last_token = -1; // not serialized
|
||||
|
Reference in New Issue
Block a user