Trying to shrink the copy+paste code and do more code sharing between backend model impl.

This commit is contained in:
Adam Treat
2023-06-01 23:15:58 -04:00
committed by AT
parent 031d7149a7
commit a41bd6ac0a
9 changed files with 41 additions and 96 deletions

View File

@@ -216,7 +216,7 @@ void LLamaModel::prompt(const std::string &prompt,
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
if (!evalTokens(promptCtx, batch)) {
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
return;
}
@@ -258,7 +258,7 @@ void LLamaModel::prompt(const std::string &prompt,
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
if (!evalTokens(promptCtx, { id })) {
std::cerr << "LLAMA ERROR: Failed to predict next token\n";
return;
}
@@ -305,29 +305,9 @@ void LLamaModel::prompt(const std::string &prompt,
}
}
void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens)
{
size_t i = 0;
promptCtx.n_past = 0;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
goto stop_generating;
}
promptCtx.n_past += batch.size();
if (!recalculate(true))
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
stop_generating:
recalculate(false);
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
}
#if defined(_WIN32)