mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-10-31 22:02:53 +00:00 
			
		
		
		
	* chore: boilerplate, refactor in future * chore: boilerplate * feat: can compile succesfully * document .gyp file * add src, test and fix gyp * progress on prompting and some helper methods * add destructor and basic prompting work, prepare download function * download function done * download function edits and adding documentation * fix bindings memory issue and add tests and specs * add more documentation and readme * add npmignore * Update README.md Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com> * Update package.json - redundant scripts Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com> --------- Signed-off-by: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com>
		
			
				
	
	
		
			228 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			228 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <napi.h>
 | |
| #include <iostream>
 | |
| #include "llmodel_c.h" 
 | |
| #include "llmodel.h"
 | |
| #include "gptj.h"
 | |
| #include "llamamodel.h"
 | |
| #include "mpt.h"
 | |
| #include "stdcapture.h"
 | |
| 
 | |
| class NodeModelWrapper : public Napi::ObjectWrap<NodeModelWrapper> {
 | |
| public:
 | |
|   static Napi::Object Init(Napi::Env env, Napi::Object exports) {
 | |
|     Napi::Function func = DefineClass(env, "LLModel", {
 | |
|       InstanceMethod("type",  &NodeModelWrapper::getType),
 | |
|       InstanceMethod("name", &NodeModelWrapper::getName),
 | |
|       InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
 | |
|       InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt),
 | |
|       InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
 | |
|       InstanceMethod("threadCount", &NodeModelWrapper::ThreadCount),
 | |
|     });
 | |
| 
 | |
|     Napi::FunctionReference* constructor = new Napi::FunctionReference();
 | |
|     *constructor = Napi::Persistent(func);
 | |
|     env.SetInstanceData(constructor);
 | |
| 
 | |
|     exports.Set("LLModel", func);
 | |
|     return exports;
 | |
|   }
 | |
| 
 | |
|   Napi::Value getType(const Napi::CallbackInfo& info) 
 | |
|   {
 | |
|     return Napi::String::New(info.Env(), type);
 | |
|   }
 | |
| 
 | |
|   NodeModelWrapper(const Napi::CallbackInfo& info) : Napi::ObjectWrap<NodeModelWrapper>(info) 
 | |
|   {
 | |
|     auto env = info.Env();
 | |
|     std::string weights_path = info[0].As<Napi::String>().Utf8Value();
 | |
| 
 | |
|     const char *c_weights_path = weights_path.c_str();
 | |
|     
 | |
|     inference_ = create_model_set_type(c_weights_path);
 | |
| 
 | |
|     auto success = llmodel_loadModel(inference_, c_weights_path);
 | |
|     if(!success) {
 | |
|         Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException(); 
 | |
|         return;
 | |
|     }
 | |
|     name = weights_path.substr(weights_path.find_last_of("/\\") + 1);
 | |
|     
 | |
|   };
 | |
|   ~NodeModelWrapper() {
 | |
|     // destroying the model manually causes exit code 3221226505, why?
 | |
|     // However, bindings seem to operate fine without destructing pointer
 | |
|     //llmodel_model_destroy(inference_);
 | |
|   }
 | |
| 
 | |
|   Napi::Value IsModelLoaded(const Napi::CallbackInfo& info) {
 | |
|     return Napi::Boolean::New(info.Env(), llmodel_isModelLoaded(inference_));
 | |
|   }
 | |
| 
 | |
|   Napi::Value StateSize(const Napi::CallbackInfo& info) {
 | |
|     // Implement the binding for the stateSize method
 | |
|     return Napi::Number::New(info.Env(), static_cast<int64_t>(llmodel_get_state_size(inference_)));
 | |
|   }
 | |
| 
 | |
| /**
 | |
|  * Generate a response using the model.
 | |
|  * @param model A pointer to the llmodel_model instance.
 | |
|  * @param prompt A string representing the input prompt.
 | |
|  * @param prompt_callback A callback function for handling the processing of prompt.
 | |
|  * @param response_callback A callback function for handling the generated response.
 | |
|  * @param recalculate_callback A callback function for handling recalculation requests.
 | |
|  * @param ctx A pointer to the llmodel_prompt_context structure.
 | |
|  */
 | |
|   Napi::Value Prompt(const Napi::CallbackInfo& info) {
 | |
| 
 | |
|     auto env = info.Env();
 | |
| 
 | |
|     std::string question;
 | |
|     if(info[0].IsString()) {
 | |
|         question = info[0].As<Napi::String>().Utf8Value();
 | |
|     } else {
 | |
|         Napi::Error::New(env, "invalid string argument").ThrowAsJavaScriptException();
 | |
|         return env.Undefined();
 | |
|     }
 | |
|     //defaults copied from python bindings
 | |
|     llmodel_prompt_context promptContext = {
 | |
|              .logits = nullptr,
 | |
|              .tokens = nullptr,
 | |
|              .n_past = 0,
 | |
|              .n_ctx = 1024,
 | |
|              .n_predict = 128,
 | |
|              .top_k = 40,
 | |
|              .top_p = 0.9f,
 | |
|              .temp = 0.72f,
 | |
|              .n_batch = 8,
 | |
|              .repeat_penalty = 1.0f,
 | |
|              .repeat_last_n = 10,
 | |
|              .context_erase = 0.5
 | |
|          };
 | |
|     if(info[1].IsObject())
 | |
|     {
 | |
|         auto inputObject = info[1].As<Napi::Object>();
 | |
|              
 | |
|         // Extract and assign the properties
 | |
|         if (inputObject.Has("logits") || inputObject.Has("tokens")) {
 | |
|             Napi::Error::New(env, "Invalid input: 'logits' or 'tokens' properties are not allowed").ThrowAsJavaScriptException();
 | |
|             return env.Undefined();
 | |
|         }
 | |
|              // Assign the remaining properties
 | |
|              if(inputObject.Has("n_past")) {
 | |
|                  promptContext.n_past = inputObject.Get("n_past").As<Napi::Number>().Int32Value();
 | |
|              }
 | |
|              if(inputObject.Has("n_ctx")) {
 | |
|                  promptContext.n_ctx = inputObject.Get("n_ctx").As<Napi::Number>().Int32Value();
 | |
|              }
 | |
|              if(inputObject.Has("n_predict")) {
 | |
|                  promptContext.n_predict = inputObject.Get("n_predict").As<Napi::Number>().Int32Value();
 | |
|              }
 | |
|              if(inputObject.Has("top_k")) {
 | |
|                  promptContext.top_k = inputObject.Get("top_k").As<Napi::Number>().Int32Value();
 | |
|              }
 | |
|              if(inputObject.Has("top_p")) {
 | |
|                  promptContext.top_p = inputObject.Get("top_p").As<Napi::Number>().FloatValue();
 | |
|              }
 | |
|              if(inputObject.Has("temp")) {
 | |
|                  promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
 | |
|              }
 | |
|              if(inputObject.Has("n_batch")) {
 | |
|                  promptContext.n_batch = inputObject.Get("n_batch").As<Napi::Number>().Int32Value();
 | |
|              }
 | |
|              if(inputObject.Has("repeat_penalty")) {
 | |
|                  promptContext.repeat_penalty = inputObject.Get("repeat_penalty").As<Napi::Number>().FloatValue();
 | |
|              }
 | |
|              if(inputObject.Has("repeat_last_n")) {
 | |
|                  promptContext.repeat_last_n = inputObject.Get("repeat_last_n").As<Napi::Number>().Int32Value();
 | |
|              }
 | |
|              if(inputObject.Has("context_erase")) {
 | |
|                  promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
 | |
|              }
 | |
|     }
 | |
|     //    custom callbacks are weird with the gpt4all c bindings: I need to turn Napi::Functions into  raw c function pointers,
 | |
|     //    but it doesn't seem like its possible? (TODO, is it possible?)
 | |
| 
 | |
|     //    if(info[1].IsFunction()) {
 | |
|     //        Napi::Callback cb = *info[1].As<Napi::Function>();
 | |
|     //    }
 | |
| 
 | |
| 
 | |
|     // For now, simple capture of stdout
 | |
|     // possible TODO: put this on a libuv async thread. (AsyncWorker)
 | |
|     CoutRedirect cr;
 | |
|     llmodel_prompt(inference_, question.c_str(), &prompt_callback, &response_callback, &recalculate_callback,  &promptContext);
 | |
|     return Napi::String::New(env, cr.getString());
 | |
|   }
 | |
| 
 | |
|   void SetThreadCount(const Napi::CallbackInfo& info) {
 | |
|     if(info[0].IsNumber()) {
 | |
|         llmodel_setThreadCount(inference_, info[0].As<Napi::Number>().Int64Value());
 | |
|     } else {
 | |
|         Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException(); 
 | |
|         return;
 | |
|     }
 | |
|   }
 | |
|   Napi::Value getName(const Napi::CallbackInfo& info) {
 | |
|     return Napi::String::New(info.Env(), name);
 | |
|   }
 | |
|   Napi::Value ThreadCount(const Napi::CallbackInfo& info) {
 | |
|     return Napi::Number::New(info.Env(), llmodel_threadCount(inference_));
 | |
|   }
 | |
| 
 | |
| private:
 | |
|   llmodel_model inference_;
 | |
|   std::string type;
 | |
|   std::string name;
 | |
| 
 | |
| 
 | |
|   //wrapper cb to capture output into stdout.then, CoutRedirect captures this 
 | |
|   // and writes it to a file
 | |
|   static bool response_callback(int32_t tid, const char* resp) 
 | |
|   {
 | |
|     if(tid != -1) {
 | |
|         std::cout<<std::string(resp);
 | |
|         return true;
 | |
|     }
 | |
|     return false;
 | |
|   }
 | |
| 
 | |
|   static bool prompt_callback(int32_t tid) { return true; }
 | |
|   static bool recalculate_callback(bool isrecalculating) { return  isrecalculating; }
 | |
|   // Had to use this instead of the c library in order 
 | |
|   // set the type of the model loaded.
 | |
|   // causes side effect: type is mutated;
 | |
|   llmodel_model create_model_set_type(const char* c_weights_path) 
 | |
|   {
 | |
| 
 | |
|     uint32_t magic;
 | |
|     llmodel_model model;
 | |
|     FILE *f = fopen(c_weights_path, "rb");
 | |
|     fread(&magic, sizeof(magic), 1, f);
 | |
| 
 | |
|     if (magic == 0x67676d6c) {
 | |
|         model = llmodel_gptj_create();  
 | |
|         type = "gptj";
 | |
|     }
 | |
|     else if (magic == 0x67676a74) {
 | |
|         model = llmodel_llama_create(); 
 | |
|         type = "llama";
 | |
|     }
 | |
|     else if (magic == 0x67676d6d) {
 | |
|         model = llmodel_mpt_create();   
 | |
|         type = "mpt";
 | |
|     }
 | |
|     else  {fprintf(stderr, "Invalid model file\n");}
 | |
|     fclose(f);
 | |
|     
 | |
|     return model;
 | |
|   }
 | |
| };
 | |
| 
 | |
| //Exports Bindings
 | |
| Napi::Object Init(Napi::Env env, Napi::Object exports) {
 | |
|   return NodeModelWrapper::Init(env, exports);
 | |
| }
 | |
| 
 | |
| NODE_API_MODULE(NODE_GYP_MODULE_NAME, Init)
 |