typescript: async generator and token stream (#1897)

Signed-off-by: Tare Ebelo <75279482+TareHimself@users.noreply.github.com>
Signed-off-by: jacob <jacoobes@sern.dev>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: jacob <jacoobes@sern.dev>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
TareHimself
2024-02-24 17:50:14 -05:00
committed by GitHub
parent ef518fae3e
commit a153cc5b25
14 changed files with 1517 additions and 955 deletions

View File

@@ -1,60 +1,146 @@
#include "prompt.h"
#include <future>
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) {
if(_config.bHasTokenCallback){
_tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this);
}
}
TsfnContext::TsfnContext(Napi::Env env, const PromptWorkContext& pc)
: deferred_(Napi::Promise::Deferred::New(env)), pc(pc) {
}
namespace {
static std::string *res;
}
PromptWorker::~PromptWorker()
{
if(_config.bHasTokenCallback){
_tsfn.Release();
}
}
bool response_callback(int32_t token_id, const char *response) {
*res += response;
return token_id != -1;
}
bool recalculate_callback (bool isrecalculating) {
return isrecalculating;
};
bool prompt_callback (int32_t tid) {
return true;
};
void PromptWorker::Execute()
{
_config.mutex->lock();
// The thread entry point. This takes as its arguments the specific
// threadsafe-function context created inside the main thread.
void threadEntry(TsfnContext* context) {
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
res = &context->pc.res;
// Perform a call into JavaScript.
napi_status status =
context->tsfn.BlockingCall(&context->pc,
[](Napi::Env env, Napi::Function jsCallback, PromptWorkContext* pc) {
llmodel_prompt(
pc->inference_,
pc->question.c_str(),
&prompt_callback,
&response_callback,
&recalculate_callback,
&pc->prompt_params
);
});
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
if (status != napi_ok) {
Napi::Error::Fatal(
"ThreadEntry",
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
// Release the thread-safe function. This decrements the internal thread
// count, and will perform finalization since the count will reach 0.
context->tsfn.Release();
}
auto ctx = &_config.context;
void FinalizerCallback(Napi::Env env,
void* finalizeData,
TsfnContext* context) {
// Resolve the Promise previously returned to JS
context->deferred_.Resolve(Napi::String::New(env, context->pc.res));
// Wait for the thread to finish executing before proceeding.
context->nativeThread.join();
delete context;
}
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
wrapper->promptContext.tokens.resize(ctx->n_past);
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Napi::Error::Fatal(
// "SUPRA",
// "About to prompt");
// Call the C++ prompt method
wrapper->llModel->prompt(
_config.prompt,
[](int32_t tid) { return true; },
[this](int32_t token_id, const std::string tok)
{
return ResponseCallback(token_id, tok);
},
[](bool isRecalculating)
{
return isRecalculating;
},
wrapper->promptContext);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->logits = wrapper->promptContext.logits.data();
ctx->logits_size = wrapper->promptContext.logits.size();
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock();
}
void PromptWorker::OnOK()
{
promise.Resolve(Napi::String::New(Env(), result));
}
void PromptWorker::OnError(const Napi::Error &e)
{
promise.Reject(e.Value());
}
Napi::Promise PromptWorker::GetPromise()
{
return promise.Promise();
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
{
if (token_id == -1)
{
return false;
}
if(!_config.bHasTokenCallback){
return true;
}
result += token;
std::promise<bool> promise;
auto info = new TokenCallbackInfo();
info->tokenId = token_id;
info->token = token;
info->total = result;
auto future = promise.get_future();
auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value)
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::String::New(env, value->token);
auto total = Napi::String::New(env,value->total);
auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean();
promise.set_value(jsResult);
// We're finished with the data.
delete value;
});
if (status != napi_ok) {
Napi::Error::Fatal(
"PromptWorkerResponseCallback",
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}
bool PromptWorker::RecalculateCallback(bool isRecalculating)
{
return isRecalculating;
}
bool PromptWorker::PromptCallback(int32_t tid)
{
return true;
}