Enable more warning flags, and fix more warnings (#3065)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-10-18 12:11:03 -04:00
committed by GitHub
parent eed92fd5b2
commit c3357b7625
16 changed files with 27 additions and 69 deletions

View File

@@ -213,7 +213,7 @@ public:
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0;
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;

View File

@@ -511,7 +511,7 @@ size_t LLamaModel::restoreState(std::span<const uint8_t> src)
return llama_state_set_data(d_ptr->ctx, src.data(), src.size());
}
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special)
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);

View File

@@ -54,7 +54,7 @@ private:
bool m_supportsCompletion = false;
protected:
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override;
std::vector<Token> tokenize(std::string_view str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override;

View File

@@ -90,41 +90,33 @@ void LLModel::prompt(const std::string &prompt,
}
}
auto old_n_past = promptCtx.n_past; // prepare to fake n_past for tokenize
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
// this is unusual, but well-defined
std::cerr << __func__ << ": prompt template has no placeholder\n";
embd_inp = tokenize(promptCtx, promptTemplate, true);
embd_inp = tokenize(promptTemplate, true);
} else {
// template: beginning of user prompt
const auto &phUser = placeholders[0];
std::string userPrefix(phUser.prefix());
if (!userPrefix.empty()) {
embd_inp = tokenize(promptCtx, userPrefix, true);
promptCtx.n_past += embd_inp.size();
}
if (!userPrefix.empty())
embd_inp = tokenize(userPrefix, true);
// user input (shouldn't have special token processing)
auto tokens = tokenize(promptCtx, prompt, special);
auto tokens = tokenize(prompt, special);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
// template: end of user prompt + start of assistant prompt
size_t start = phUser.position() + phUser.length();
size_t end = placeholders.size() >= 2 ? placeholders[1].position() : promptTemplate.length();
auto userToAsst = promptTemplate.substr(start, end - start);
if (!userToAsst.empty()) {
tokens = tokenize(promptCtx, userToAsst, true);
tokens = tokenize(userToAsst, true);
embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end());
promptCtx.n_past += tokens.size();
}
}
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
return; // error
@@ -133,7 +125,7 @@ void LLModel::prompt(const std::string &prompt,
if (!fakeReply) {
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
embd_inp = tokenize(*fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true))
return; // error
}
@@ -148,7 +140,7 @@ void LLModel::prompt(const std::string &prompt,
asstSuffix = "\n\n"; // default to a blank link, good for e.g. Alpaca
}
if (!asstSuffix.empty()) {
embd_inp = tokenize(promptCtx, asstSuffix, true);
embd_inp = tokenize(asstSuffix, true);
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
}
}