mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-09 04:20:42 +00:00
Trying to shrink the copy+paste code and do more code sharing between backend model impl.
This commit is contained in:
@@ -944,8 +944,7 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
if (!evalTokens(promptCtx, batch)) {
|
||||
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
||||
return;
|
||||
}
|
||||
@@ -995,8 +994,7 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
if (!evalTokens(promptCtx, { id })) {
|
||||
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
||||
return;
|
||||
}
|
||||
@@ -1042,30 +1040,9 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
}
|
||||
}
|
||||
|
||||
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
||||
bool GPTJ::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens)
|
||||
{
|
||||
size_t i = 0;
|
||||
promptCtx.n_past = 0;
|
||||
while (i < promptCtx.tokens.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
|
||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
||||
d_ptr->mem_per_token)) {
|
||||
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
|
||||
goto stop_generating;
|
||||
}
|
||||
promptCtx.n_past += batch.size();
|
||||
if (!recalculate(true))
|
||||
goto stop_generating;
|
||||
i = batch_end;
|
||||
}
|
||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
||||
|
||||
stop_generating:
|
||||
recalculate(false);
|
||||
return gptj_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token);
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
|
Reference in New Issue
Block a user