mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-22 13:41:08 +00:00
llmodel: change tokenToString to not use string_view (#968)
fixes a definite use-after-free and likely avoids some other potential ones - std::string will convert to a std::string_view automatically but as soon as the std::string in question goes out of scope it is already freed and the string_view is pointing at freed memory - this is *mostly* fine if its returning a reference to the tokenizer's internal vocab table but it's, imo, too easy to return a reference to a dynamically constructed string with this as replit is doing (and unfortunately needs to do to convert the internal whitespace replacement symbol back to a space)
This commit is contained in:
parent
726dcbd43d
commit
88616fde7f
@ -907,7 +907,7 @@ LLModel::Token GPTJ::sampleToken(PromptContext &promptCtx) const
|
|||||||
d_ptr->rng);
|
d_ptr->rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view GPTJ::tokenToString(Token id) const
|
std::string GPTJ::tokenToString(Token id) const
|
||||||
{
|
{
|
||||||
return d_ptr->vocab.id_to_token[id];
|
return d_ptr->vocab.id_to_token[id];
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ private:
|
|||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||||
Token sampleToken(PromptContext &ctx) const override;
|
Token sampleToken(PromptContext &ctx) const override;
|
||||||
std::string_view tokenToString(Token) const override;
|
std::string tokenToString(Token) const override;
|
||||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
int32_t contextLength() const override;
|
int32_t contextLength() const override;
|
||||||
const std::vector<Token>& endTokens() const override;
|
const std::vector<Token>& endTokens() const override;
|
||||||
|
@ -177,7 +177,7 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
|
|||||||
return fres;
|
return fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view LLamaModel::tokenToString(Token id) const
|
std::string LLamaModel::tokenToString(Token id) const
|
||||||
{
|
{
|
||||||
return llama_token_to_str(d_ptr->ctx, id);
|
return llama_token_to_str(d_ptr->ctx, id);
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ private:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||||
std::string_view tokenToString(Token) const override;
|
std::string tokenToString(Token) const override;
|
||||||
Token sampleToken(PromptContext& ctx) const override;
|
Token sampleToken(PromptContext& ctx) const override;
|
||||||
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext& ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
int32_t contextLength() const override;
|
int32_t contextLength() const override;
|
||||||
|
@ -86,7 +86,7 @@ protected:
|
|||||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||||
// 'prompt' above calls these functions
|
// 'prompt' above calls these functions
|
||||||
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
|
virtual std::vector<Token> tokenize(PromptContext &, const std::string&) const = 0;
|
||||||
virtual std::string_view tokenToString(Token) const = 0;
|
virtual std::string tokenToString(Token) const = 0;
|
||||||
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
virtual Token sampleToken(PromptContext &ctx) const = 0;
|
||||||
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
|
virtual bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const = 0;
|
||||||
virtual int32_t contextLength() const = 0;
|
virtual int32_t contextLength() const = 0;
|
||||||
|
@ -121,7 +121,7 @@ void LLModel::prompt(const std::string &prompt,
|
|||||||
if (id == token) return;
|
if (id == token) return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string_view str = tokenToString(id);
|
const std::string str = tokenToString(id);
|
||||||
|
|
||||||
// Check if the provided str is part of our reverse prompts
|
// Check if the provided str is part of our reverse prompts
|
||||||
bool foundPartialReversePrompt = false;
|
bool foundPartialReversePrompt = false;
|
||||||
|
@ -820,7 +820,7 @@ std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &st
|
|||||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view MPT::tokenToString(Token id) const
|
std::string MPT::tokenToString(Token id) const
|
||||||
{
|
{
|
||||||
return d_ptr->vocab.id_to_token[id];
|
return d_ptr->vocab.id_to_token[id];
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ private:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||||
std::string_view tokenToString(Token) const override;
|
std::string tokenToString(Token) const override;
|
||||||
Token sampleToken(PromptContext &ctx) const override;
|
Token sampleToken(PromptContext &ctx) const override;
|
||||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
int32_t contextLength() const override;
|
int32_t contextLength() const override;
|
||||||
|
@ -910,7 +910,7 @@ std::vector<LLModel::Token> Replit::tokenize(PromptContext &, const std::string
|
|||||||
return replit_tokenizer_tokenize(d_ptr->vocab, str);
|
return replit_tokenizer_tokenize(d_ptr->vocab, str);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view Replit::tokenToString(LLModel::Token id) const
|
std::string Replit::tokenToString(LLModel::Token id) const
|
||||||
{
|
{
|
||||||
return replit_tokenizer_detokenize(d_ptr->vocab, {id});
|
return replit_tokenizer_detokenize(d_ptr->vocab, {id});
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ private:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override;
|
||||||
std::string_view tokenToString(Token) const override;
|
std::string tokenToString(Token) const override;
|
||||||
Token sampleToken(PromptContext &ctx) const override;
|
Token sampleToken(PromptContext &ctx) const override;
|
||||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||||
int32_t contextLength() const override;
|
int32_t contextLength() const override;
|
||||||
|
@ -39,7 +39,7 @@ protected:
|
|||||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||||
// completely replace
|
// completely replace
|
||||||
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
|
std::vector<Token> tokenize(PromptContext &, const std::string&) const override { return std::vector<Token>(); }
|
||||||
std::string_view tokenToString(Token) const override { return std::string_view(); }
|
std::string tokenToString(Token) const override { return std::string(); }
|
||||||
Token sampleToken(PromptContext &ctx) const override { return -1; }
|
Token sampleToken(PromptContext &ctx) const override { return -1; }
|
||||||
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
|
bool evalTokens(PromptContext &/*ctx*/, const std::vector<int32_t>& /*tokens*/) const override { return false; }
|
||||||
int32_t contextLength() const override { return -1; }
|
int32_t contextLength() const override { return -1; }
|
||||||
|
Loading…
Reference in New Issue
Block a user