mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-03 00:55:49 +00:00
Trying to shrink the copy+paste code and do more code sharing between backend model impl.
This commit is contained in:
parent
031d7149a7
commit
a41bd6ac0a
@ -944,8 +944,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
if (!evalTokens(promptCtx, batch)) {
|
||||||
d_ptr->mem_per_token)) {
|
|
||||||
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -995,8 +994,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
if (!evalTokens(promptCtx, { id })) {
|
||||||
d_ptr->mem_per_token)) {
|
|
||||||
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1042,30 +1040,9 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
bool GPTJ::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens)
|
||||||
{
|
{
|
||||||
size_t i = 0;
|
return gptj_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token);
|
||||||
promptCtx.n_past = 0;
|
|
||||||
while (i < promptCtx.tokens.size()) {
|
|
||||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
|
||||||
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
|
||||||
|
|
||||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
|
||||||
|
|
||||||
if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
|
||||||
d_ptr->mem_per_token)) {
|
|
||||||
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
|
|
||||||
goto stop_generating;
|
|
||||||
}
|
|
||||||
promptCtx.n_past += batch.size();
|
|
||||||
if (!recalculate(true))
|
|
||||||
goto stop_generating;
|
|
||||||
i = batch_end;
|
|
||||||
}
|
|
||||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
|
||||||
|
|
||||||
stop_generating:
|
|
||||||
recalculate(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
@ -25,13 +25,10 @@ public:
|
|||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
std::function<bool(bool)> recalculateCallback,
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() const override;
|
int32_t threadCount() const override;
|
||||||
|
|
||||||
protected:
|
|
||||||
void recalculateContext(PromptContext &promptCtx,
|
|
||||||
std::function<bool(bool)> recalculate) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GPTJPrivate *d_ptr;
|
GPTJPrivate *d_ptr;
|
||||||
};
|
};
|
||||||
|
@ -216,7 +216,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
if (!evalTokens(promptCtx, batch)) {
|
||||||
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -258,7 +258,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
if (!evalTokens(promptCtx, { id })) {
|
||||||
std::cerr << "LLAMA ERROR: Failed to predict next token\n";
|
std::cerr << "LLAMA ERROR: Failed to predict next token\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -305,29 +305,9 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens)
|
||||||
{
|
{
|
||||||
size_t i = 0;
|
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
|
||||||
promptCtx.n_past = 0;
|
|
||||||
while (i < promptCtx.tokens.size()) {
|
|
||||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
|
||||||
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
|
||||||
|
|
||||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
|
||||||
|
|
||||||
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
|
|
||||||
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
|
|
||||||
goto stop_generating;
|
|
||||||
}
|
|
||||||
promptCtx.n_past += batch.size();
|
|
||||||
if (!recalculate(true))
|
|
||||||
goto stop_generating;
|
|
||||||
i = batch_end;
|
|
||||||
}
|
|
||||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
|
||||||
|
|
||||||
stop_generating:
|
|
||||||
recalculate(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
@ -25,13 +25,10 @@ public:
|
|||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
std::function<bool(bool)> recalculateCallback,
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() const override;
|
int32_t threadCount() const override;
|
||||||
|
|
||||||
protected:
|
|
||||||
void recalculateContext(PromptContext &promptCtx,
|
|
||||||
std::function<bool(bool)> recalculate) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LLamaPrivate *d_ptr;
|
LLamaPrivate *d_ptr;
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "llmodel.h"
|
#include "llmodel.h"
|
||||||
#include "dlhandle.h"
|
#include "dlhandle.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -95,6 +96,28 @@ const LLModel::Implementation* LLModel::implementation(std::ifstream& f, const s
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate) {
|
||||||
|
size_t i = 0;
|
||||||
|
promptCtx.n_past = 0;
|
||||||
|
while (i < promptCtx.tokens.size()) {
|
||||||
|
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
||||||
|
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
||||||
|
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||||
|
if (!evalTokens(promptCtx, batch)) {
|
||||||
|
std::cerr << "LLModel ERROR: Failed to process prompt\n";
|
||||||
|
goto stop_generating;
|
||||||
|
}
|
||||||
|
promptCtx.n_past += batch.size();
|
||||||
|
if (!recalculate(true))
|
||||||
|
goto stop_generating;
|
||||||
|
i = batch_end;
|
||||||
|
}
|
||||||
|
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
||||||
|
|
||||||
|
stop_generating:
|
||||||
|
recalculate(false);
|
||||||
|
}
|
||||||
|
|
||||||
LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) {
|
LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) {
|
||||||
//TODO: Auto-detect CUDA/OpenCL
|
//TODO: Auto-detect CUDA/OpenCL
|
||||||
if (buildVariant == "auto") {
|
if (buildVariant == "auto") {
|
||||||
|
@ -64,6 +64,7 @@ public:
|
|||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
std::function<bool(bool)> recalculateCallback,
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) = 0;
|
PromptContext &ctx) = 0;
|
||||||
|
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) = 0;
|
||||||
virtual void setThreadCount(int32_t /*n_threads*/) {}
|
virtual void setThreadCount(int32_t /*n_threads*/) {}
|
||||||
virtual int32_t threadCount() const { return 1; }
|
virtual int32_t threadCount() const { return 1; }
|
||||||
|
|
||||||
@ -78,7 +79,6 @@ public:
|
|||||||
protected:
|
protected:
|
||||||
const Implementation *m_implementation = nullptr;
|
const Implementation *m_implementation = nullptr;
|
||||||
|
|
||||||
virtual void recalculateContext(PromptContext &promptCtx,
|
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);
|
||||||
std::function<bool(bool)> recalculate) = 0;
|
|
||||||
};
|
};
|
||||||
#endif // LLMODEL_H
|
#endif // LLMODEL_H
|
||||||
|
@ -869,8 +869,7 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
if (!evalTokens(promptCtx, batch)) {
|
||||||
d_ptr->mem_per_token)) {
|
|
||||||
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -920,8 +919,7 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
|
if (!evalTokens(promptCtx, { id })) {
|
||||||
d_ptr->mem_per_token)) {
|
|
||||||
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -971,30 +969,9 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MPT::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
|
bool MPT::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens)
|
||||||
{
|
{
|
||||||
size_t i = 0;
|
return mpt_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token);
|
||||||
promptCtx.n_past = 0;
|
|
||||||
while (i < promptCtx.tokens.size()) {
|
|
||||||
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
|
|
||||||
std::vector<int> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
|
|
||||||
|
|
||||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
|
||||||
|
|
||||||
if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
|
|
||||||
d_ptr->mem_per_token)) {
|
|
||||||
std::cerr << "MPT ERROR: Failed to process prompt\n";
|
|
||||||
goto stop_generating;
|
|
||||||
}
|
|
||||||
promptCtx.n_past += batch.size();
|
|
||||||
if (!recalculate(true))
|
|
||||||
goto stop_generating;
|
|
||||||
i = batch_end;
|
|
||||||
}
|
|
||||||
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
|
|
||||||
|
|
||||||
stop_generating:
|
|
||||||
recalculate(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
@ -25,13 +25,10 @@ public:
|
|||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
std::function<bool(bool)> recalculateCallback,
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) override;
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() const override;
|
int32_t threadCount() const override;
|
||||||
|
|
||||||
protected:
|
|
||||||
void recalculateContext(PromptContext &promptCtx,
|
|
||||||
std::function<bool(bool)> recalculate) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MPTPrivate *d_ptr;
|
MPTPrivate *d_ptr;
|
||||||
};
|
};
|
||||||
|
@ -24,6 +24,7 @@ public:
|
|||||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||||
std::function<bool(bool)> recalculateCallback,
|
std::function<bool(bool)> recalculateCallback,
|
||||||
PromptContext &ctx) override;
|
PromptContext &ctx) override;
|
||||||
|
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) override { return true; }
|
||||||
void setThreadCount(int32_t n_threads) override;
|
void setThreadCount(int32_t n_threads) override;
|
||||||
int32_t threadCount() const override;
|
int32_t threadCount() const override;
|
||||||
|
|
||||||
@ -33,10 +34,6 @@ public:
|
|||||||
QList<QString> context() const { return m_context; }
|
QList<QString> context() const { return m_context; }
|
||||||
void setContext(const QList<QString> &context) { m_context = context; }
|
void setContext(const QList<QString> &context) { m_context = context; }
|
||||||
|
|
||||||
protected:
|
|
||||||
void recalculateContext(PromptContext &promptCtx,
|
|
||||||
std::function<bool(bool)> recalculate) override {}
|
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
void handleFinished();
|
void handleFinished();
|
||||||
void handleReadyRead();
|
void handleReadyRead();
|
||||||
|
Loading…
Reference in New Issue
Block a user