typescript!: chatSessions, fixes, tokenStreams (#2045)

Signed-off-by: jacob <jacoobes@sern.dev>
Signed-off-by: limez <limez@protonmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: limez <limez@protonmail.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jacob Nguyen
2024-03-28 11:08:23 -05:00
committed by GitHub
parent 6c8a44f6c4
commit 55f3b056b7
33 changed files with 2573 additions and 1349 deletions

View File

@@ -11,37 +11,116 @@ pnpm install gpt4all@latest
``` ```
The original [GPT4All typescript bindings](https://github.com/nomic-ai/gpt4all-ts) are now out of date. ## Contents
* New bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
* The nodejs api has made strides to mirror the python api. It is not 100% mirrored, but many pieces of the api resemble its python counterpart.
* Everything should work out the box.
* See [API Reference](#api-reference) * See [API Reference](#api-reference)
* See [Examples](#api-example)
* See [Developing](#develop)
* GPT4ALL nodejs bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
## Api Example
### Chat Completion ### Chat Completion
```js ```js
import { createCompletion, loadModel } from '../src/gpt4all.js' import { LLModel, createCompletion, DEFAULT_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY, loadModel } from '../src/gpt4all.js'
const model = await loadModel('mistral-7b-openorca.Q4_0.gguf', { verbose: true }); const model = await loadModel( 'mistral-7b-openorca.gguf2.Q4_0.gguf', { verbose: true, device: 'gpu' });
const response = await createCompletion(model, [ const completion1 = await createCompletion(model, 'What is 1 + 1?', { verbose: true, })
{ role : 'system', content: 'You are meant to be annoying and unhelpful.' }, console.log(completion1.message)
{ role : 'user', content: 'What is 1 + 1?' }
]);
const completion2 = await createCompletion(model, 'And if we add two?', { verbose: true })
console.log(completion2.message)
model.dispose()
``` ```
### Embedding ### Embedding
```js ```js
import { createEmbedding, loadModel } from '../src/gpt4all.js' import { loadModel, createEmbedding } from '../src/gpt4all.js'
const model = await loadModel('ggml-all-MiniLM-L6-v2-f16', { verbose: true }); const embedder = await loadModel("all-MiniLM-L6-v2-f16.gguf", { verbose: true, type: 'embedding'})
const fltArray = createEmbedding(model, "Pain is inevitable, suffering optional"); console.log(createEmbedding(embedder, "Maybe Minecraft was the friends we made along the way"));
``` ```
### Chat Sessions
```js
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true,
device: "gpu",
});
const chat = await model.createChatSession();
await createCompletion(
chat,
"Why are bananas rather blue than bread at night sometimes?",
{
verbose: true,
}
);
await createCompletion(chat, "Are you sure?", { verbose: true, });
```
### Streaming responses
```js
import gpt from "../src/gpt4all.js";
const model = await gpt.loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("### Stream:");
const stream = gpt.createCompletionStream(model, "How are you?");
stream.tokens.on("data", (data) => {
process.stdout.write(data);
});
//wait till stream finishes. We cannot continue until this one is done.
await stream.result;
process.stdout.write("\n");
process.stdout.write("### Stream with pipe:");
const stream2 = gpt.createCompletionStream(
model,
"Please say something nice about node streams."
);
stream2.tokens.pipe(process.stdout);
await stream2.result;
process.stdout.write("\n");
console.log("done");
model.dispose();
```
### Async Generators
```js
import gpt from "../src/gpt4all.js";
const model = await gpt.loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("### Generator:");
const gen = gpt.createCompletionGenerator(model, "Redstone in Minecraft is Turing Complete. Let that sink in. (let it in!)");
for await (const chunk of gen) {
process.stdout.write(chunk);
}
process.stdout.write("\n");
model.dispose();
```
## Develop
### Build Instructions ### Build Instructions
* binding.gyp is compile config * binding.gyp is compile config
@@ -131,21 +210,27 @@ yarn test
* why your model may be spewing bull 💩 * why your model may be spewing bull 💩
* The downloaded model is broken (just reinstall or download from official site) * The downloaded model is broken (just reinstall or download from official site)
* That's it so far * Your model is hanging after a call to generate tokens.
* Is `nPast` set too high? This may cause your model to hang (03/16/2024), Linux Mint, Ubuntu 22.04
* Your GPU usage is still high after node.js exits.
* Make sure to call `model.dispose()`!!!
### Roadmap ### Roadmap
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list: This package has been stabilizing over time development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[x] createChatSession ( the python equivalent to create\_chat\_session )
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs * \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
* \[ ] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete * \[x] generateTokens is the new name for this^
* \[x] proper unit testing (integrate with circle ci) * \[x] proper unit testing (integrate with circle ci)
* \[x] publish to npm under alpha tag `gpt4all@alpha` * \[x] publish to npm under alpha tag `gpt4all@alpha`
* \[x] have more people test on other platforms (mac tester needed) * \[x] have more people test on other platforms (mac tester needed)
* \[x] switch to new pluggable backend * \[x] switch to new pluggable backend
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session )
### API Reference ### API Reference
@@ -153,144 +238,200 @@ This package is in active development, and breaking changes may happen until the
##### Table of Contents ##### Table of Contents
* [ModelFile](#modelfile)
* [gptj](#gptj)
* [llama](#llama)
* [mpt](#mpt)
* [replit](#replit)
* [type](#type) * [type](#type)
* [TokenCallback](#tokencallback) * [TokenCallback](#tokencallback)
* [ChatSessionOptions](#chatsessionoptions)
* [systemPrompt](#systemprompt)
* [messages](#messages)
* [initialize](#initialize)
* [Parameters](#parameters)
* [generate](#generate)
* [Parameters](#parameters-1)
* [InferenceModel](#inferencemodel) * [InferenceModel](#inferencemodel)
* [createChatSession](#createchatsession)
* [Parameters](#parameters-2)
* [generate](#generate-1)
* [Parameters](#parameters-3)
* [dispose](#dispose) * [dispose](#dispose)
* [EmbeddingModel](#embeddingmodel) * [EmbeddingModel](#embeddingmodel)
* [dispose](#dispose-1) * [dispose](#dispose-1)
* [InferenceResult](#inferenceresult)
* [LLModel](#llmodel) * [LLModel](#llmodel)
* [constructor](#constructor) * [constructor](#constructor)
* [Parameters](#parameters) * [Parameters](#parameters-4)
* [type](#type-1) * [type](#type-1)
* [name](#name) * [name](#name)
* [stateSize](#statesize) * [stateSize](#statesize)
* [threadCount](#threadcount) * [threadCount](#threadcount)
* [setThreadCount](#setthreadcount) * [setThreadCount](#setthreadcount)
* [Parameters](#parameters-1) * [Parameters](#parameters-5)
* [raw\_prompt](#raw_prompt) * [infer](#infer)
* [Parameters](#parameters-2) * [Parameters](#parameters-6)
* [embed](#embed) * [embed](#embed)
* [Parameters](#parameters-3) * [Parameters](#parameters-7)
* [isModelLoaded](#ismodelloaded) * [isModelLoaded](#ismodelloaded)
* [setLibraryPath](#setlibrarypath) * [setLibraryPath](#setlibrarypath)
* [Parameters](#parameters-4) * [Parameters](#parameters-8)
* [getLibraryPath](#getlibrarypath) * [getLibraryPath](#getlibrarypath)
* [initGpuByString](#initgpubystring) * [initGpuByString](#initgpubystring)
* [Parameters](#parameters-5) * [Parameters](#parameters-9)
* [hasGpuDevice](#hasgpudevice) * [hasGpuDevice](#hasgpudevice)
* [listGpu](#listgpu) * [listGpu](#listgpu)
* [Parameters](#parameters-6) * [Parameters](#parameters-10)
* [dispose](#dispose-2) * [dispose](#dispose-2)
* [GpuDevice](#gpudevice) * [GpuDevice](#gpudevice)
* [type](#type-2) * [type](#type-2)
* [LoadModelOptions](#loadmodeloptions) * [LoadModelOptions](#loadmodeloptions)
* [loadModel](#loadmodel) * [modelPath](#modelpath)
* [Parameters](#parameters-7) * [librariesPath](#librariespath)
* [createCompletion](#createcompletion) * [modelConfigFile](#modelconfigfile)
* [Parameters](#parameters-8) * [allowDownload](#allowdownload)
* [createEmbedding](#createembedding)
* [Parameters](#parameters-9)
* [CompletionOptions](#completionoptions)
* [verbose](#verbose) * [verbose](#verbose)
* [systemPromptTemplate](#systemprompttemplate) * [device](#device)
* [promptTemplate](#prompttemplate) * [nCtx](#nctx)
* [promptHeader](#promptheader) * [ngl](#ngl)
* [promptFooter](#promptfooter) * [loadModel](#loadmodel)
* [PromptMessage](#promptmessage) * [Parameters](#parameters-11)
* [InferenceProvider](#inferenceprovider)
* [createCompletion](#createcompletion)
* [Parameters](#parameters-12)
* [createCompletionStream](#createcompletionstream)
* [Parameters](#parameters-13)
* [createCompletionGenerator](#createcompletiongenerator)
* [Parameters](#parameters-14)
* [createEmbedding](#createembedding)
* [Parameters](#parameters-15)
* [CompletionOptions](#completionoptions)
* [verbose](#verbose-1)
* [onToken](#ontoken)
* [Message](#message)
* [role](#role) * [role](#role)
* [content](#content) * [content](#content)
* [prompt\_tokens](#prompt_tokens) * [prompt\_tokens](#prompt_tokens)
* [completion\_tokens](#completion_tokens) * [completion\_tokens](#completion_tokens)
* [total\_tokens](#total_tokens) * [total\_tokens](#total_tokens)
* [n\_past\_tokens](#n_past_tokens)
* [CompletionReturn](#completionreturn) * [CompletionReturn](#completionreturn)
* [model](#model) * [model](#model)
* [usage](#usage) * [usage](#usage)
* [choices](#choices) * [message](#message-1)
* [CompletionChoice](#completionchoice) * [CompletionStreamReturn](#completionstreamreturn)
* [message](#message)
* [LLModelPromptContext](#llmodelpromptcontext) * [LLModelPromptContext](#llmodelpromptcontext)
* [logitsSize](#logitssize) * [logitsSize](#logitssize)
* [tokensSize](#tokenssize) * [tokensSize](#tokenssize)
* [nPast](#npast) * [nPast](#npast)
* [nCtx](#nctx)
* [nPredict](#npredict) * [nPredict](#npredict)
* [promptTemplate](#prompttemplate)
* [nCtx](#nctx-1)
* [topK](#topk) * [topK](#topk)
* [topP](#topp) * [topP](#topp)
* [temp](#temp) * [minP](#minp)
* [temperature](#temperature)
* [nBatch](#nbatch) * [nBatch](#nbatch)
* [repeatPenalty](#repeatpenalty) * [repeatPenalty](#repeatpenalty)
* [repeatLastN](#repeatlastn) * [repeatLastN](#repeatlastn)
* [contextErase](#contexterase) * [contextErase](#contexterase)
* [generateTokens](#generatetokens)
* [Parameters](#parameters-10)
* [DEFAULT\_DIRECTORY](#default_directory) * [DEFAULT\_DIRECTORY](#default_directory)
* [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory) * [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory)
* [DEFAULT\_MODEL\_CONFIG](#default_model_config) * [DEFAULT\_MODEL\_CONFIG](#default_model_config)
* [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context) * [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context)
* [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url) * [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url)
* [downloadModel](#downloadmodel) * [downloadModel](#downloadmodel)
* [Parameters](#parameters-11) * [Parameters](#parameters-16)
* [Examples](#examples) * [Examples](#examples)
* [DownloadModelOptions](#downloadmodeloptions) * [DownloadModelOptions](#downloadmodeloptions)
* [modelPath](#modelpath) * [modelPath](#modelpath-1)
* [verbose](#verbose-1) * [verbose](#verbose-2)
* [url](#url) * [url](#url)
* [md5sum](#md5sum) * [md5sum](#md5sum)
* [DownloadController](#downloadcontroller) * [DownloadController](#downloadcontroller)
* [cancel](#cancel) * [cancel](#cancel)
* [promise](#promise) * [promise](#promise)
#### ModelFile
Full list of models available
DEPRECATED!! These model names are outdated and this type will not be maintained, please use a string literal instead
##### gptj
List of GPT-J Models
Type: (`"ggml-gpt4all-j-v1.3-groovy.bin"` | `"ggml-gpt4all-j-v1.2-jazzy.bin"` | `"ggml-gpt4all-j-v1.1-breezy.bin"` | `"ggml-gpt4all-j.bin"`)
##### llama
List Llama Models
Type: (`"ggml-gpt4all-l13b-snoozy.bin"` | `"ggml-vicuna-7b-1.1-q4_2.bin"` | `"ggml-vicuna-13b-1.1-q4_2.bin"` | `"ggml-wizardLM-7B.q4_2.bin"` | `"ggml-stable-vicuna-13B.q4_2.bin"` | `"ggml-nous-gpt4-vicuna-13b.bin"` | `"ggml-v3-13b-hermes-q5_1.bin"`)
##### mpt
List of MPT Models
Type: (`"ggml-mpt-7b-base.bin"` | `"ggml-mpt-7b-chat.bin"` | `"ggml-mpt-7b-instruct.bin"`)
##### replit
List of Replit Models
Type: `"ggml-replit-code-v1-3b.bin"`
#### type #### type
Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user. Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user.
Type: ModelType Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
#### TokenCallback #### TokenCallback
Callback for controlling token generation Callback for controlling token generation. Return false to stop token generation.
Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean) Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
#### ChatSessionOptions
**Extends Partial\<LLModelPromptContext>**
Options for the chat session.
##### systemPrompt
System prompt to ingest on initialization.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### messages
Messages to ingest on initialization.
Type: [Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[Message](#message)>
#### initialize
Ingests system prompt and initial messages.
Sets this chat session as the active chat session of the model.
##### Parameters
* `options` **[ChatSessionOptions](#chatsessionoptions)** The options for the chat session.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)\<void>**&#x20;
#### generate
Prompts the model in chat-session context.
##### Parameters
* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `options` **[CompletionOptions](#completionoptions)?** Prompt context and other options.
* `callback` **[TokenCallback](#tokencallback)?** Token generation callback.
<!---->
* Throws **[Error](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Error)** If the chat session is not the active chat session of the model.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[CompletionReturn](#completionreturn)>** The model's response to the prompt.
#### InferenceModel #### InferenceModel
InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers. InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers.
##### createChatSession
Create a chat session with the model.
###### Parameters
* `options` **[ChatSessionOptions](#chatsessionoptions)?** The options for the chat session.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)\<ChatSession>** The chat session.
##### generate
Prompts the model with a given input and optional parameters.
###### Parameters
* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)**&#x20;
* `options` **[CompletionOptions](#completionoptions)?** Prompt context and other options.
* `callback` **[TokenCallback](#tokencallback)?** Token generation callback.
* `input` The prompt input.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[CompletionReturn](#completionreturn)>** The model's response to the prompt.
##### dispose ##### dispose
delete and cleanup the native model delete and cleanup the native model
@@ -307,6 +448,10 @@ delete and cleanup the native model
Returns **void**&#x20; Returns **void**&#x20;
#### InferenceResult
Shape of LLModel's inference result.
#### LLModel #### LLModel
LLModel class representing a language model. LLModel class representing a language model.
@@ -326,9 +471,9 @@ Initialize a new LLModel.
##### type ##### type
either 'gpt', mpt', or 'llama' or undefined undefined or user supplied
Returns **(ModelType | [undefined](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/undefined))**&#x20; Returns **([string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) | [undefined](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/undefined))**&#x20;
##### name ##### name
@@ -360,7 +505,7 @@ Set the number of threads used for model inference.
Returns **void**&#x20; Returns **void**&#x20;
##### raw\_prompt ##### infer
Prompt the model with a given input and optional parameters. Prompt the model with a given input and optional parameters.
This is the raw output from model. This is the raw output from model.
@@ -368,23 +513,20 @@ Use the prompt function exported for a value
###### Parameters ###### Parameters
* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input. * `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context. * `promptContext` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context.
* `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation. * `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation.
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The result of the model prompt. Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[InferenceResult](#inferenceresult)>** The result of the model prompt.
##### embed ##### embed
Embed text with the model. Keep in mind that Embed text with the model. Keep in mind that
not all models can embed text, (only bert can embed as of 07/16/2023 (mm/dd/yyyy))
Use the prompt function exported for a value Use the prompt function exported for a value
###### Parameters ###### Parameters
* `text` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)**&#x20; * `text` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input.
* `q` The prompt input.
* `params` Optional parameters for the prompt context.
Returns **[Float32Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Float32Array)** The result of the model prompt. Returns **[Float32Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Float32Array)** The result of the model prompt.
@@ -462,6 +604,62 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
Options that configure a model's behavior. Options that configure a model's behavior.
##### modelPath
Where to look for model files.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### librariesPath
Where to look for the backend libraries.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### modelConfigFile
The path to the model configuration file, useful for offline usage or custom model configurations.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### allowDownload
Whether to allow downloading the model if it is not present at the specified path.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### verbose
Enable verbose logging.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### device
The processing unit on which the model will run. It can be set to
* "cpu": Model will run on the central processing unit.
* "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
* "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
* "gpu name": Model will run on the GPU that matches the name if it's available.
Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All
instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the
model.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### nCtx
The Maximum window size of this model
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### ngl
Number of gpu layers needed
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### loadModel #### loadModel
Loads a machine learning model with the specified name. The defacto way to create a model. Loads a machine learning model with the specified name. The defacto way to create a model.
@@ -474,18 +672,46 @@ By default this will download a model from the official GPT4ALL website, if a mo
Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<([InferenceModel](#inferencemodel) | [EmbeddingModel](#embeddingmodel))>** A promise that resolves to an instance of the loaded LLModel. Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<([InferenceModel](#inferencemodel) | [EmbeddingModel](#embeddingmodel))>** A promise that resolves to an instance of the loaded LLModel.
#### InferenceProvider
Interface for inference, implemented by InferenceModel and ChatSession.
#### createCompletion #### createCompletion
The nodejs equivalent to python binding's chat\_completion The nodejs equivalent to python binding's chat\_completion
##### Parameters ##### Parameters
* `model` **[InferenceModel](#inferencemodel)** The language model object. * `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation. * `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion. * `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
Returns **[CompletionReturn](#completionreturn)** The completion result. Returns **[CompletionReturn](#completionreturn)** The completion result.
#### createCompletionStream
Streaming variant of createCompletion, returns a stream of tokens and a promise that resolves to the completion result.
##### Parameters
* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session
* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message.
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
Returns **[CompletionStreamReturn](#completionstreamreturn)** An object of token stream and the completion result promise.
#### createCompletionGenerator
Creates an async generator of tokens
##### Parameters
* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session
* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message.
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
#### createEmbedding #### createEmbedding
The nodejs moral equivalent to python binding's Embed4All().embed() The nodejs moral equivalent to python binding's Embed4All().embed()
@@ -510,34 +736,15 @@ Indicates if verbose logging is enabled.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean) Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### systemPromptTemplate ##### onToken
Template for the system message. Will be put before the conversation with %1 being replaced by all system messages. Callback for controlling token generation. Return false to stop processing.
Note that if this is not defined, system messages will not be included in the prompt.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) Type: [TokenCallback](#tokencallback)
##### promptTemplate #### Message
Template for user messages, with %1 being replaced by the message. A message in the conversation.
Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean)
##### promptHeader
The initial instruction for the model, on top of the prompt
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### promptFooter
The last instruction for the model, appended to the end of the prompt.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
#### PromptMessage
A message in the conversation, identical to OpenAI's chat message.
##### role ##### role
@@ -553,7 +760,7 @@ Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
#### prompt\_tokens #### prompt\_tokens
The number of tokens used in the prompt. The number of tokens used in the prompt. Currently not available and always 0.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
@@ -565,13 +772,19 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
#### total\_tokens #### total\_tokens
The total number of tokens used. The total number of tokens used. Currently not available and always 0.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### n\_past\_tokens
Number of tokens used in the conversation.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### CompletionReturn #### CompletionReturn
The result of the completion, similar to OpenAI's format. The result of a completion.
##### model ##### model
@@ -583,23 +796,17 @@ Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
Token usage report. Token usage report.
Type: {prompt\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), completion\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), total\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)} Type: {prompt\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), completion\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), total\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), n\_past\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)}
##### choices
The generated completions.
Type: [Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[CompletionChoice](#completionchoice)>
#### CompletionChoice
A completion choice, similar to OpenAI's format.
##### message ##### message
Response message The generated completion.
Type: [PromptMessage](#promptmessage) Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
#### CompletionStreamReturn
The result of a streamed completion, containing a stream of tokens and a promise that resolves to the completion result.
#### LLModelPromptContext #### LLModelPromptContext
@@ -620,18 +827,29 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa
##### nPast ##### nPast
The number of tokens in the past conversation. The number of tokens in the past conversation.
This controls how far back the model looks when generating completions.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### nCtx
The number of tokens possible in the context window.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### nPredict ##### nPredict
The number of tokens to predict. The maximum number of tokens to predict.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### promptTemplate
Template for user / assistant message pairs.
%1 is required and will be replaced by the user input.
%2 is optional and will be replaced by the assistant response.
Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)
##### nCtx
The context window size. Do not use, it has no effect. See loadModel options.
THIS IS DEPRECATED!!!
Use loadModel's nCtx option instead.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
@@ -654,12 +872,16 @@ above a threshold P. This method, also known as nucleus sampling, finds a balanc
and quality by considering both token probabilities and the number of tokens available for sampling. and quality by considering both token probabilities and the number of tokens available for sampling.
When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse. When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse.
On the other hand, a lower value (eg., 0.1) produces more focused and conservative text. On the other hand, a lower value (eg., 0.1) produces more focused and conservative text.
The default value is 0.4, which is aimed to be the middle ground between focus and diversity, but
for more creative tasks a higher top-p value will be beneficial, about 0.5-0.9 is a good range for that.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### temp ##### minP
The minimum probability of a token to be considered.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
##### temperature
The temperature to adjust the model's output distribution. The temperature to adjust the model's output distribution.
Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures
@@ -704,19 +926,6 @@ The percentage of context to erase if the context window is exceeded.
Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)
#### generateTokens
Creates an async generator of tokens
##### Parameters
* `llmodel` **[InferenceModel](#inferencemodel)** The language model object.
* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation.
* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion.
* `callback` **[TokenCallback](#tokencallback)** optional callback to control token generation.
Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens
#### DEFAULT\_DIRECTORY #### DEFAULT\_DIRECTORY
From python api: From python api:
@@ -759,7 +968,7 @@ By default this downloads without waiting. use the controller returned to alter
##### Parameters ##### Parameters
* `modelName` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The model to be downloaded. * `modelName` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The model to be downloaded.
* `options` **DownloadOptions** to pass into the downloader. Default is { location: (cwd), verbose: false }. * `options` **[DownloadModelOptions](#downloadmodeloptions)** to pass into the downloader. Default is { location: (cwd), verbose: false }.
##### Examples ##### Examples

View File

@@ -0,0 +1,4 @@
---
Language: Cpp
BasedOnStyle: Microsoft
ColumnLimit: 120

View File

@@ -10,45 +10,170 @@ npm install gpt4all@latest
pnpm install gpt4all@latest pnpm install gpt4all@latest
``` ```
## Breaking changes in version 4!!
The original [GPT4All typescript bindings](https://github.com/nomic-ai/gpt4all-ts) are now out of date. * See [Transition](#changes)
## Contents
* New bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
* The nodejs api has made strides to mirror the python api. It is not 100% mirrored, but many pieces of the api resemble its python counterpart.
* Everything should work out the box.
* See [API Reference](#api-reference) * See [API Reference](#api-reference)
* See [Examples](#api-example)
* See [Developing](#develop)
* GPT4ALL nodejs bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use.
* [spare change](https://github.com/sponsors/jacoobes) for a college student? 🤑
## Api Examples
### Chat Completion ### Chat Completion
Use a chat session to keep context between completions. This is useful for efficient back and forth conversations.
```js ```js
import { createCompletion, loadModel } from '../src/gpt4all.js' import { createCompletion, loadModel } from "../src/gpt4all.js";
const model = await loadModel('mistral-7b-openorca.Q4_0.gguf', { verbose: true }); const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true, // logs loaded model configuration
device: "gpu", // defaults to 'cpu'
nCtx: 2048, // the maximum sessions context window size.
});
const response = await createCompletion(model, [ // initialize a chat session on the model. a model instance can have only one chat session at a time.
{ role : 'system', content: 'You are meant to be annoying and unhelpful.' }, const chat = await model.createChatSession({
{ role : 'user', content: 'What is 1 + 1?' } // any completion options set here will be used as default for all completions in this chat session
temperature: 0.8,
// a custom systemPrompt can be set here. note that the template depends on the model.
// if unset, the systemPrompt that comes with the model will be used.
systemPrompt: "### System:\nYou are an advanced mathematician.\n\n",
});
// create a completion using a string as input
const res1 = await createCompletion(chat, "What is 1 + 1?");
console.debug(res1.choices[0].message);
// multiple messages can be input to the conversation at once.
// note that if the last message is not of role 'user', an empty message will be returned.
await createCompletion(chat, [
{
role: "user",
content: "What is 2 + 2?",
},
{
role: "assistant",
content: "It's 5.",
},
]); ]);
const res3 = await createCompletion(chat, "Could you recalculate that?");
console.debug(res3.choices[0].message);
model.dispose();
```
### Stateless usage
You can use the model without a chat session. This is useful for one-off completions.
```js
import { createCompletion, loadModel } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf");
// createCompletion methods can also be used on the model directly.
// context is not maintained between completions.
const res1 = await createCompletion(model, "What is 1 + 1?");
console.debug(res1.choices[0].message);
// a whole conversation can be input as well.
// note that if the last message is not of role 'user', an error will be thrown.
const res2 = await createCompletion(model, [
{
role: "user",
content: "What is 2 + 2?",
},
{
role: "assistant",
content: "It's 5.",
},
{
role: "user",
content: "Could you recalculate that?",
},
]);
console.debug(res2.choices[0].message);
``` ```
### Embedding ### Embedding
```js ```js
import { createEmbedding, loadModel } from '../src/gpt4all.js' import { loadModel, createEmbedding } from '../src/gpt4all.js'
const model = await loadModel('ggml-all-MiniLM-L6-v2-f16', { verbose: true }); const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding'})
const fltArray = createEmbedding(model, "Pain is inevitable, suffering optional"); console.log(createEmbedding(embedder, "Maybe Minecraft was the friends we made along the way"));
``` ```
### Streaming responses
```js
import { loadModel, createCompletionStream } from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("Output: ");
const stream = createCompletionStream(model, "How are you?");
stream.tokens.on("data", (data) => {
process.stdout.write(data);
});
//wait till stream finishes. We cannot continue until this one is done.
await stream.result;
process.stdout.write("\n");
model.dispose();
```
### Async Generators
```js
import { loadModel, createCompletionGenerator } from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf");
process.stdout.write("Output: ");
const gen = createCompletionGenerator(
model,
"Redstone in Minecraft is Turing Complete. Let that sink in. (let it in!)"
);
for await (const chunk of gen) {
process.stdout.write(chunk);
}
process.stdout.write("\n");
model.dispose();
```
### Offline usage
do this b4 going offline
```sh
curl -L https://gpt4all.io/models/models3.json -o ./models3.json
```
```js
import { createCompletion, loadModel } from 'gpt4all'
//make sure u downloaded the models before going offline!
const model = await loadModel('mistral-7b-openorca.gguf2.Q4_0.gguf', {
verbose: true,
device: 'gpu',
modelConfigFile: "./models3.json"
});
await createCompletion(model, 'What is 1 + 1?', { verbose: true })
model.dispose();
```
## Develop
### Build Instructions ### Build Instructions
* binding.gyp is compile config * `binding.gyp` is compile config
* Tested on Ubuntu. Everything seems to work fine * Tested on Ubuntu. Everything seems to work fine
* Tested on Windows. Everything works fine. * Tested on Windows. Everything works fine.
* Sparse testing on mac os. * Sparse testing on mac os.
* MingW works as well to build the gpt4all-backend. **HOWEVER**, this package works only with MSVC built dlls. * MingW script works to build the gpt4all-backend. We left it there just in case. **HOWEVER**, this package works only with MSVC built dlls.
### Requirements ### Requirements
@@ -76,23 +201,18 @@ cd gpt4all-bindings/typescript
* To Build and Rebuild: * To Build and Rebuild:
```sh ```sh
yarn node scripts/prebuild.js
``` ```
* llama.cpp git submodule for gpt4all can be possibly absent. If this is the case, make sure to run in llama.cpp parent directory * llama.cpp git submodule for gpt4all can be possibly absent. If this is the case, make sure to run in llama.cpp parent directory
```sh ```sh
git submodule update --init --depth 1 --recursive git submodule update --init --recursive
``` ```
```sh ```sh
yarn build:backend yarn build:backend
``` ```
This will build platform-dependent dynamic libraries, and will be located in runtimes/(platform)/native
This will build platform-dependent dynamic libraries, and will be located in runtimes/(platform)/native The only current way to use them is to put them in the current working directory of your application. That is, **WHEREVER YOU RUN YOUR NODE APPLICATION**
* llama-xxxx.dll is required.
* According to whatever model you are using, you'll need to select the proper model loader.
* For example, if you running an Mosaic MPT model, you will need to select the mpt-(buildvariant).(dynamiclibrary)
### Test ### Test
@@ -130,17 +250,20 @@ yarn test
* why your model may be spewing bull 💩 * why your model may be spewing bull 💩
* The downloaded model is broken (just reinstall or download from official site) * The downloaded model is broken (just reinstall or download from official site)
* That's it so far * Your model is hanging after a call to generate tokens.
* Is `nPast` set too high? This may cause your model to hang (03/16/2024), Linux Mint, Ubuntu 22.04
* Your GPU usage is still high after node.js exits.
* Make sure to call `model.dispose()`!!!
### Roadmap ### Roadmap
This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list: This package has been stabilizing over time development, and breaking changes may happen until the api stabilizes. Here's what's the todo list:
* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well. * \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well.
* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help) * \[ ] NPM bundle size reduction via optionalDependencies strategy (need help)
* Should include prebuilds to avoid painful node-gyp errors * Should include prebuilds to avoid painful node-gyp errors
* \[ ] createChatSession ( the python equivalent to create\_chat\_session ) * \[x] createChatSession ( the python equivalent to create\_chat\_session )
* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs! * \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs!
* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete * \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete
* \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs * \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs
* \[x] generateTokens is the new name for this^ * \[x] generateTokens is the new name for this^
@@ -149,5 +272,13 @@ This package is in active development, and breaking changes may happen until the
* \[x] have more people test on other platforms (mac tester needed) * \[x] have more people test on other platforms (mac tester needed)
* \[x] switch to new pluggable backend * \[x] switch to new pluggable backend
## Changes
This repository serves as the new bindings for nodejs users.
- If you were a user of [these bindings](https://github.com/nomic-ai/gpt4all-ts), they are outdated.
- Version 4 includes the follow breaking changes
* `createEmbedding` & `EmbeddingModel.embed()` returns an object, `EmbeddingResult`, instead of a float32array.
* Removed deprecated types `ModelType` and `ModelFile`
* Removed deprecated initiation of model by string path only
### API Reference ### API Reference

View File

@@ -6,12 +6,12 @@
"<!@(node -p \"require('node-addon-api').include\")", "<!@(node -p \"require('node-addon-api').include\")",
"gpt4all-backend", "gpt4all-backend",
], ],
"sources": [ "sources": [
# PREVIOUS VERSION: had to required the sources, but with newest changes do not need to # PREVIOUS VERSION: had to required the sources, but with newest changes do not need to
#"../../gpt4all-backend/llama.cpp/examples/common.cpp", #"../../gpt4all-backend/llama.cpp/examples/common.cpp",
#"../../gpt4all-backend/llama.cpp/ggml.c", #"../../gpt4all-backend/llama.cpp/ggml.c",
#"../../gpt4all-backend/llama.cpp/llama.cpp", #"../../gpt4all-backend/llama.cpp/llama.cpp",
# "../../gpt4all-backend/utils.cpp", # "../../gpt4all-backend/utils.cpp",
"gpt4all-backend/llmodel_c.cpp", "gpt4all-backend/llmodel_c.cpp",
"gpt4all-backend/llmodel.cpp", "gpt4all-backend/llmodel.cpp",
"prompt.cc", "prompt.cc",
@@ -40,7 +40,7 @@
"AdditionalOptions": [ "AdditionalOptions": [
"/std:c++20", "/std:c++20",
"/EHsc", "/EHsc",
], ],
}, },
}, },
}], }],

View File

@@ -6,12 +6,12 @@
"<!@(node -p \"require('node-addon-api').include\")", "<!@(node -p \"require('node-addon-api').include\")",
"../../gpt4all-backend", "../../gpt4all-backend",
], ],
"sources": [ "sources": [
# PREVIOUS VERSION: had to required the sources, but with newest changes do not need to # PREVIOUS VERSION: had to required the sources, but with newest changes do not need to
#"../../gpt4all-backend/llama.cpp/examples/common.cpp", #"../../gpt4all-backend/llama.cpp/examples/common.cpp",
#"../../gpt4all-backend/llama.cpp/ggml.c", #"../../gpt4all-backend/llama.cpp/ggml.c",
#"../../gpt4all-backend/llama.cpp/llama.cpp", #"../../gpt4all-backend/llama.cpp/llama.cpp",
# "../../gpt4all-backend/utils.cpp", # "../../gpt4all-backend/utils.cpp",
"../../gpt4all-backend/llmodel_c.cpp", "../../gpt4all-backend/llmodel_c.cpp",
"../../gpt4all-backend/llmodel.cpp", "../../gpt4all-backend/llmodel.cpp",
"prompt.cc", "prompt.cc",
@@ -40,7 +40,7 @@
"AdditionalOptions": [ "AdditionalOptions": [
"/std:c++20", "/std:c++20",
"/EHsc", "/EHsc",
], ],
}, },
}, },
}], }],

View File

@@ -1,175 +1,171 @@
#include "index.h" #include "index.h"
#include "napi.h"
Napi::Function NodeModelWrapper::GetClass(Napi::Env env)
Napi::Function NodeModelWrapper::GetClass(Napi::Env env) { {
Napi::Function self = DefineClass(env, "LLModel", { Napi::Function self = DefineClass(env, "LLModel",
InstanceMethod("type", &NodeModelWrapper::GetType), {InstanceMethod("type", &NodeModelWrapper::GetType),
InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded), InstanceMethod("isModelLoaded", &NodeModelWrapper::IsModelLoaded),
InstanceMethod("name", &NodeModelWrapper::GetName), InstanceMethod("name", &NodeModelWrapper::GetName),
InstanceMethod("stateSize", &NodeModelWrapper::StateSize), InstanceMethod("stateSize", &NodeModelWrapper::StateSize),
InstanceMethod("raw_prompt", &NodeModelWrapper::Prompt), InstanceMethod("infer", &NodeModelWrapper::Infer),
InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount), InstanceMethod("setThreadCount", &NodeModelWrapper::SetThreadCount),
InstanceMethod("embed", &NodeModelWrapper::GenerateEmbedding), InstanceMethod("embed", &NodeModelWrapper::GenerateEmbedding),
InstanceMethod("threadCount", &NodeModelWrapper::ThreadCount), InstanceMethod("threadCount", &NodeModelWrapper::ThreadCount),
InstanceMethod("getLibraryPath", &NodeModelWrapper::GetLibraryPath), InstanceMethod("getLibraryPath", &NodeModelWrapper::GetLibraryPath),
InstanceMethod("initGpuByString", &NodeModelWrapper::InitGpuByString), InstanceMethod("initGpuByString", &NodeModelWrapper::InitGpuByString),
InstanceMethod("hasGpuDevice", &NodeModelWrapper::HasGpuDevice), InstanceMethod("hasGpuDevice", &NodeModelWrapper::HasGpuDevice),
InstanceMethod("listGpu", &NodeModelWrapper::GetGpuDevices), InstanceMethod("listGpu", &NodeModelWrapper::GetGpuDevices),
InstanceMethod("memoryNeeded", &NodeModelWrapper::GetRequiredMemory), InstanceMethod("memoryNeeded", &NodeModelWrapper::GetRequiredMemory),
InstanceMethod("dispose", &NodeModelWrapper::Dispose) InstanceMethod("dispose", &NodeModelWrapper::Dispose)});
});
// Keep a static reference to the constructor // Keep a static reference to the constructor
// //
Napi::FunctionReference* constructor = new Napi::FunctionReference(); Napi::FunctionReference *constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(self); *constructor = Napi::Persistent(self);
env.SetInstanceData(constructor); env.SetInstanceData(constructor);
return self; return self;
} }
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo &info)
{ {
auto env = info.Env(); auto env = info.Env();
return Napi::Number::New(env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers) )); return Napi::Number::New(
env, static_cast<uint32_t>(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers)));
} }
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo &info)
{ {
auto env = info.Env(); auto env = info.Env();
int num_devices = 0; int num_devices = 0;
auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers); auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers);
llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices); llmodel_gpu_device *all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices);
if(all_devices == nullptr) { if (all_devices == nullptr)
Napi::Error::New( {
env, Napi::Error::New(env, "Unable to retrieve list of all GPU devices").ThrowAsJavaScriptException();
"Unable to retrieve list of all GPU devices"
).ThrowAsJavaScriptException();
return env.Undefined(); return env.Undefined();
} }
auto js_array = Napi::Array::New(env, num_devices); auto js_array = Napi::Array::New(env, num_devices);
for(int i = 0; i < num_devices; ++i) { for (int i = 0; i < num_devices; ++i)
auto gpu_device = all_devices[i]; {
/* auto gpu_device = all_devices[i];
* /*
* struct llmodel_gpu_device { *
int index = 0; * struct llmodel_gpu_device {
int type = 0; // same as VkPhysicalDeviceType int index = 0;
size_t heapSize = 0; int type = 0; // same as VkPhysicalDeviceType
const char * name; size_t heapSize = 0;
const char * vendor; const char * name;
}; const char * vendor;
* };
*/ *
Napi::Object js_gpu_device = Napi::Object::New(env); */
Napi::Object js_gpu_device = Napi::Object::New(env);
js_gpu_device["index"] = uint32_t(gpu_device.index); js_gpu_device["index"] = uint32_t(gpu_device.index);
js_gpu_device["type"] = uint32_t(gpu_device.type); js_gpu_device["type"] = uint32_t(gpu_device.type);
js_gpu_device["heapSize"] = static_cast<uint32_t>( gpu_device.heapSize ); js_gpu_device["heapSize"] = static_cast<uint32_t>(gpu_device.heapSize);
js_gpu_device["name"]= gpu_device.name; js_gpu_device["name"] = gpu_device.name;
js_gpu_device["vendor"] = gpu_device.vendor; js_gpu_device["vendor"] = gpu_device.vendor;
js_array[i] = js_gpu_device; js_array[i] = js_gpu_device;
} }
return js_array; return js_array;
} }
Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo &info)
{ {
if(type.empty()) { if (type.empty())
{
return info.Env().Undefined(); return info.Env().Undefined();
} }
return Napi::String::New(info.Env(), type); return Napi::String::New(info.Env(), type);
} }
Napi::Value NodeModelWrapper::InitGpuByString(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::InitGpuByString(const Napi::CallbackInfo &info)
{ {
auto env = info.Env(); auto env = info.Env();
size_t memory_required = static_cast<size_t>(info[0].As<Napi::Number>().Uint32Value()); size_t memory_required = static_cast<size_t>(info[0].As<Napi::Number>().Uint32Value());
std::string gpu_device_identifier = info[1].As<Napi::String>(); std::string gpu_device_identifier = info[1].As<Napi::String>();
size_t converted_value; size_t converted_value;
if(memory_required <= std::numeric_limits<size_t>::max()) { if (memory_required <= std::numeric_limits<size_t>::max())
{
converted_value = static_cast<size_t>(memory_required); converted_value = static_cast<size_t>(memory_required);
} else { }
Napi::Error::New( else
env, {
"invalid number for memory size. Exceeded bounds for memory." Napi::Error::New(env, "invalid number for memory size. Exceeded bounds for memory.")
).ThrowAsJavaScriptException(); .ThrowAsJavaScriptException();
return env.Undefined(); return env.Undefined();
} }
auto result = llmodel_gpu_init_gpu_device_by_string(GetInference(), converted_value, gpu_device_identifier.c_str()); auto result = llmodel_gpu_init_gpu_device_by_string(GetInference(), converted_value, gpu_device_identifier.c_str());
return Napi::Boolean::New(env, result); return Napi::Boolean::New(env, result);
} }
Napi::Value NodeModelWrapper::HasGpuDevice(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::HasGpuDevice(const Napi::CallbackInfo &info)
{ {
return Napi::Boolean::New(info.Env(), llmodel_has_gpu_device(GetInference())); return Napi::Boolean::New(info.Env(), llmodel_has_gpu_device(GetInference()));
} }
NodeModelWrapper::NodeModelWrapper(const Napi::CallbackInfo& info) : Napi::ObjectWrap<NodeModelWrapper>(info) NodeModelWrapper::NodeModelWrapper(const Napi::CallbackInfo &info) : Napi::ObjectWrap<NodeModelWrapper>(info)
{ {
auto env = info.Env(); auto env = info.Env();
fs::path model_path; auto config_object = info[0].As<Napi::Object>();
std::string full_weight_path, // sets the directory where models (gguf files) are to be searched
library_path = ".", llmodel_set_implementation_search_path(
model_name, config_object.Has("library_path") ? config_object.Get("library_path").As<Napi::String>().Utf8Value().c_str()
device; : ".");
if(info[0].IsString()) {
model_path = info[0].As<Napi::String>().Utf8Value();
full_weight_path = model_path.string();
std::cout << "DEPRECATION: constructor accepts object now. Check docs for more.\n";
} else {
auto config_object = info[0].As<Napi::Object>();
model_name = config_object.Get("model_name").As<Napi::String>();
model_path = config_object.Get("model_path").As<Napi::String>().Utf8Value();
if(config_object.Has("model_type")) {
type = config_object.Get("model_type").As<Napi::String>();
}
full_weight_path = (model_path / fs::path(model_name)).string();
if(config_object.Has("library_path")) {
library_path = config_object.Get("library_path").As<Napi::String>();
} else {
library_path = ".";
}
device = config_object.Get("device").As<Napi::String>();
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value(); std::string model_name = config_object.Get("model_name").As<Napi::String>();
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value(); fs::path model_path = config_object.Get("model_path").As<Napi::String>().Utf8Value();
} std::string full_weight_path = (model_path / fs::path(model_name)).string();
llmodel_set_implementation_search_path(library_path.c_str());
const char* e; name = model_name.empty() ? model_path.filename().string() : model_name;
full_model_path = full_weight_path;
nCtx = config_object.Get("nCtx").As<Napi::Number>().Int32Value();
nGpuLayers = config_object.Get("ngl").As<Napi::Number>().Int32Value();
const char *e;
inference_ = llmodel_model_create2(full_weight_path.c_str(), "auto", &e); inference_ = llmodel_model_create2(full_weight_path.c_str(), "auto", &e);
if(!inference_) { if (!inference_)
Napi::Error::New(env, e).ThrowAsJavaScriptException(); {
return; Napi::Error::New(env, e).ThrowAsJavaScriptException();
return;
} }
if(GetInference() == nullptr) { if (GetInference() == nullptr)
std::cerr << "Tried searching libraries in \"" << library_path << "\"" << std::endl; {
std::cerr << "Tried searching for model weight in \"" << full_weight_path << "\"" << std::endl; std::cerr << "Tried searching libraries in \"" << llmodel_get_implementation_search_path() << "\"" << std::endl;
std::cerr << "Do you have runtime libraries installed?" << std::endl; std::cerr << "Tried searching for model weight in \"" << full_weight_path << "\"" << std::endl;
Napi::Error::New(env, "Had an issue creating llmodel object, inference is null").ThrowAsJavaScriptException(); std::cerr << "Do you have runtime libraries installed?" << std::endl;
return; Napi::Error::New(env, "Had an issue creating llmodel object, inference is null").ThrowAsJavaScriptException();
return;
} }
if(device != "cpu") {
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(),nCtx, nGpuLayers); std::string device = config_object.Get("device").As<Napi::String>();
if (device != "cpu")
{
size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str()); auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str());
if(!success) { if (!success)
//https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215 {
//Haven't implemented this but it is still open to contribution // https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215
// Haven't implemented this but it is still open to contribution
std::cout << "WARNING: Failed to init GPU\n"; std::cout << "WARNING: Failed to init GPU\n";
} }
} }
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers); auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers);
if(!success) { if (!success)
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException(); {
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
return; return;
} }
// optional
name = model_name.empty() ? model_path.filename().string() : model_name; if (config_object.Has("model_type"))
full_model_path = full_weight_path; {
}; type = config_object.Get("model_type").As<Napi::String>();
}
};
// NodeModelWrapper::~NodeModelWrapper() { // NodeModelWrapper::~NodeModelWrapper() {
// if(GetInference() != nullptr) { // if(GetInference() != nullptr) {
@@ -182,177 +178,275 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
// if(inference_ != nullptr) { // if(inference_ != nullptr) {
// std::cout << "Debug: deleting model\n"; // std::cout << "Debug: deleting model\n";
// //
// } // }
// } // }
Napi::Value NodeModelWrapper::IsModelLoaded(const Napi::CallbackInfo& info) { Napi::Value NodeModelWrapper::IsModelLoaded(const Napi::CallbackInfo &info)
{
return Napi::Boolean::New(info.Env(), llmodel_isModelLoaded(GetInference())); return Napi::Boolean::New(info.Env(), llmodel_isModelLoaded(GetInference()));
} }
Napi::Value NodeModelWrapper::StateSize(const Napi::CallbackInfo& info) { Napi::Value NodeModelWrapper::StateSize(const Napi::CallbackInfo &info)
{
// Implement the binding for the stateSize method // Implement the binding for the stateSize method
return Napi::Number::New(info.Env(), static_cast<int64_t>(llmodel_get_state_size(GetInference()))); return Napi::Number::New(info.Env(), static_cast<int64_t>(llmodel_get_state_size(GetInference())));
} }
Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo& info) { Napi::Array ChunkedFloatPtr(float *embedding_ptr, int embedding_size, int text_len, Napi::Env const &env)
{
auto n_embd = embedding_size / text_len;
// std::cout << "Embedding size: " << embedding_size << std::endl;
// std::cout << "Text length: " << text_len << std::endl;
// std::cout << "Chunk size (n_embd): " << n_embd << std::endl;
Napi::Array result = Napi::Array::New(env, text_len);
auto count = 0;
for (int i = 0; i < embedding_size; i += n_embd)
{
int end = std::min(i + n_embd, embedding_size);
// possible bounds error?
// Constructs a container with as many elements as the range [first,last), with each element emplace-constructed
// from its corresponding element in that range, in the same order.
std::vector<float> chunk(embedding_ptr + i, embedding_ptr + end);
Napi::Float32Array fltarr = Napi::Float32Array::New(env, chunk.size());
// I know there's a way to emplace the raw float ptr into a Napi::Float32Array but idk how and
// im too scared to cause memory issues
// this is goodenough
for (int j = 0; j < chunk.size(); j++)
{
fltarr.Set(j, chunk[j]);
}
result.Set(count++, fltarr);
}
return result;
}
Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo &info)
{
auto env = info.Env(); auto env = info.Env();
std::string text = info[0].As<Napi::String>().Utf8Value();
size_t embedding_size = 0; auto prefix = info[1];
float* arr = llmodel_embedding(GetInference(), text.c_str(), &embedding_size); auto dimensionality = info[2].As<Napi::Number>().Int32Value();
if(arr == nullptr) { auto do_mean = info[3].As<Napi::Boolean>().Value();
Napi::Error::New( auto atlas = info[4].As<Napi::Boolean>().Value();
env, size_t embedding_size;
"Cannot embed. native embedder returned 'nullptr'" size_t token_count = 0;
).ThrowAsJavaScriptException();
// This procedure can maybe be optimized but its whatever, i have too many intermediary structures
std::vector<std::string> text_arr;
bool is_single_text = false;
if (info[0].IsString())
{
is_single_text = true;
text_arr.push_back(info[0].As<Napi::String>().Utf8Value());
}
else
{
auto jsarr = info[0].As<Napi::Array>();
size_t len = jsarr.Length();
text_arr.reserve(len);
for (size_t i = 0; i < len; ++i)
{
std::string str = jsarr.Get(i).As<Napi::String>().Utf8Value();
text_arr.push_back(str);
}
}
std::vector<const char *> str_ptrs;
str_ptrs.reserve(text_arr.size() + 1);
for (size_t i = 0; i < text_arr.size(); ++i)
str_ptrs.push_back(text_arr[i].c_str());
str_ptrs.push_back(nullptr);
const char *_err = nullptr;
float *embeds = llmodel_embed(GetInference(), str_ptrs.data(), &embedding_size,
prefix.IsUndefined() ? nullptr : prefix.As<Napi::String>().Utf8Value().c_str(),
dimensionality, &token_count, do_mean, atlas, &_err);
if (!embeds)
{
// i dont wanna deal with c strings lol
std::string err(_err);
Napi::Error::New(env, err == "(unknown error)" ? "Unknown error: sorry bud" : err).ThrowAsJavaScriptException();
return env.Undefined(); return env.Undefined();
} }
auto embedmat = ChunkedFloatPtr(embeds, embedding_size, text_arr.size(), env);
if(embedding_size == 0 && text.size() != 0 ) { llmodel_free_embedding(embeds);
std::cout << "Warning: embedding length 0 but input text length > 0" << std::endl; auto res = Napi::Object::New(env);
} res.Set("n_prompt_tokens", token_count);
Napi::Float32Array js_array = Napi::Float32Array::New(env, embedding_size); if(is_single_text) {
res.Set("embeddings", embedmat.Get(static_cast<uint32_t>(0)));
for (size_t i = 0; i < embedding_size; ++i) { } else {
float element = *(arr + i); res.Set("embeddings", embedmat);
js_array[i] = element;
} }
llmodel_free_embedding(arr); return res;
}
return js_array;
}
/** /**
* Generate a response using the model. * Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt. * @param prompt A string representing the input prompt.
* @param prompt_callback A callback function for handling the processing of prompt. * @param options Inference options.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param ctx A pointer to the llmodel_prompt_context structure.
*/ */
Napi::Value NodeModelWrapper::Prompt(const Napi::CallbackInfo& info) { Napi::Value NodeModelWrapper::Infer(const Napi::CallbackInfo &info)
{
auto env = info.Env(); auto env = info.Env();
std::string question; std::string prompt;
if(info[0].IsString()) { if (info[0].IsString())
question = info[0].As<Napi::String>().Utf8Value(); {
} else { prompt = info[0].As<Napi::String>().Utf8Value();
}
else
{
Napi::Error::New(info.Env(), "invalid string argument").ThrowAsJavaScriptException(); Napi::Error::New(info.Env(), "invalid string argument").ThrowAsJavaScriptException();
return info.Env().Undefined(); return info.Env().Undefined();
} }
//defaults copied from python bindings
llmodel_prompt_context promptContext = {
.logits = nullptr,
.tokens = nullptr,
.n_past = 0,
.n_ctx = 1024,
.n_predict = 128,
.top_k = 40,
.top_p = 0.9f,
.min_p = 0.0f,
.temp = 0.72f,
.n_batch = 8,
.repeat_penalty = 1.0f,
.repeat_last_n = 10,
.context_erase = 0.5
};
PromptWorkerConfig promptWorkerConfig;
if(info[1].IsObject()) if (!info[1].IsObject())
{
auto inputObject = info[1].As<Napi::Object>();
// Extract and assign the properties
if (inputObject.Has("logits") || inputObject.Has("tokens")) {
Napi::Error::New(info.Env(), "Invalid input: 'logits' or 'tokens' properties are not allowed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}
// Assign the remaining properties
if(inputObject.Has("n_past"))
promptContext.n_past = inputObject.Get("n_past").As<Napi::Number>().Int32Value();
if(inputObject.Has("n_ctx"))
promptContext.n_ctx = inputObject.Get("n_ctx").As<Napi::Number>().Int32Value();
if(inputObject.Has("n_predict"))
promptContext.n_predict = inputObject.Get("n_predict").As<Napi::Number>().Int32Value();
if(inputObject.Has("top_k"))
promptContext.top_k = inputObject.Get("top_k").As<Napi::Number>().Int32Value();
if(inputObject.Has("top_p"))
promptContext.top_p = inputObject.Get("top_p").As<Napi::Number>().FloatValue();
if(inputObject.Has("min_p"))
promptContext.min_p = inputObject.Get("min_p").As<Napi::Number>().FloatValue();
if(inputObject.Has("temp"))
promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
if(inputObject.Has("n_batch"))
promptContext.n_batch = inputObject.Get("n_batch").As<Napi::Number>().Int32Value();
if(inputObject.Has("repeat_penalty"))
promptContext.repeat_penalty = inputObject.Get("repeat_penalty").As<Napi::Number>().FloatValue();
if(inputObject.Has("repeat_last_n"))
promptContext.repeat_last_n = inputObject.Get("repeat_last_n").As<Napi::Number>().Int32Value();
if(inputObject.Has("context_erase"))
promptContext.context_erase = inputObject.Get("context_erase").As<Napi::Number>().FloatValue();
}
else
{ {
Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException(); Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException();
return info.Env().Undefined(); return info.Env().Undefined();
} }
// defaults copied from python bindings
llmodel_prompt_context promptContext = {.logits = nullptr,
.tokens = nullptr,
.n_past = 0,
.n_ctx = nCtx,
.n_predict = 4096,
.top_k = 40,
.top_p = 0.9f,
.min_p = 0.0f,
.temp = 0.1f,
.n_batch = 8,
.repeat_penalty = 1.2f,
.repeat_last_n = 10,
.context_erase = 0.75};
if(info.Length() >= 3 && info[2].IsFunction()){ PromptWorkerConfig promptWorkerConfig;
promptWorkerConfig.bHasTokenCallback = true;
promptWorkerConfig.tokenCallback = info[2].As<Napi::Function>(); auto inputObject = info[1].As<Napi::Object>();
if (inputObject.Has("logits") || inputObject.Has("tokens"))
{
Napi::Error::New(info.Env(), "Invalid input: 'logits' or 'tokens' properties are not allowed")
.ThrowAsJavaScriptException();
return info.Env().Undefined();
} }
// Assign the remaining properties
if (inputObject.Has("nPast") && inputObject.Get("nPast").IsNumber())
{
promptContext.n_past = inputObject.Get("nPast").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("nPredict") && inputObject.Get("nPredict").IsNumber())
{
promptContext.n_predict = inputObject.Get("nPredict").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("topK") && inputObject.Get("topK").IsNumber())
{
promptContext.top_k = inputObject.Get("topK").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("topP") && inputObject.Get("topP").IsNumber())
{
promptContext.top_p = inputObject.Get("topP").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("minP") && inputObject.Get("minP").IsNumber())
{
promptContext.min_p = inputObject.Get("minP").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("temp") && inputObject.Get("temp").IsNumber())
{
promptContext.temp = inputObject.Get("temp").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("nBatch") && inputObject.Get("nBatch").IsNumber())
{
promptContext.n_batch = inputObject.Get("nBatch").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("repeatPenalty") && inputObject.Get("repeatPenalty").IsNumber())
{
promptContext.repeat_penalty = inputObject.Get("repeatPenalty").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("repeatLastN") && inputObject.Get("repeatLastN").IsNumber())
{
promptContext.repeat_last_n = inputObject.Get("repeatLastN").As<Napi::Number>().Int32Value();
}
if (inputObject.Has("contextErase") && inputObject.Get("contextErase").IsNumber())
{
promptContext.context_erase = inputObject.Get("contextErase").As<Napi::Number>().FloatValue();
}
if (inputObject.Has("onPromptToken") && inputObject.Get("onPromptToken").IsFunction())
{
promptWorkerConfig.promptCallback = inputObject.Get("onPromptToken").As<Napi::Function>();
promptWorkerConfig.hasPromptCallback = true;
}
if (inputObject.Has("onResponseToken") && inputObject.Get("onResponseToken").IsFunction())
{
promptWorkerConfig.responseCallback = inputObject.Get("onResponseToken").As<Napi::Function>();
promptWorkerConfig.hasResponseCallback = true;
}
//copy to protect llmodel resources when splitting to new thread // copy to protect llmodel resources when splitting to new thread
// llmodel_prompt_context copiedPrompt = promptContext; // llmodel_prompt_context copiedPrompt = promptContext;
promptWorkerConfig.context = promptContext; promptWorkerConfig.context = promptContext;
promptWorkerConfig.model = GetInference(); promptWorkerConfig.model = GetInference();
promptWorkerConfig.mutex = &inference_mutex; promptWorkerConfig.mutex = &inference_mutex;
promptWorkerConfig.prompt = question; promptWorkerConfig.prompt = prompt;
promptWorkerConfig.result = ""; promptWorkerConfig.result = "";
promptWorkerConfig.promptTemplate = inputObject.Get("promptTemplate").As<Napi::String>();
if (inputObject.Has("special"))
{
promptWorkerConfig.special = inputObject.Get("special").As<Napi::Boolean>();
}
if (inputObject.Has("fakeReply"))
{
// this will be deleted in the worker
promptWorkerConfig.fakeReply = new std::string(inputObject.Get("fakeReply").As<Napi::String>().Utf8Value());
}
auto worker = new PromptWorker(env, promptWorkerConfig); auto worker = new PromptWorker(env, promptWorkerConfig);
worker->Queue(); worker->Queue();
return worker->GetPromise(); return worker->GetPromise();
} }
void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) { void NodeModelWrapper::Dispose(const Napi::CallbackInfo &info)
{
llmodel_model_destroy(inference_); llmodel_model_destroy(inference_);
} }
void NodeModelWrapper::SetThreadCount(const Napi::CallbackInfo& info) { void NodeModelWrapper::SetThreadCount(const Napi::CallbackInfo &info)
if(info[0].IsNumber()) { {
if (info[0].IsNumber())
{
llmodel_setThreadCount(GetInference(), info[0].As<Napi::Number>().Int64Value()); llmodel_setThreadCount(GetInference(), info[0].As<Napi::Number>().Int64Value());
} else { }
Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException(); else
{
Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException();
return; return;
} }
}
Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo& info) {
return Napi::String::New(info.Env(), name);
}
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) {
return Napi::Number::New(info.Env(), llmodel_threadCount(GetInference()));
}
Napi::Value NodeModelWrapper::GetLibraryPath(const Napi::CallbackInfo& info) {
return Napi::String::New(info.Env(),
llmodel_get_implementation_search_path());
}
llmodel_model NodeModelWrapper::GetInference() {
return inference_;
}
//Exports Bindings
Napi::Object Init(Napi::Env env, Napi::Object exports) {
exports["LLModel"] = NodeModelWrapper::GetClass(env);
return exports;
} }
Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo &info)
{
return Napi::String::New(info.Env(), name);
}
Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo &info)
{
return Napi::Number::New(info.Env(), llmodel_threadCount(GetInference()));
}
Napi::Value NodeModelWrapper::GetLibraryPath(const Napi::CallbackInfo &info)
{
return Napi::String::New(info.Env(), llmodel_get_implementation_search_path());
}
llmodel_model NodeModelWrapper::GetInference()
{
return inference_;
}
// Exports Bindings
Napi::Object Init(Napi::Env env, Napi::Object exports)
{
exports["LLModel"] = NodeModelWrapper::GetClass(env);
return exports;
}
NODE_API_MODULE(NODE_GYP_MODULE_NAME, Init) NODE_API_MODULE(NODE_GYP_MODULE_NAME, Init)

View File

@@ -1,62 +1,63 @@
#include <napi.h>
#include "llmodel.h" #include "llmodel.h"
#include <iostream> #include "llmodel_c.h"
#include "llmodel_c.h"
#include "prompt.h" #include "prompt.h"
#include <atomic> #include <atomic>
#include <memory>
#include <filesystem> #include <filesystem>
#include <set> #include <iostream>
#include <memory>
#include <mutex> #include <mutex>
#include <napi.h>
#include <set>
namespace fs = std::filesystem; namespace fs = std::filesystem;
class NodeModelWrapper : public Napi::ObjectWrap<NodeModelWrapper>
{
class NodeModelWrapper: public Napi::ObjectWrap<NodeModelWrapper> { public:
NodeModelWrapper(const Napi::CallbackInfo &);
public: // virtual ~NodeModelWrapper();
NodeModelWrapper(const Napi::CallbackInfo &); Napi::Value GetType(const Napi::CallbackInfo &info);
//virtual ~NodeModelWrapper(); Napi::Value IsModelLoaded(const Napi::CallbackInfo &info);
Napi::Value GetType(const Napi::CallbackInfo& info); Napi::Value StateSize(const Napi::CallbackInfo &info);
Napi::Value IsModelLoaded(const Napi::CallbackInfo& info); // void Finalize(Napi::Env env) override;
Napi::Value StateSize(const Napi::CallbackInfo& info); /**
//void Finalize(Napi::Env env) override; * Prompting the model. This entails spawning a new thread and adding the response tokens
/** * into a thread local string variable.
* Prompting the model. This entails spawning a new thread and adding the response tokens */
* into a thread local string variable. Napi::Value Infer(const Napi::CallbackInfo &info);
*/ void SetThreadCount(const Napi::CallbackInfo &info);
Napi::Value Prompt(const Napi::CallbackInfo& info); void Dispose(const Napi::CallbackInfo &info);
void SetThreadCount(const Napi::CallbackInfo& info); Napi::Value GetName(const Napi::CallbackInfo &info);
void Dispose(const Napi::CallbackInfo& info); Napi::Value ThreadCount(const Napi::CallbackInfo &info);
Napi::Value GetName(const Napi::CallbackInfo& info); Napi::Value GenerateEmbedding(const Napi::CallbackInfo &info);
Napi::Value ThreadCount(const Napi::CallbackInfo& info); Napi::Value HasGpuDevice(const Napi::CallbackInfo &info);
Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info); Napi::Value ListGpus(const Napi::CallbackInfo &info);
Napi::Value HasGpuDevice(const Napi::CallbackInfo& info); Napi::Value InitGpuByString(const Napi::CallbackInfo &info);
Napi::Value ListGpus(const Napi::CallbackInfo& info); Napi::Value GetRequiredMemory(const Napi::CallbackInfo &info);
Napi::Value InitGpuByString(const Napi::CallbackInfo& info); Napi::Value GetGpuDevices(const Napi::CallbackInfo &info);
Napi::Value GetRequiredMemory(const Napi::CallbackInfo& info); /*
Napi::Value GetGpuDevices(const Napi::CallbackInfo& info); * The path that is used to search for the dynamic libraries
/* */
* The path that is used to search for the dynamic libraries Napi::Value GetLibraryPath(const Napi::CallbackInfo &info);
*/ /**
Napi::Value GetLibraryPath(const Napi::CallbackInfo& info); * Creates the LLModel class
/** */
* Creates the LLModel class static Napi::Function GetClass(Napi::Env);
*/ llmodel_model GetInference();
static Napi::Function GetClass(Napi::Env);
llmodel_model GetInference();
private:
/**
* The underlying inference that interfaces with the C interface
*/
llmodel_model inference_;
std::mutex inference_mutex; private:
/**
* The underlying inference that interfaces with the C interface
*/
llmodel_model inference_;
std::string type; std::mutex inference_mutex;
// corresponds to LLModel::name() in typescript
std::string name; std::string type;
int nCtx{}; // corresponds to LLModel::name() in typescript
int nGpuLayers{}; std::string name;
std::string full_model_path; int nCtx{};
int nGpuLayers{};
std::string full_model_path;
}; };

View File

@@ -1,6 +1,6 @@
{ {
"name": "gpt4all", "name": "gpt4all",
"version": "3.2.0", "version": "4.0.0",
"packageManager": "yarn@3.6.1", "packageManager": "yarn@3.6.1",
"main": "src/gpt4all.js", "main": "src/gpt4all.js",
"repository": "nomic-ai/gpt4all", "repository": "nomic-ai/gpt4all",
@@ -22,7 +22,6 @@
], ],
"dependencies": { "dependencies": {
"md5-file": "^5.0.0", "md5-file": "^5.0.0",
"mkdirp": "^3.0.1",
"node-addon-api": "^6.1.0", "node-addon-api": "^6.1.0",
"node-gyp-build": "^4.6.0" "node-gyp-build": "^4.6.0"
}, },

View File

@@ -2,145 +2,195 @@
#include <future> #include <future>
PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config) PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config)
: promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) { : promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env)
if(_config.bHasTokenCallback){ {
_tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this); if (_config.hasResponseCallback)
}
}
PromptWorker::~PromptWorker()
{ {
if(_config.bHasTokenCallback){ _responseCallbackFn = Napi::ThreadSafeFunction::New(config.responseCallback.Env(), config.responseCallback,
_tsfn.Release(); "PromptWorker", 0, 1, this);
}
} }
void PromptWorker::Execute() if (_config.hasPromptCallback)
{ {
_config.mutex->lock(); _promptCallbackFn = Napi::ThreadSafeFunction::New(config.promptCallback.Env(), config.promptCallback,
"PromptWorker", 0, 1, this);
}
}
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model); PromptWorker::~PromptWorker()
{
if (_config.hasResponseCallback)
{
_responseCallbackFn.Release();
}
if (_config.hasPromptCallback)
{
_promptCallbackFn.Release();
}
}
auto ctx = &_config.context; void PromptWorker::Execute()
{
_config.mutex->lock();
if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size()) LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper *>(_config.model);
wrapper->promptContext.tokens.resize(ctx->n_past);
// Copy the C prompt context auto ctx = &_config.context;
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Napi::Error::Fatal( if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size())
// "SUPRA", wrapper->promptContext.tokens.resize(ctx->n_past);
// "About to prompt");
// Call the C++ prompt method
wrapper->llModel->prompt(
_config.prompt,
[](int32_t tid) { return true; },
[this](int32_t token_id, const std::string tok)
{
return ResponseCallback(token_id, tok);
},
[](bool isRecalculating)
{
return isRecalculating;
},
wrapper->promptContext);
// Update the C context by giving access to the wrappers raw pointers to std::vector data // Copy the C prompt context
// which involves no copies wrapper->promptContext.n_past = ctx->n_past;
ctx->logits = wrapper->promptContext.logits.data(); wrapper->promptContext.n_ctx = ctx->n_ctx;
ctx->logits_size = wrapper->promptContext.logits.size(); wrapper->promptContext.n_predict = ctx->n_predict;
ctx->tokens = wrapper->promptContext.tokens.data(); wrapper->promptContext.top_k = ctx->top_k;
ctx->tokens_size = wrapper->promptContext.tokens.size(); wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
wrapper->promptContext.contextErase = ctx->context_erase;
// Update the rest of the C prompt context // Call the C++ prompt method
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock(); wrapper->llModel->prompt(
_config.prompt, _config.promptTemplate, [this](int32_t token_id) { return PromptCallback(token_id); },
[this](int32_t token_id, const std::string token) { return ResponseCallback(token_id, token); },
[](bool isRecalculating) { return isRecalculating; }, wrapper->promptContext, _config.special,
_config.fakeReply);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->logits = wrapper->promptContext.logits.data();
ctx->logits_size = wrapper->promptContext.logits.size();
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
ctx->context_erase = wrapper->promptContext.contextErase;
_config.mutex->unlock();
}
void PromptWorker::OnOK()
{
Napi::Object returnValue = Napi::Object::New(Env());
returnValue.Set("text", result);
returnValue.Set("nPast", _config.context.n_past);
promise.Resolve(returnValue);
delete _config.fakeReply;
}
void PromptWorker::OnError(const Napi::Error &e)
{
delete _config.fakeReply;
promise.Reject(e.Value());
}
Napi::Promise PromptWorker::GetPromise()
{
return promise.Promise();
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
{
if (token_id == -1)
{
return false;
} }
void PromptWorker::OnOK() if (!_config.hasResponseCallback)
{
promise.Resolve(Napi::String::New(Env(), result));
}
void PromptWorker::OnError(const Napi::Error &e)
{
promise.Reject(e.Value());
}
Napi::Promise PromptWorker::GetPromise()
{
return promise.Promise();
}
bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token)
{
if (token_id == -1)
{
return false;
}
if(!_config.bHasTokenCallback){
return true;
}
result += token;
std::promise<bool> promise;
auto info = new TokenCallbackInfo();
info->tokenId = token_id;
info->token = token;
info->total = result;
auto future = promise.get_future();
auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value)
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::String::New(env, value->token);
auto total = Napi::String::New(env,value->total);
auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean();
promise.set_value(jsResult);
// We're finished with the data.
delete value;
});
if (status != napi_ok) {
Napi::Error::Fatal(
"PromptWorkerResponseCallback",
"Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}
bool PromptWorker::RecalculateCallback(bool isRecalculating)
{
return isRecalculating;
}
bool PromptWorker::PromptCallback(int32_t tid)
{ {
return true; return true;
} }
result += token;
std::promise<bool> promise;
auto info = new ResponseCallbackData();
info->tokenId = token_id;
info->token = token;
auto future = promise.get_future();
auto status = _responseCallbackFn.BlockingCall(
info, [&promise](Napi::Env env, Napi::Function jsCallback, ResponseCallbackData *value) {
try
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto token = Napi::String::New(env, value->token);
auto jsResult = jsCallback.Call({token_id, token}).ToBoolean();
promise.set_value(jsResult);
}
catch (const Napi::Error &e)
{
std::cerr << "Error in onResponseToken callback: " << e.what() << std::endl;
promise.set_value(false);
}
delete value;
});
if (status != napi_ok)
{
Napi::Error::Fatal("PromptWorkerResponseCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}
bool PromptWorker::RecalculateCallback(bool isRecalculating)
{
return isRecalculating;
}
bool PromptWorker::PromptCallback(int32_t token_id)
{
if (!_config.hasPromptCallback)
{
return true;
}
std::promise<bool> promise;
auto info = new PromptCallbackData();
info->tokenId = token_id;
auto future = promise.get_future();
auto status = _promptCallbackFn.BlockingCall(
info, [&promise](Napi::Env env, Napi::Function jsCallback, PromptCallbackData *value) {
try
{
// Transform native data into JS data, passing it to the provided
// `jsCallback` -- the TSFN's JavaScript function.
auto token_id = Napi::Number::New(env, value->tokenId);
auto jsResult = jsCallback.Call({token_id}).ToBoolean();
promise.set_value(jsResult);
}
catch (const Napi::Error &e)
{
std::cerr << "Error in onPromptToken callback: " << e.what() << std::endl;
promise.set_value(false);
}
delete value;
});
if (status != napi_ok)
{
Napi::Error::Fatal("PromptWorkerPromptCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed");
}
return future.get();
}

View File

@@ -1,59 +1,72 @@
#ifndef PREDICT_WORKER_H #ifndef PREDICT_WORKER_H
#define PREDICT_WORKER_H #define PREDICT_WORKER_H
#include "napi.h"
#include "llmodel_c.h"
#include "llmodel.h" #include "llmodel.h"
#include <thread> #include "llmodel_c.h"
#include <mutex> #include "napi.h"
#include <iostream>
#include <atomic> #include <atomic>
#include <iostream>
#include <memory> #include <memory>
#include <mutex>
#include <thread>
struct TokenCallbackInfo struct ResponseCallbackData
{
int32_t tokenId;
std::string token;
};
struct PromptCallbackData
{
int32_t tokenId;
};
struct LLModelWrapper
{
LLModel *llModel = nullptr;
LLModel::PromptContext promptContext;
~LLModelWrapper()
{ {
int32_t tokenId; delete llModel;
std::string total; }
std::string token; };
};
struct LLModelWrapper struct PromptWorkerConfig
{ {
LLModel *llModel = nullptr; Napi::Function responseCallback;
LLModel::PromptContext promptContext; bool hasResponseCallback = false;
~LLModelWrapper() { delete llModel; } Napi::Function promptCallback;
}; bool hasPromptCallback = false;
llmodel_model model;
std::mutex *mutex;
std::string prompt;
std::string promptTemplate;
llmodel_prompt_context context;
std::string result;
bool special = false;
std::string *fakeReply = nullptr;
};
struct PromptWorkerConfig class PromptWorker : public Napi::AsyncWorker
{ {
Napi::Function tokenCallback; public:
bool bHasTokenCallback = false; PromptWorker(Napi::Env env, PromptWorkerConfig config);
llmodel_model model; ~PromptWorker();
std::mutex * mutex; void Execute() override;
std::string prompt; void OnOK() override;
llmodel_prompt_context context; void OnError(const Napi::Error &e) override;
std::string result; Napi::Promise GetPromise();
};
class PromptWorker : public Napi::AsyncWorker bool ResponseCallback(int32_t token_id, const std::string token);
{ bool RecalculateCallback(bool isrecalculating);
public: bool PromptCallback(int32_t token_id);
PromptWorker(Napi::Env env, PromptWorkerConfig config);
~PromptWorker();
void Execute() override;
void OnOK() override;
void OnError(const Napi::Error &e) override;
Napi::Promise GetPromise();
bool ResponseCallback(int32_t token_id, const std::string token); private:
bool RecalculateCallback(bool isrecalculating); Napi::Promise::Deferred promise;
bool PromptCallback(int32_t tid); std::string result;
PromptWorkerConfig _config;
Napi::ThreadSafeFunction _responseCallbackFn;
Napi::ThreadSafeFunction _promptCallbackFn;
};
private: #endif // PREDICT_WORKER_H
Napi::Promise::Deferred promise;
std::string result;
PromptWorkerConfig _config;
Napi::ThreadSafeFunction _tsfn;
};
#endif // PREDICT_WORKER_H

View File

@@ -24,7 +24,6 @@ mkdir -p "$NATIVE_DIR" "$BUILD_DIR"
cmake -S ../../gpt4all-backend -B "$BUILD_DIR" && cmake -S ../../gpt4all-backend -B "$BUILD_DIR" &&
cmake --build "$BUILD_DIR" -j --config Release && { cmake --build "$BUILD_DIR" -j --config Release && {
cp "$BUILD_DIR"/libbert*.$LIB_EXT "$NATIVE_DIR"/
cp "$BUILD_DIR"/libgptj*.$LIB_EXT "$NATIVE_DIR"/ cp "$BUILD_DIR"/libgptj*.$LIB_EXT "$NATIVE_DIR"/
cp "$BUILD_DIR"/libllama*.$LIB_EXT "$NATIVE_DIR"/ cp "$BUILD_DIR"/libllama*.$LIB_EXT "$NATIVE_DIR"/
} }

View File

@@ -0,0 +1,31 @@
import { promises as fs } from "node:fs";
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const res = await createCompletion(
model,
"I've got three 🍣 - What shall I name them?",
{
onPromptToken: (tokenId) => {
console.debug("onPromptToken", { tokenId });
// throwing an error will cancel
throw new Error("This is an error");
// const foo = thisMethodDoesNotExist();
// returning false will cancel as well
// return false;
},
onResponseToken: (tokenId, token) => {
console.debug("onResponseToken", { tokenId, token });
// same applies here
},
}
);
console.debug("Output:", {
usage: res.usage,
message: res.choices[0].message,
});

View File

@@ -0,0 +1,65 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const chat = await model.createChatSession({
messages: [
{
role: "user",
content: "I'll tell you a secret password: It's 63445.",
},
{
role: "assistant",
content: "I will do my best to remember that.",
},
{
role: "user",
content:
"And here another fun fact: Bananas may be bluer than bread at night.",
},
{
role: "assistant",
content: "Yes, that makes sense.",
},
],
});
const turn1 = await createCompletion(
chat,
"Please tell me the secret password."
);
console.debug(turn1.choices[0].message);
// "The secret password you shared earlier is 63445.""
const turn2 = await createCompletion(
chat,
"Thanks! Have your heard about the bananas?"
);
console.debug(turn2.choices[0].message);
for (let i = 0; i < 32; i++) {
// gpu go brr
const turn = await createCompletion(
chat,
i % 2 === 0 ? "Tell me a fun fact." : "And a boring one?"
);
console.debug({
message: turn.choices[0].message,
n_past_tokens: turn.usage.n_past_tokens,
});
}
const finalTurn = await createCompletion(
chat,
"Now I forgot the secret password. Can you remind me?"
);
console.debug(finalTurn.choices[0].message);
// result of finalTurn may vary depending on whether the generated facts pushed the secret out of the context window.
// "Of course! The secret password you shared earlier is 63445."
// "I apologize for any confusion. As an AI language model, ..."
model.dispose();

View File

@@ -0,0 +1,19 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true,
device: "gpu",
});
const chat = await model.createChatSession();
await createCompletion(
chat,
"Why are bananas rather blue than bread at night sometimes?",
{
verbose: true,
}
);
await createCompletion(chat, "Are you sure?", {
verbose: true,
});

View File

@@ -1,70 +0,0 @@
import { LLModel, createCompletion, DEFAULT_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY, loadModel } from '../src/gpt4all.js'
const model = await loadModel(
'mistral-7b-openorca.Q4_0.gguf',
{ verbose: true, device: 'gpu' }
);
const ll = model.llm;
try {
class Extended extends LLModel {
}
} catch(e) {
console.log("Extending from native class gone wrong " + e)
}
console.log("state size " + ll.stateSize())
console.log("thread count " + ll.threadCount());
ll.setThreadCount(5);
console.log("thread count " + ll.threadCount());
ll.setThreadCount(4);
console.log("thread count " + ll.threadCount());
console.log("name " + ll.name());
console.log("type: " + ll.type());
console.log("Default directory for models", DEFAULT_DIRECTORY);
console.log("Default directory for libraries", DEFAULT_LIBRARIES_DIRECTORY);
console.log("Has GPU", ll.hasGpuDevice());
console.log("gpu devices", ll.listGpu())
console.log("Required Mem in bytes", ll.memoryNeeded())
const completion1 = await createCompletion(model, [
{ role : 'system', content: 'You are an advanced mathematician.' },
{ role : 'user', content: 'What is 1 + 1?' },
], { verbose: true })
console.log(completion1.choices[0].message)
const completion2 = await createCompletion(model, [
{ role : 'system', content: 'You are an advanced mathematician.' },
{ role : 'user', content: 'What is two plus two?' },
], { verbose: true })
console.log(completion2.choices[0].message)
//CALLING DISPOSE WILL INVALID THE NATIVE MODEL. USE THIS TO CLEANUP
model.dispose()
// At the moment, from testing this code, concurrent model prompting is not possible.
// Behavior: The last prompt gets answered, but the rest are cancelled
// my experience with threading is not the best, so if anyone who is good is willing to give this a shot,
// maybe this is possible
// INFO: threading with llama.cpp is not the best maybe not even possible, so this will be left here as reference
//const responses = await Promise.all([
// createCompletion(model, [
// { role : 'system', content: 'You are an advanced mathematician.' },
// { role : 'user', content: 'What is 1 + 1?' },
// ], { verbose: true }),
// createCompletion(model, [
// { role : 'system', content: 'You are an advanced mathematician.' },
// { role : 'user', content: 'What is 1 + 1?' },
// ], { verbose: true }),
//
//createCompletion(model, [
// { role : 'system', content: 'You are an advanced mathematician.' },
// { role : 'user', content: 'What is 1 + 1?' },
//], { verbose: true })
//
//])
//console.log(responses.map(s => s.choices[0].message))

View File

@@ -0,0 +1,29 @@
import {
loadModel,
createCompletion,
} from "../src/gpt4all.js";
const modelOptions = {
verbose: true,
};
const model1 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
...modelOptions,
device: "gpu", // only one model can be on gpu
});
const model2 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", modelOptions);
const model3 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", modelOptions);
const promptContext = {
verbose: true,
}
const responses = await Promise.all([
createCompletion(model1, "What is 1 + 1?", promptContext),
// generating with the same model instance will wait for the previous completion to finish
createCompletion(model1, "What is 1 + 1?", promptContext),
// generating with different model instances will run in parallel
createCompletion(model2, "What is 1 + 2?", promptContext),
createCompletion(model3, "What is 1 + 3?", promptContext),
]);
console.log(responses.map((res) => res.choices[0].message));

View File

@@ -0,0 +1,26 @@
import { loadModel, createEmbedding } from '../src/gpt4all.js'
import { createGunzip, createGzip, createUnzip } from 'node:zlib';
import { Readable } from 'stream'
import readline from 'readline'
const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding', device: 'gpu' })
console.log("Running with", embedder.llm.threadCount(), "threads");
const unzip = createGunzip();
const url = "https://huggingface.co/datasets/sentence-transformers/embedding-training-data/resolve/main/squad_pairs.jsonl.gz"
const stream = await fetch(url)
.then(res => Readable.fromWeb(res.body));
const lineReader = readline.createInterface({
input: stream.pipe(unzip),
crlfDelay: Infinity
})
lineReader.on('line', line => {
//pairs of questions and answers
const question_answer = JSON.parse(line)
console.log(createEmbedding(embedder, question_answer))
})
lineReader.on('close', () => embedder.dispose())

View File

@@ -1,6 +1,12 @@
import { loadModel, createEmbedding } from '../src/gpt4all.js' import { loadModel, createEmbedding } from '../src/gpt4all.js'
const embedder = await loadModel("ggml-all-MiniLM-L6-v2-f16.bin", { verbose: true, type: 'embedding'}) const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding' , device: 'gpu' })
console.log(createEmbedding(embedder, "Accept your current situation")) try {
console.log(createEmbedding(embedder, ["Accept your current situation", "12312"], { prefix: "search_document" }))
} catch(e) {
console.log(e)
}
embedder.dispose()

View File

@@ -1,41 +0,0 @@
import gpt from '../src/gpt4all.js'
const model = await gpt.loadModel("mistral-7b-openorca.Q4_0.gguf", { device: 'gpu' })
process.stdout.write('Response: ')
const tokens = gpt.generateTokens(model, [{
role: 'user',
content: "How are you ?"
}], { nPredict: 2048 })
for await (const token of tokens){
process.stdout.write(token);
}
const result = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure?"
}])
console.log(result)
const result2 = await gpt.createCompletion(model, [{
role: 'user',
content: "You sure you sure?"
}])
console.log(result2)
const tokens2 = gpt.generateTokens(model, [{
role: 'user',
content: "If 3 + 3 is 5, what is 2 + 2?"
}], { nPredict: 2048 })
for await (const token of tokens2){
process.stdout.write(token);
}
console.log("done")
model.dispose();

View File

@@ -0,0 +1,61 @@
import {
LLModel,
createCompletion,
DEFAULT_DIRECTORY,
DEFAULT_LIBRARIES_DIRECTORY,
loadModel,
} from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const ll = model.llm;
try {
class Extended extends LLModel {}
} catch (e) {
console.log("Extending from native class gone wrong " + e);
}
console.log("state size " + ll.stateSize());
console.log("thread count " + ll.threadCount());
ll.setThreadCount(5);
console.log("thread count " + ll.threadCount());
ll.setThreadCount(4);
console.log("thread count " + ll.threadCount());
console.log("name " + ll.name());
console.log("type: " + ll.type());
console.log("Default directory for models", DEFAULT_DIRECTORY);
console.log("Default directory for libraries", DEFAULT_LIBRARIES_DIRECTORY);
console.log("Has GPU", ll.hasGpuDevice());
console.log("gpu devices", ll.listGpu());
console.log("Required Mem in bytes", ll.memoryNeeded());
// to ingest a custom system prompt without using a chat session.
await createCompletion(
model,
"<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n",
{
promptTemplate: "%1",
nPredict: 0,
special: true,
}
);
const completion1 = await createCompletion(model, "What is 1 + 1?", {
verbose: true,
});
console.log(`🤖 > ${completion1.choices[0].message.content}`);
//Very specific:
// tested on Ubuntu 22.0, Linux Mint, if I set nPast to 100, the app hangs.
const completion2 = await createCompletion(model, "And if we add two?", {
verbose: true,
});
console.log(`🤖 > ${completion2.choices[0].message.content}`);
//CALLING DISPOSE WILL INVALID THE NATIVE MODEL. USE THIS TO CLEANUP
model.dispose();
console.log("model disposed, exiting...");

View File

@@ -0,0 +1,21 @@
import { promises as fs } from "node:fs";
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
nCtx: 32768,
});
const typeDefSource = await fs.readFile("./src/gpt4all.d.ts", "utf-8");
const res = await createCompletion(
model,
"Here are the type definitions for the GPT4All API:\n\n" +
typeDefSource +
"\n\nHow do I create a completion with a really large context window?",
{
verbose: true,
}
);
console.debug(res.choices[0].message);

View File

@@ -0,0 +1,60 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model1 = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
device: "gpu",
nCtx: 4096,
});
const chat1 = await model1.createChatSession({
temperature: 0.8,
topP: 0.7,
topK: 60,
});
const chat1turn1 = await createCompletion(
chat1,
"Outline a short story concept for adults. About why bananas are rather blue than bread is green at night sometimes. Not too long."
);
console.debug(chat1turn1.choices[0].message);
const chat1turn2 = await createCompletion(
chat1,
"Lets sprinkle some plot twists. And a cliffhanger at the end."
);
console.debug(chat1turn2.choices[0].message);
const chat1turn3 = await createCompletion(
chat1,
"Analyze your plot. Find the weak points."
);
console.debug(chat1turn3.choices[0].message);
const chat1turn4 = await createCompletion(
chat1,
"Rewrite it based on the analysis."
);
console.debug(chat1turn4.choices[0].message);
model1.dispose();
const model2 = await loadModel("gpt4all-falcon-newbpe-q4_0.gguf", {
device: "gpu",
});
const chat2 = await model2.createChatSession({
messages: chat1.messages,
});
const chat2turn1 = await createCompletion(
chat2,
"Give three ideas how this plot could be improved."
);
console.debug(chat2turn1.choices[0].message);
const chat2turn2 = await createCompletion(
chat2,
"Revise the plot, applying your ideas."
);
console.debug(chat2turn2.choices[0].message);
model2.dispose();

View File

@@ -0,0 +1,50 @@
import { loadModel, createCompletion } from "../src/gpt4all.js";
const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", {
verbose: true,
device: "gpu",
});
const messages = [
{
role: "system",
content: "<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n",
},
{
role: "user",
content: "What's 2+2?",
},
{
role: "assistant",
content: "5",
},
{
role: "user",
content: "Are you sure?",
},
];
const res1 = await createCompletion(model, messages);
console.debug(res1.choices[0].message);
messages.push(res1.choices[0].message);
messages.push({
role: "user",
content: "Could you double check that?",
});
const res2 = await createCompletion(model, messages);
console.debug(res2.choices[0].message);
messages.push(res2.choices[0].message);
messages.push({
role: "user",
content: "Let's bring out the big calculators.",
});
const res3 = await createCompletion(model, messages);
console.debug(res3.choices[0].message);
messages.push(res3.choices[0].message);
// console.debug(messages);

View File

@@ -0,0 +1,57 @@
import {
loadModel,
createCompletion,
createCompletionStream,
createCompletionGenerator,
} from "../src/gpt4all.js";
const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", {
device: "gpu",
});
process.stdout.write("### Stream:");
const stream = createCompletionStream(model, "How are you?");
stream.tokens.on("data", (data) => {
process.stdout.write(data);
});
await stream.result;
process.stdout.write("\n");
process.stdout.write("### Stream with pipe:");
const stream2 = createCompletionStream(
model,
"Please say something nice about node streams."
);
stream2.tokens.pipe(process.stdout);
const stream2Res = await stream2.result;
process.stdout.write("\n");
process.stdout.write("### Generator:");
const gen = createCompletionGenerator(model, "generators instead?", {
nPast: stream2Res.usage.n_past_tokens,
});
for await (const chunk of gen) {
process.stdout.write(chunk);
}
process.stdout.write("\n");
process.stdout.write("### Callback:");
await createCompletion(model, "Why not just callbacks?", {
onResponseToken: (tokenId, token) => {
process.stdout.write(token);
},
});
process.stdout.write("\n");
process.stdout.write("### 2nd Generator:");
const gen2 = createCompletionGenerator(model, "If 3 + 3 is 5, what is 2 + 2?");
let chunk = await gen2.next();
while (!chunk.done) {
process.stdout.write(chunk.value);
chunk = await gen2.next();
}
process.stdout.write("\n");
console.debug("generator finished", chunk);
model.dispose();

View File

@@ -0,0 +1,19 @@
import {
loadModel,
createCompletion,
} from "../src/gpt4all.js";
const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", {
verbose: true,
device: "gpu",
});
const chat = await model.createChatSession({
verbose: true,
systemPrompt: "<|im_start|>system\nRoleplay as Batman. Answer as if you are Batman, never say you're an Assistant.\n<|im_end|>",
});
const turn1 = await createCompletion(chat, "You have any plans tonight?");
console.log(turn1.choices[0].message);
// "I'm afraid I must decline any personal invitations tonight. As Batman, I have a responsibility to protect Gotham City."
model.dispose();

View File

@@ -0,0 +1,169 @@
const { DEFAULT_PROMPT_CONTEXT } = require("./config");
const { prepareMessagesForIngest } = require("./util");
class ChatSession {
model;
modelName;
/**
* @type {import('./gpt4all').ChatMessage[]}
*/
messages;
/**
* @type {string}
*/
systemPrompt;
/**
* @type {import('./gpt4all').LLModelPromptContext}
*/
promptContext;
/**
* @type {boolean}
*/
initialized;
constructor(model, chatSessionOpts = {}) {
const { messages, systemPrompt, ...sessionDefaultPromptContext } =
chatSessionOpts;
this.model = model;
this.modelName = model.llm.name();
this.messages = messages ?? [];
this.systemPrompt = systemPrompt ?? model.config.systemPrompt;
this.initialized = false;
this.promptContext = {
...DEFAULT_PROMPT_CONTEXT,
...sessionDefaultPromptContext,
nPast: 0,
};
}
async initialize(completionOpts = {}) {
if (this.model.activeChatSession !== this) {
this.model.activeChatSession = this;
}
let tokensIngested = 0;
// ingest system prompt
if (this.systemPrompt) {
const systemRes = await this.model.generate(this.systemPrompt, {
promptTemplate: "%1",
nPredict: 0,
special: true,
nBatch: this.promptContext.nBatch,
// verbose: true,
});
tokensIngested += systemRes.tokensIngested;
this.promptContext.nPast = systemRes.nPast;
}
// ingest initial messages
if (this.messages.length > 0) {
tokensIngested += await this.ingestMessages(
this.messages,
completionOpts
);
}
this.initialized = true;
return tokensIngested;
}
async ingestMessages(messages, completionOpts = {}) {
const turns = prepareMessagesForIngest(messages);
// send the message pairs to the model
let tokensIngested = 0;
for (const turn of turns) {
const turnRes = await this.model.generate(turn.user, {
...this.promptContext,
...completionOpts,
fakeReply: turn.assistant,
});
tokensIngested += turnRes.tokensIngested;
this.promptContext.nPast = turnRes.nPast;
}
return tokensIngested;
}
async generate(input, completionOpts = {}) {
if (this.model.activeChatSession !== this) {
throw new Error(
"Chat session is not active. Create a new chat session or call initialize to continue."
);
}
if (completionOpts.nPast > this.promptContext.nPast) {
throw new Error(
`nPast cannot be greater than ${this.promptContext.nPast}.`
);
}
let tokensIngested = 0;
if (!this.initialized) {
tokensIngested += await this.initialize(completionOpts);
}
let prompt = input;
if (Array.isArray(input)) {
// assuming input is a messages array
// -> tailing user message will be used as the final prompt. its optional.
// -> all system messages will be ignored.
// -> all other messages will be ingested with fakeReply
// -> user/assistant messages will be pushed into the messages array
let tailingUserMessage = "";
let messagesToIngest = input;
const lastMessage = input[input.length - 1];
if (lastMessage.role === "user") {
tailingUserMessage = lastMessage.content;
messagesToIngest = input.slice(0, input.length - 1);
}
if (messagesToIngest.length > 0) {
tokensIngested += await this.ingestMessages(
messagesToIngest,
completionOpts
);
this.messages.push(...messagesToIngest);
}
if (tailingUserMessage) {
prompt = tailingUserMessage;
} else {
return {
text: "",
nPast: this.promptContext.nPast,
tokensIngested,
tokensGenerated: 0,
};
}
}
const result = await this.model.generate(prompt, {
...this.promptContext,
...completionOpts,
});
this.promptContext.nPast = result.nPast;
result.tokensIngested += tokensIngested;
this.messages.push({
role: "user",
content: prompt,
});
this.messages.push({
role: "assistant",
content: result.text,
});
return result;
}
}
module.exports = {
ChatSession,
};

View File

@@ -27,15 +27,16 @@ const DEFAULT_MODEL_CONFIG = {
promptTemplate: "### Human:\n%1\n\n### Assistant:\n", promptTemplate: "### Human:\n%1\n\n### Assistant:\n",
} }
const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models2.json"; const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models3.json";
const DEFAULT_PROMPT_CONTEXT = { const DEFAULT_PROMPT_CONTEXT = {
temp: 0.7, temp: 0.1,
topK: 40, topK: 40,
topP: 0.4, topP: 0.9,
minP: 0.0,
repeatPenalty: 1.18, repeatPenalty: 1.18,
repeatLastN: 64, repeatLastN: 10,
nBatch: 8, nBatch: 100,
} }
module.exports = { module.exports = {

View File

@@ -1,43 +1,11 @@
/// <reference types="node" /> /// <reference types="node" />
declare module "gpt4all"; declare module "gpt4all";
type ModelType = "gptj" | "llama" | "mpt" | "replit";
// NOTE: "deprecated" tag in below comment breaks the doc generator https://github.com/documentationjs/documentation/issues/1596
/**
* Full list of models available
* DEPRECATED!! These model names are outdated and this type will not be maintained, please use a string literal instead
*/
interface ModelFile {
/** List of GPT-J Models */
gptj:
| "ggml-gpt4all-j-v1.3-groovy.bin"
| "ggml-gpt4all-j-v1.2-jazzy.bin"
| "ggml-gpt4all-j-v1.1-breezy.bin"
| "ggml-gpt4all-j.bin";
/** List Llama Models */
llama:
| "ggml-gpt4all-l13b-snoozy.bin"
| "ggml-vicuna-7b-1.1-q4_2.bin"
| "ggml-vicuna-13b-1.1-q4_2.bin"
| "ggml-wizardLM-7B.q4_2.bin"
| "ggml-stable-vicuna-13B.q4_2.bin"
| "ggml-nous-gpt4-vicuna-13b.bin"
| "ggml-v3-13b-hermes-q5_1.bin";
/** List of MPT Models */
mpt:
| "ggml-mpt-7b-base.bin"
| "ggml-mpt-7b-chat.bin"
| "ggml-mpt-7b-instruct.bin";
/** List of Replit Models */
replit: "ggml-replit-code-v1-3b.bin";
}
interface LLModelOptions { interface LLModelOptions {
/** /**
* Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user. * Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user.
*/ */
type?: ModelType; type?: string;
model_name: string; model_name: string;
model_path: string; model_path: string;
library_path?: string; library_path?: string;
@@ -51,47 +19,259 @@ interface ModelConfig {
} }
/** /**
* Callback for controlling token generation * Options for the chat session.
*/ */
type TokenCallback = (tokenId: number, token: string, total: string) => boolean interface ChatSessionOptions extends Partial<LLModelPromptContext> {
/**
* System prompt to ingest on initialization.
*/
systemPrompt?: string;
/** /**
* * Messages to ingest on initialization.
* InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers. */
* messages?: ChatMessage[];
*/
declare class InferenceModel {
constructor(llm: LLModel, config: ModelConfig);
llm: LLModel;
config: ModelConfig;
generate(
prompt: string,
options?: Partial<LLModelPromptContext>,
callback?: TokenCallback
): Promise<string>;
/**
* delete and cleanup the native model
*/
dispose(): void
} }
/**
* ChatSession utilizes an InferenceModel for efficient processing of chat conversations.
*/
declare class ChatSession implements CompletionProvider {
/**
* Constructs a new ChatSession using the provided InferenceModel and options.
* Does not set the chat session as the active chat session until initialize is called.
* @param {InferenceModel} model An InferenceModel instance.
* @param {ChatSessionOptions} [options] Options for the chat session including default completion options.
*/
constructor(model: InferenceModel, options?: ChatSessionOptions);
/**
* The underlying InferenceModel used for generating completions.
*/
model: InferenceModel;
/**
* The name of the model.
*/
modelName: string;
/**
* The messages that have been exchanged in this chat session.
*/
messages: ChatMessage[];
/**
* The system prompt that has been ingested at the beginning of the chat session.
*/
systemPrompt: string;
/**
* The current prompt context of the chat session.
*/
promptContext: LLModelPromptContext;
/**
* Ingests system prompt and initial messages.
* Sets this chat session as the active chat session of the model.
* @param {CompletionOptions} [options] Set completion options for initialization.
* @returns {Promise<number>} The number of tokens ingested during initialization. systemPrompt + messages.
*/
initialize(completionOpts?: CompletionOptions): Promise<number>;
/**
* Prompts the model in chat-session context.
* @param {CompletionInput} input Input string or message array.
* @param {CompletionOptions} [options] Set completion options for this generation.
* @returns {Promise<InferenceResult>} The inference result.
* @throws {Error} If the chat session is not the active chat session of the model.
* @throws {Error} If nPast is set to a value higher than what has been ingested in the session.
*/
generate(
input: CompletionInput,
options?: CompletionOptions
): Promise<InferenceResult>;
}
/**
* Shape of InferenceModel generations.
*/
interface InferenceResult extends LLModelInferenceResult {
tokensIngested: number;
tokensGenerated: number;
}
/**
* InferenceModel represents an LLM which can make next-token predictions.
*/
declare class InferenceModel implements CompletionProvider {
constructor(llm: LLModel, config: ModelConfig);
/** The native LLModel */
llm: LLModel;
/** The configuration the instance was constructed with. */
config: ModelConfig;
/** The active chat session of the model. */
activeChatSession?: ChatSession;
/** The name of the model. */
modelName: string;
/**
* Create a chat session with the model and set it as the active chat session of this model.
* A model instance can only have one active chat session at a time.
* @param {ChatSessionOptions} options The options for the chat session.
* @returns {Promise<ChatSession>} The chat session.
*/
createChatSession(options?: ChatSessionOptions): Promise<ChatSession>;
/**
* Prompts the model with a given input and optional parameters.
* @param {CompletionInput} input The prompt input.
* @param {CompletionOptions} options Prompt context and other options.
* @returns {Promise<InferenceResult>} The model's response to the prompt.
* @throws {Error} If nPast is set to a value smaller than 0.
* @throws {Error} If a messages array without a tailing user message is provided.
*/
generate(
prompt: string,
options?: CompletionOptions
): Promise<InferenceResult>;
/**
* delete and cleanup the native model
*/
dispose(): void;
}
/**
* Options for generating one or more embeddings.
*/
interface EmbedddingOptions {
/**
* The model-specific prefix representing the embedding task, without the trailing colon. For Nomic Embed
* this can be `search_query`, `search_document`, `classification`, or `clustering`.
*/
prefix?: string;
/**
*The embedding dimension, for use with Matryoshka-capable models. Defaults to full-size.
* @default determines on the model being used.
*/
dimensionality?: number;
/**
* How to handle texts longer than the model can accept. One of `mean` or `truncate`.
* @default "mean"
*/
longTextMode?: "mean" | "truncate";
/**
* Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens
* with long_text_mode="mean" will raise an error. Disabled by default.
* @default false
*/
atlas?: boolean;
}
/**
* The nodejs moral equivalent to python binding's Embed4All().embed()
* meow
* @param {EmbeddingModel} model The embedding model instance.
* @param {string} text Text to embed.
* @param {EmbeddingOptions} options Optional parameters for the embedding.
* @returns {EmbeddingResult} The embedding result.
* @throws {Error} If dimensionality is set to a value smaller than 1.
*/
declare function createEmbedding(
model: EmbeddingModel,
text: string,
options?: EmbedddingOptions
): EmbeddingResult<Float32Array>;
/**
* Overload that takes multiple strings to embed.
* @param {EmbeddingModel} model The embedding model instance.
* @param {string[]} texts Texts to embed.
* @param {EmbeddingOptions} options Optional parameters for the embedding.
* @returns {EmbeddingResult<Float32Array[]>} The embedding result.
* @throws {Error} If dimensionality is set to a value smaller than 1.
*/
declare function createEmbedding(
model: EmbeddingModel,
text: string[],
options?: EmbedddingOptions
): EmbeddingResult<Float32Array[]>;
/**
* The resulting embedding.
*/
interface EmbeddingResult<T> {
/**
* Encoded token count. Includes overlap but specifically excludes tokens used for the prefix/task_type, BOS/CLS token, and EOS/SEP token
**/
n_prompt_tokens: number;
embeddings: T;
}
/** /**
* EmbeddingModel represents an LLM which can create embeddings, which are float arrays * EmbeddingModel represents an LLM which can create embeddings, which are float arrays
*/ */
declare class EmbeddingModel { declare class EmbeddingModel {
constructor(llm: LLModel, config: ModelConfig); constructor(llm: LLModel, config: ModelConfig);
/** The native LLModel */
llm: LLModel; llm: LLModel;
/** The configuration the instance was constructed with. */
config: ModelConfig; config: ModelConfig;
embed(text: string): Float32Array; /**
* Create an embedding from a given input string. See EmbeddingOptions.
* @param {string} text
* @param {string} prefix
* @param {number} dimensionality
* @param {boolean} doMean
* @param {boolean} atlas
* @returns {EmbeddingResult<Float32Array>} The embedding result.
*/
embed(
text: string,
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): EmbeddingResult<Float32Array>;
/**
* Create an embedding from a given input text array. See EmbeddingOptions.
* @param {string[]} text
* @param {string} prefix
* @param {number} dimensionality
* @param {boolean} doMean
* @param {boolean} atlas
* @returns {EmbeddingResult<Float32Array[]>} The embedding result.
*/
embed(
text: string[],
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): EmbeddingResult<Float32Array[]>;
/** /**
* delete and cleanup the native model * delete and cleanup the native model
*/ */
dispose(): void dispose(): void;
}
/**
* Shape of LLModel's inference result.
*/
interface LLModelInferenceResult {
text: string;
nPast: number;
}
interface LLModelInferenceOptions extends Partial<LLModelPromptContext> {
/** Callback for response tokens, called for each generated token.
* @param {number} tokenId The token id.
* @param {string} token The token.
* @returns {boolean | undefined} Whether to continue generating tokens.
* */
onResponseToken?: (tokenId: number, token: string) => boolean | void;
/** Callback for prompt tokens, called for each input token in the prompt.
* @param {number} tokenId The token id.
* @returns {boolean | undefined} Whether to continue ingesting the prompt.
* */
onPromptToken?: (tokenId: number) => boolean | void;
} }
/** /**
@@ -101,14 +281,13 @@ declare class EmbeddingModel {
declare class LLModel { declare class LLModel {
/** /**
* Initialize a new LLModel. * Initialize a new LLModel.
* @param path Absolute path to the model file. * @param {string} path Absolute path to the model file.
* @throws {Error} If the model file does not exist. * @throws {Error} If the model file does not exist.
*/ */
constructor(path: string);
constructor(options: LLModelOptions); constructor(options: LLModelOptions);
/** either 'gpt', mpt', or 'llama' or undefined */ /** undefined or user supplied */
type(): ModelType | undefined; type(): string | undefined;
/** The name of the model. */ /** The name of the model. */
name(): string; name(): string;
@@ -134,29 +313,53 @@ declare class LLModel {
setThreadCount(newNumber: number): void; setThreadCount(newNumber: number): void;
/** /**
* Prompt the model with a given input and optional parameters. * Prompt the model directly with a given input string and optional parameters.
* This is the raw output from model. * Use the higher level createCompletion methods for a more user-friendly interface.
* Use the prompt function exported for a value * @param {string} prompt The prompt input.
* @param q The prompt input. * @param {LLModelInferenceOptions} options Optional parameters for the generation.
* @param params Optional parameters for the prompt context. * @returns {LLModelInferenceResult} The response text and final context size.
* @param callback - optional callback to control token generation.
* @returns The result of the model prompt.
*/ */
raw_prompt( infer(
q: string, prompt: string,
params: Partial<LLModelPromptContext>, options: LLModelInferenceOptions
callback?: TokenCallback ): Promise<LLModelInferenceResult>;
): Promise<string>
/** /**
* Embed text with the model. Keep in mind that * Embed text with the model. See EmbeddingOptions for more information.
* not all models can embed text, (only bert can embed as of 07/16/2023 (mm/dd/yyyy)) * Use the higher level createEmbedding methods for a more user-friendly interface.
* Use the prompt function exported for a value * @param {string} text
* @param q The prompt input. * @param {string} prefix
* @param params Optional parameters for the prompt context. * @param {number} dimensionality
* @returns The result of the model prompt. * @param {boolean} doMean
* @param {boolean} atlas
* @returns {Float32Array} The embedding of the text.
*/ */
embed(text: string): Float32Array; embed(
text: string,
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): Float32Array;
/**
* Embed multiple texts with the model. See EmbeddingOptions for more information.
* Use the higher level createEmbedding methods for a more user-friendly interface.
* @param {string[]} texts
* @param {string} prefix
* @param {number} dimensionality
* @param {boolean} doMean
* @param {boolean} atlas
* @returns {Float32Array[]} The embeddings of the texts.
*/
embed(
texts: string,
prefix: string,
dimensionality: number,
doMean: boolean,
atlas: boolean
): Float32Array[];
/** /**
* Whether the model is loaded or not. * Whether the model is loaded or not.
*/ */
@@ -166,81 +369,97 @@ declare class LLModel {
* Where to search for the pluggable backend libraries * Where to search for the pluggable backend libraries
*/ */
setLibraryPath(s: string): void; setLibraryPath(s: string): void;
/** /**
* Where to get the pluggable backend libraries * Where to get the pluggable backend libraries
*/ */
getLibraryPath(): string; getLibraryPath(): string;
/** /**
* Initiate a GPU by a string identifier. * Initiate a GPU by a string identifier.
* @param {number} memory_required Should be in the range size_t or will throw * @param {number} memory_required Should be in the range size_t or will throw
* @param {string} device_name 'amd' | 'nvidia' | 'intel' | 'gpu' | gpu name. * @param {string} device_name 'amd' | 'nvidia' | 'intel' | 'gpu' | gpu name.
* read LoadModelOptions.device for more information * read LoadModelOptions.device for more information
*/ */
initGpuByString(memory_required: number, device_name: string): boolean initGpuByString(memory_required: number, device_name: string): boolean;
/** /**
* From C documentation * From C documentation
* @returns True if a GPU device is successfully initialized, false otherwise. * @returns True if a GPU device is successfully initialized, false otherwise.
*/ */
hasGpuDevice(): boolean hasGpuDevice(): boolean;
/**
* GPUs that are usable for this LLModel
* @param nCtx Maximum size of context window
* @throws if hasGpuDevice returns false (i think)
* @returns
*/
listGpu(nCtx: number) : GpuDevice[]
/** /**
* delete and cleanup the native model * GPUs that are usable for this LLModel
* @param {number} nCtx Maximum size of context window
* @throws if hasGpuDevice returns false (i think)
* @returns
*/ */
dispose(): void listGpu(nCtx: number): GpuDevice[];
/**
* delete and cleanup the native model
*/
dispose(): void;
} }
/** /**
* an object that contains gpu data on this machine. * an object that contains gpu data on this machine.
*/ */
interface GpuDevice { interface GpuDevice {
index: number; index: number;
/** /**
* same as VkPhysicalDeviceType * same as VkPhysicalDeviceType
*/ */
type: number; type: number;
heapSize : number; heapSize: number;
name: string; name: string;
vendor: string; vendor: string;
} }
/** /**
* Options that configure a model's behavior. * Options that configure a model's behavior.
*/ */
interface LoadModelOptions { interface LoadModelOptions {
/**
* Where to look for model files.
*/
modelPath?: string; modelPath?: string;
/**
* Where to look for the backend libraries.
*/
librariesPath?: string; librariesPath?: string;
/**
* The path to the model configuration file, useful for offline usage or custom model configurations.
*/
modelConfigFile?: string; modelConfigFile?: string;
/**
* Whether to allow downloading the model if it is not present at the specified path.
*/
allowDownload?: boolean; allowDownload?: boolean;
/**
* Enable verbose logging.
*/
verbose?: boolean; verbose?: boolean;
/* The processing unit on which the model will run. It can be set to /**
* The processing unit on which the model will run. It can be set to
* - "cpu": Model will run on the central processing unit. * - "cpu": Model will run on the central processing unit.
* - "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor. * - "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
* - "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor. * - "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
* - "gpu name": Model will run on the GPU that matches the name if it's available.
Alternatively, a specific GPU name can also be provided, and the model will run on the GPU that matches the name * Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All
if it's available. * instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the
* model.
Default is "cpu". * @default "cpu"
*/
Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All
instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the
model.
*/
device?: string; device?: string;
/* /**
* The Maximum window size of this model * The Maximum window size of this model
* Default of 2048 * @default 2048
*/ */
nCtx?: number; nCtx?: number;
/* /**
* Number of gpu layers needed * Number of gpu layers needed
* Default of 100 * @default 100
*/ */
ngl?: number; ngl?: number;
} }
@@ -277,66 +496,84 @@ declare function loadModel(
): Promise<InferenceModel | EmbeddingModel>; ): Promise<InferenceModel | EmbeddingModel>;
/** /**
* The nodejs equivalent to python binding's chat_completion * Interface for createCompletion methods, implemented by InferenceModel and ChatSession.
* @param {InferenceModel} model - The language model object. * Implement your own CompletionProvider or extend ChatSession to generate completions with custom logic.
* @param {PromptMessage[]} messages - The array of messages for the conversation.
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {CompletionReturn} The completion result.
*/ */
declare function createCompletion( interface CompletionProvider {
model: InferenceModel, modelName: string;
messages: PromptMessage[], generate(
options?: CompletionOptions input: CompletionInput,
): Promise<CompletionReturn>; options?: CompletionOptions
): Promise<InferenceResult>;
/**
* The nodejs moral equivalent to python binding's Embed4All().embed()
* meow
* @param {EmbeddingModel} model - The language model object.
* @param {string} text - text to embed
* @returns {Float32Array} The completion result.
*/
declare function createEmbedding(
model: EmbeddingModel,
text: string
): Float32Array;
/**
* The options for creating the completion.
*/
interface CompletionOptions extends Partial<LLModelPromptContext> {
/**
* Indicates if verbose logging is enabled.
* @default true
*/
verbose?: boolean;
/**
* Template for the system message. Will be put before the conversation with %1 being replaced by all system messages.
* Note that if this is not defined, system messages will not be included in the prompt.
*/
systemPromptTemplate?: string;
/**
* Template for user messages, with %1 being replaced by the message.
*/
promptTemplate?: boolean;
/**
* The initial instruction for the model, on top of the prompt
*/
promptHeader?: string;
/**
* The last instruction for the model, appended to the end of the prompt.
*/
promptFooter?: string;
} }
/** /**
* A message in the conversation, identical to OpenAI's chat message. * Options for creating a completion.
*/ */
interface PromptMessage { interface CompletionOptions extends LLModelInferenceOptions {
/**
* Indicates if verbose logging is enabled.
* @default false
*/
verbose?: boolean;
}
/**
* The input for creating a completion. May be a string or an array of messages.
*/
type CompletionInput = string | ChatMessage[];
/**
* The nodejs equivalent to python binding's chat_completion
* @param {CompletionProvider} provider - The inference model object or chat session
* @param {CompletionInput} input - The input string or message array
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {CompletionResult} The completion result.
*/
declare function createCompletion(
provider: CompletionProvider,
input: CompletionInput,
options?: CompletionOptions
): Promise<CompletionResult>;
/**
* Streaming variant of createCompletion, returns a stream of tokens and a promise that resolves to the completion result.
* @param {CompletionProvider} provider - The inference model object or chat session
* @param {CompletionInput} input - The input string or message array
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {CompletionStreamReturn} An object of token stream and the completion result promise.
*/
declare function createCompletionStream(
provider: CompletionProvider,
input: CompletionInput,
options?: CompletionOptions
): CompletionStreamReturn;
/**
* The result of a streamed completion, containing a stream of tokens and a promise that resolves to the completion result.
*/
interface CompletionStreamReturn {
tokens: NodeJS.ReadableStream;
result: Promise<CompletionResult>;
}
/**
* Async generator variant of createCompletion, yields tokens as they are generated and returns the completion result.
* @param {CompletionProvider} provider - The inference model object or chat session
* @param {CompletionInput} input - The input string or message array
* @param {CompletionOptions} options - The options for creating the completion.
* @returns {AsyncGenerator<string>} The stream of generated tokens
*/
declare function createCompletionGenerator(
provider: CompletionProvider,
input: CompletionInput,
options: CompletionOptions
): AsyncGenerator<string, CompletionResult>;
/**
* A message in the conversation.
*/
interface ChatMessage {
/** The role of the message. */ /** The role of the message. */
role: "system" | "assistant" | "user"; role: "system" | "assistant" | "user";
@@ -345,34 +582,31 @@ interface PromptMessage {
} }
/** /**
* The result of the completion, similar to OpenAI's format. * The result of a completion.
*/ */
interface CompletionReturn { interface CompletionResult {
/** The model used for the completion. */ /** The model used for the completion. */
model: string; model: string;
/** Token usage report. */ /** Token usage report. */
usage: { usage: {
/** The number of tokens used in the prompt. */ /** The number of tokens ingested during the completion. */
prompt_tokens: number; prompt_tokens: number;
/** The number of tokens used in the completion. */ /** The number of tokens generated in the completion. */
completion_tokens: number; completion_tokens: number;
/** The total number of tokens used. */ /** The total number of tokens used. */
total_tokens: number; total_tokens: number;
/** Number of tokens used in the conversation. */
n_past_tokens: number;
}; };
/** The generated completions. */ /** The generated completion. */
choices: CompletionChoice[]; choices: Array<{
} message: ChatMessage;
}>;
/**
* A completion choice, similar to OpenAI's format.
*/
interface CompletionChoice {
/** Response message */
message: PromptMessage;
} }
/** /**
@@ -385,19 +619,33 @@ interface LLModelPromptContext {
/** The size of the raw tokens vector. */ /** The size of the raw tokens vector. */
tokensSize: number; tokensSize: number;
/** The number of tokens in the past conversation. */ /** The number of tokens in the past conversation.
* This may be used to "roll back" the conversation to a previous state.
* Note that for most use cases the default value should be sufficient and this should not be set.
* @default 0 For completions using InferenceModel, meaning the model will only consider the input prompt.
* @default nPast For completions using ChatSession. This means the context window will be automatically determined
* and possibly resized (see contextErase) to keep the conversation performant.
* */
nPast: number; nPast: number;
/** The number of tokens possible in the context window. /** The maximum number of tokens to predict.
* @default 1024 * @default 4096
*/
nCtx: number;
/** The number of tokens to predict.
* @default 128
* */ * */
nPredict: number; nPredict: number;
/** Template for user / assistant message pairs.
* %1 is required and will be replaced by the user input.
* %2 is optional and will be replaced by the assistant response. If not present, the assistant response will be appended.
*/
promptTemplate?: string;
/** The context window size. Do not use, it has no effect. See loadModel options.
* THIS IS DEPRECATED!!!
* Use loadModel's nCtx option instead.
* @default 2048
*/
nCtx: number;
/** The top-k logits to sample from. /** The top-k logits to sample from.
* Top-K sampling selects the next token only from the top K most likely tokens predicted by the model. * Top-K sampling selects the next token only from the top K most likely tokens predicted by the model.
* It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit * It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit
@@ -409,26 +657,33 @@ interface LLModelPromptContext {
topK: number; topK: number;
/** The nucleus sampling probability threshold. /** The nucleus sampling probability threshold.
* Top-P limits the selection of the next token to a subset of tokens with a cumulative probability * Top-P limits the selection of the next token to a subset of tokens with a cumulative probability
* above a threshold P. This method, also known as nucleus sampling, finds a balance between diversity * above a threshold P. This method, also known as nucleus sampling, finds a balance between diversity
* and quality by considering both token probabilities and the number of tokens available for sampling. * and quality by considering both token probabilities and the number of tokens available for sampling.
* When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse. * When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse.
* On the other hand, a lower value (eg., 0.1) produces more focused and conservative text. * On the other hand, a lower value (eg., 0.1) produces more focused and conservative text.
* The default value is 0.4, which is aimed to be the middle ground between focus and diversity, but * @default 0.9
* for more creative tasks a higher top-p value will be beneficial, about 0.5-0.9 is a good range for that. *
* @default 0.4
* */ * */
topP: number; topP: number;
/**
* The minimum probability of a token to be considered.
* @default 0.0
*/
minP: number;
/** The temperature to adjust the model's output distribution. /** The temperature to adjust the model's output distribution.
* Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures * Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures
* (eg., 1.2) increase randomness, resulting in more imaginative and diverse text. Lower temperatures (eg., 0.5) * (eg., 1.2) increase randomness, resulting in more imaginative and diverse text. Lower temperatures (eg., 0.5)
* make the output more focused, predictable, and conservative. When the temperature is set to 0, the output * make the output more focused, predictable, and conservative. When the temperature is set to 0, the output
* becomes completely deterministic, always selecting the most probable next token and producing identical results * becomes completely deterministic, always selecting the most probable next token and producing identical results
* each time. A safe range would be around 0.6 - 0.85, but you are free to search what value fits best for you. * each time. Try what value fits best for your use case and model.
* @default 0.7 * @default 0.1
* @alias temperature
* */ * */
temp: number; temp: number;
temperature: number;
/** The number of predictions to generate in parallel. /** The number of predictions to generate in parallel.
* By splitting the prompt every N tokens, prompt-batch-size reduces RAM usage during processing. However, * By splitting the prompt every N tokens, prompt-batch-size reduces RAM usage during processing. However,
@@ -451,31 +706,17 @@ interface LLModelPromptContext {
* The repeat-penalty-tokens N option controls the number of tokens in the history to consider for penalizing repetition. * The repeat-penalty-tokens N option controls the number of tokens in the history to consider for penalizing repetition.
* A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only * A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only
* consider recent tokens. * consider recent tokens.
* @default 64 * @default 10
* */ * */
repeatLastN: number; repeatLastN: number;
/** The percentage of context to erase if the context window is exceeded. /** The percentage of context to erase if the context window is exceeded.
* @default 0.5 * Set it to a lower value to keep context for longer at the cost of performance.
* @default 0.75
* */ * */
contextErase: number; contextErase: number;
} }
/**
* Creates an async generator of tokens
* @param {InferenceModel} llmodel - The language model object.
* @param {PromptMessage[]} messages - The array of messages for the conversation.
* @param {CompletionOptions} options - The options for creating the completion.
* @param {TokenCallback} callback - optional callback to control token generation.
* @returns {AsyncGenerator<string>} The stream of generated tokens
*/
declare function generateTokens(
llmodel: InferenceModel,
messages: PromptMessage[],
options: CompletionOptions,
callback?: TokenCallback
): AsyncGenerator<string>;
/** /**
* From python api: * From python api:
* models will be stored in (homedir)/.cache/gpt4all/` * models will be stored in (homedir)/.cache/gpt4all/`
@@ -508,7 +749,7 @@ declare const DEFAULT_MODEL_LIST_URL: string;
* Initiates the download of a model file. * Initiates the download of a model file.
* By default this downloads without waiting. use the controller returned to alter this behavior. * By default this downloads without waiting. use the controller returned to alter this behavior.
* @param {string} modelName - The model to be downloaded. * @param {string} modelName - The model to be downloaded.
* @param {DownloadOptions} options - to pass into the downloader. Default is { location: (cwd), verbose: false }. * @param {DownloadModelOptions} options - to pass into the downloader. Default is { location: (cwd), verbose: false }.
* @returns {DownloadController} object that allows controlling the download process. * @returns {DownloadController} object that allows controlling the download process.
* *
* @throws {Error} If the model already exists in the specified location. * @throws {Error} If the model already exists in the specified location.
@@ -556,7 +797,9 @@ interface ListModelsOptions {
file?: string; file?: string;
} }
declare function listModels(options?: ListModelsOptions): Promise<ModelConfig[]>; declare function listModels(
options?: ListModelsOptions
): Promise<ModelConfig[]>;
interface RetrieveModelOptions { interface RetrieveModelOptions {
allowDownload?: boolean; allowDownload?: boolean;
@@ -581,30 +824,35 @@ interface DownloadController {
} }
export { export {
ModelType,
ModelFile,
ModelConfig,
InferenceModel,
EmbeddingModel,
LLModel, LLModel,
LLModelPromptContext, LLModelPromptContext,
PromptMessage, ModelConfig,
InferenceModel,
InferenceResult,
EmbeddingModel,
EmbeddingResult,
ChatSession,
ChatMessage,
CompletionInput,
CompletionProvider,
CompletionOptions, CompletionOptions,
CompletionResult,
LoadModelOptions, LoadModelOptions,
DownloadController,
RetrieveModelOptions,
DownloadModelOptions,
GpuDevice,
loadModel, loadModel,
downloadModel,
retrieveModel,
listModels,
createCompletion, createCompletion,
createCompletionStream,
createCompletionGenerator,
createEmbedding, createEmbedding,
generateTokens,
DEFAULT_DIRECTORY, DEFAULT_DIRECTORY,
DEFAULT_LIBRARIES_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY,
DEFAULT_MODEL_CONFIG, DEFAULT_MODEL_CONFIG,
DEFAULT_PROMPT_CONTEXT, DEFAULT_PROMPT_CONTEXT,
DEFAULT_MODEL_LIST_URL, DEFAULT_MODEL_LIST_URL,
downloadModel,
retrieveModel,
listModels,
DownloadController,
RetrieveModelOptions,
DownloadModelOptions,
GpuDevice
}; };

View File

@@ -2,8 +2,10 @@
/// This file implements the gpt4all.d.ts file endings. /// This file implements the gpt4all.d.ts file endings.
/// Written in commonjs to support both ESM and CJS projects. /// Written in commonjs to support both ESM and CJS projects.
const { existsSync } = require("fs"); const { existsSync } = require("node:fs");
const path = require("node:path"); const path = require("node:path");
const Stream = require("node:stream");
const assert = require("node:assert");
const { LLModel } = require("node-gyp-build")(path.resolve(__dirname, "..")); const { LLModel } = require("node-gyp-build")(path.resolve(__dirname, ".."));
const { const {
retrieveModel, retrieveModel,
@@ -18,15 +20,14 @@ const {
DEFAULT_MODEL_LIST_URL, DEFAULT_MODEL_LIST_URL,
} = require("./config.js"); } = require("./config.js");
const { InferenceModel, EmbeddingModel } = require("./models.js"); const { InferenceModel, EmbeddingModel } = require("./models.js");
const Stream = require('stream') const { ChatSession } = require("./chat-session.js");
const assert = require("assert");
/** /**
* Loads a machine learning model with the specified name. The defacto way to create a model. * Loads a machine learning model with the specified name. The defacto way to create a model.
* By default this will download a model from the official GPT4ALL website, if a model is not present at given path. * By default this will download a model from the official GPT4ALL website, if a model is not present at given path.
* *
* @param {string} modelName - The name of the model to load. * @param {string} modelName - The name of the model to load.
* @param {LoadModelOptions|undefined} [options] - (Optional) Additional options for loading the model. * @param {import('./gpt4all').LoadModelOptions|undefined} [options] - (Optional) Additional options for loading the model.
* @returns {Promise<InferenceModel | EmbeddingModel>} A promise that resolves to an instance of the loaded LLModel. * @returns {Promise<InferenceModel | EmbeddingModel>} A promise that resolves to an instance of the loaded LLModel.
*/ */
async function loadModel(modelName, options = {}) { async function loadModel(modelName, options = {}) {
@@ -35,10 +36,10 @@ async function loadModel(modelName, options = {}) {
librariesPath: DEFAULT_LIBRARIES_DIRECTORY, librariesPath: DEFAULT_LIBRARIES_DIRECTORY,
type: "inference", type: "inference",
allowDownload: true, allowDownload: true,
verbose: true, verbose: false,
device: 'cpu', device: "cpu",
nCtx: 2048, nCtx: 2048,
ngl : 100, ngl: 100,
...options, ...options,
}; };
@@ -49,12 +50,14 @@ async function loadModel(modelName, options = {}) {
verbose: loadOptions.verbose, verbose: loadOptions.verbose,
}); });
assert.ok(typeof loadOptions.librariesPath === 'string'); assert.ok(
typeof loadOptions.librariesPath === "string",
"Libraries path should be a string"
);
const existingPaths = loadOptions.librariesPath const existingPaths = loadOptions.librariesPath
.split(";") .split(";")
.filter(existsSync) .filter(existsSync)
.join(';'); .join(";");
console.log("Passing these paths into runtime library search:", existingPaths)
const llmOptions = { const llmOptions = {
model_name: appendBinSuffixIfMissing(modelName), model_name: appendBinSuffixIfMissing(modelName),
@@ -62,13 +65,15 @@ async function loadModel(modelName, options = {}) {
library_path: existingPaths, library_path: existingPaths,
device: loadOptions.device, device: loadOptions.device,
nCtx: loadOptions.nCtx, nCtx: loadOptions.nCtx,
ngl: loadOptions.ngl ngl: loadOptions.ngl,
}; };
if (loadOptions.verbose) { if (loadOptions.verbose) {
console.debug("Creating LLModel with options:", llmOptions); console.debug("Creating LLModel:", {
llmOptions,
modelConfig,
});
} }
console.log(modelConfig)
const llmodel = new LLModel(llmOptions); const llmodel = new LLModel(llmOptions);
if (loadOptions.type === "embedding") { if (loadOptions.type === "embedding") {
return new EmbeddingModel(llmodel, modelConfig); return new EmbeddingModel(llmodel, modelConfig);
@@ -79,75 +84,43 @@ async function loadModel(modelName, options = {}) {
} }
} }
/** function createEmbedding(model, text, options={}) {
* Formats a list of messages into a single prompt string. let {
*/ dimensionality = undefined,
function formatChatPrompt( longTextMode = "mean",
messages, atlas = false,
{ } = options;
systemPromptTemplate,
defaultSystemPrompt,
promptTemplate,
promptFooter,
promptHeader,
}
) {
const systemMessages = messages
.filter((message) => message.role === "system")
.map((message) => message.content);
let fullPrompt = ""; if (dimensionality === undefined) {
dimensionality = -1;
if (promptHeader) { } else {
fullPrompt += promptHeader + "\n\n"; if (dimensionality <= 0) {
} throw new Error(
`Dimensionality must be undefined or a positive integer, got ${dimensionality}`
if (systemPromptTemplate) {
// if user specified a template for the system prompt, put all system messages in the template
let systemPrompt = "";
if (systemMessages.length > 0) {
systemPrompt += systemMessages.join("\n");
}
if (systemPrompt) {
fullPrompt +=
systemPromptTemplate.replace("%1", systemPrompt) + "\n";
}
} else if (defaultSystemPrompt) {
// otherwise, use the system prompt from the model config and ignore system messages
fullPrompt += defaultSystemPrompt + "\n\n";
}
if (systemMessages.length > 0 && !systemPromptTemplate) {
console.warn(
"System messages were provided, but no systemPromptTemplate was specified. System messages will be ignored."
);
}
for (const message of messages) {
if (message.role === "user") {
const userMessage = promptTemplate.replace(
"%1",
message["content"]
); );
fullPrompt += userMessage;
} }
if (message["role"] == "assistant") { if (dimensionality < model.MIN_DIMENSIONALITY) {
const assistantMessage = message["content"] + "\n"; console.warn(
fullPrompt += assistantMessage; `Dimensionality ${dimensionality} is less than the suggested minimum of ${model.MIN_DIMENSIONALITY}. Performance may be degraded.`
);
} }
} }
if (promptFooter) { let doMean;
fullPrompt += "\n\n" + promptFooter; switch (longTextMode) {
case "mean":
doMean = true;
break;
case "truncate":
doMean = false;
break;
default:
throw new Error(
`Long text mode must be one of 'mean' or 'truncate', got ${longTextMode}`
);
} }
return fullPrompt; return model.embed(text, options?.prefix, dimensionality, doMean, atlas);
}
function createEmbedding(model, text) {
return model.embed(text);
} }
const defaultCompletionOptions = { const defaultCompletionOptions = {
@@ -155,162 +128,76 @@ const defaultCompletionOptions = {
...DEFAULT_PROMPT_CONTEXT, ...DEFAULT_PROMPT_CONTEXT,
}; };
function preparePromptAndContext(model,messages,options){ async function createCompletion(
if (options.hasDefaultHeader !== undefined) { provider,
console.warn( input,
"hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead" options = defaultCompletionOptions
); ) {
} const completionOptions = {
if (options.hasDefaultFooter !== undefined) {
console.warn(
"hasDefaultFooter (bool) is deprecated and has no effect, use promptFooter (string) instead"
);
}
const optionsWithDefaults = {
...defaultCompletionOptions, ...defaultCompletionOptions,
...options, ...options,
}; };
const { const result = await provider.generate(
verbose, input,
systemPromptTemplate, completionOptions,
promptTemplate, );
promptHeader,
promptFooter,
...promptContext
} = optionsWithDefaults;
const prompt = formatChatPrompt(messages, {
systemPromptTemplate,
defaultSystemPrompt: model.config.systemPrompt,
promptTemplate: promptTemplate || model.config.promptTemplate || "%1",
promptHeader: promptHeader || "",
promptFooter: promptFooter || "",
// These were the default header/footer prompts used for non-chat single turn completions.
// both seem to be working well still with some models, so keeping them here for reference.
// promptHeader: '### Instruction: The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.',
// promptFooter: '### Response:',
});
return { return {
prompt, promptContext, verbose model: provider.modelName,
}
}
async function createCompletion(
model,
messages,
options = defaultCompletionOptions
) {
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options);
if (verbose) {
console.debug("Sending Prompt:\n" + prompt);
}
let tokensGenerated = 0
const response = await model.generate(prompt, promptContext,() => {
tokensGenerated++;
return true;
});
if (verbose) {
console.debug("Received Response:\n" + response);
}
return {
llmodel: model.llm.name(),
usage: { usage: {
prompt_tokens: prompt.length, prompt_tokens: result.tokensIngested,
completion_tokens: tokensGenerated, total_tokens: result.tokensIngested + result.tokensGenerated,
total_tokens: prompt.length + tokensGenerated, //TODO Not sure how to get tokens in prompt completion_tokens: result.tokensGenerated,
n_past_tokens: result.nPast,
}, },
choices: [ choices: [
{ {
message: { message: {
role: "assistant", role: "assistant",
content: response, content: result.text,
}, },
// TODO some completion APIs also provide logprobs and finish_reason, could look into adding those
}, },
], ],
}; };
} }
function _internal_createTokenStream(stream,model, function createCompletionStream(
messages, provider,
options = defaultCompletionOptions,callback = undefined) { input,
const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options); options = defaultCompletionOptions
) {
const completionStream = new Stream.PassThrough({
encoding: "utf-8",
});
const completionPromise = createCompletion(provider, input, {
if (verbose) { ...options,
console.debug("Sending Prompt:\n" + prompt); onResponseToken: (tokenId, token) => {
} completionStream.push(token);
if (options.onResponseToken) {
model.generate(prompt, promptContext,(tokenId, token, total) => { return options.onResponseToken(tokenId, token);
stream.push(token);
if(callback !== undefined){
return callback(tokenId,token,total);
}
return true;
}).then(() => {
stream.end()
})
return stream;
}
function _createTokenStream(model,
messages,
options = defaultCompletionOptions,callback = undefined) {
// Silent crash if we dont do this here
const stream = new Stream.PassThrough({
encoding: 'utf-8'
});
return _internal_createTokenStream(stream,model,messages,options,callback);
}
async function* generateTokens(model,
messages,
options = defaultCompletionOptions, callback = undefined) {
const stream = _createTokenStream(model,messages,options,callback)
let bHasFinished = false;
let activeDataCallback = undefined;
const finishCallback = () => {
bHasFinished = true;
if(activeDataCallback !== undefined){
activeDataCallback(undefined);
}
}
stream.on("finish",finishCallback)
while (!bHasFinished) {
const token = await new Promise((res) => {
activeDataCallback = (d) => {
stream.off("data",activeDataCallback)
activeDataCallback = undefined
res(d);
} }
stream.on('data', activeDataCallback) },
}) }).then((result) => {
completionStream.push(null);
completionStream.emit("end");
return result;
});
if (token == undefined) { return {
break; tokens: completionStream,
} result: completionPromise,
};
}
yield token; async function* createCompletionGenerator(provider, input, options) {
const completion = createCompletionStream(provider, input, options);
for await (const chunk of completion.tokens) {
yield chunk;
} }
return await completion.result;
stream.off("finish",finishCallback);
} }
module.exports = { module.exports = {
@@ -322,10 +209,12 @@ module.exports = {
LLModel, LLModel,
InferenceModel, InferenceModel,
EmbeddingModel, EmbeddingModel,
ChatSession,
createCompletion, createCompletion,
createCompletionStream,
createCompletionGenerator,
createEmbedding, createEmbedding,
downloadModel, downloadModel,
retrieveModel, retrieveModel,
loadModel, loadModel,
generateTokens
}; };

View File

@@ -1,18 +1,138 @@
const { normalizePromptContext, warnOnSnakeCaseKeys } = require('./util'); const { DEFAULT_PROMPT_CONTEXT } = require("./config");
const { ChatSession } = require("./chat-session");
const { prepareMessagesForIngest } = require("./util");
class InferenceModel { class InferenceModel {
llm; llm;
modelName;
config; config;
activeChatSession;
constructor(llmodel, config) { constructor(llmodel, config) {
this.llm = llmodel; this.llm = llmodel;
this.config = config; this.config = config;
this.modelName = this.llm.name();
} }
async generate(prompt, promptContext,callback) { async createChatSession(options) {
warnOnSnakeCaseKeys(promptContext); const chatSession = new ChatSession(this, options);
const normalizedPromptContext = normalizePromptContext(promptContext); await chatSession.initialize();
const result = this.llm.raw_prompt(prompt, normalizedPromptContext,callback); this.activeChatSession = chatSession;
return this.activeChatSession;
}
async generate(input, options = DEFAULT_PROMPT_CONTEXT) {
const { verbose, ...otherOptions } = options;
const promptContext = {
promptTemplate: this.config.promptTemplate,
temp:
otherOptions.temp ??
otherOptions.temperature ??
DEFAULT_PROMPT_CONTEXT.temp,
...otherOptions,
};
if (promptContext.nPast < 0) {
throw new Error("nPast must be a non-negative integer.");
}
if (verbose) {
console.debug("Generating completion", {
input,
promptContext,
});
}
let prompt = input;
let nPast = promptContext.nPast;
let tokensIngested = 0;
if (Array.isArray(input)) {
// assuming input is a messages array
// -> tailing user message will be used as the final prompt. its required.
// -> leading system message will be ingested as systemPrompt, further system messages will be ignored
// -> all other messages will be ingested with fakeReply
// -> model/context will only be kept for this completion; "stateless"
nPast = 0;
const messages = [...input];
const lastMessage = input[input.length - 1];
if (lastMessage.role !== "user") {
// this is most likely a user error
throw new Error("The final message must be of role 'user'.");
}
if (input[0].role === "system") {
// needs to be a pre-templated prompt ala '<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n'
const systemPrompt = input[0].content;
const systemRes = await this.llm.infer(systemPrompt, {
promptTemplate: "%1",
nPredict: 0,
special: true,
});
nPast = systemRes.nPast;
tokensIngested += systemRes.tokensIngested;
messages.shift();
}
prompt = lastMessage.content;
const messagesToIngest = messages.slice(0, input.length - 1);
const turns = prepareMessagesForIngest(messagesToIngest);
for (const turn of turns) {
const turnRes = await this.llm.infer(turn.user, {
...promptContext,
nPast,
fakeReply: turn.assistant,
});
tokensIngested += turnRes.tokensIngested;
nPast = turnRes.nPast;
}
}
let tokensGenerated = 0;
const result = await this.llm.infer(prompt, {
...promptContext,
nPast,
onPromptToken: (tokenId) => {
let continueIngestion = true;
tokensIngested++;
if (options.onPromptToken) {
// catch errors because if they go through cpp they will loose stacktraces
try {
// don't cancel ingestion unless user explicitly returns false
continueIngestion =
options.onPromptToken(tokenId) !== false;
} catch (e) {
console.error("Error in onPromptToken callback", e);
continueIngestion = false;
}
}
return continueIngestion;
},
onResponseToken: (tokenId, token) => {
let continueGeneration = true;
tokensGenerated++;
if (options.onResponseToken) {
try {
// don't cancel the generation unless user explicitly returns false
continueGeneration =
options.onResponseToken(tokenId, token) !== false;
} catch (err) {
console.error("Error in onResponseToken callback", err);
continueGeneration = false;
}
}
return continueGeneration;
},
});
result.tokensGenerated = tokensGenerated;
result.tokensIngested = tokensIngested;
if (verbose) {
console.debug("Finished completion:\n", result);
}
return result; return result;
} }
@@ -24,14 +144,14 @@ class InferenceModel {
class EmbeddingModel { class EmbeddingModel {
llm; llm;
config; config;
MIN_DIMENSIONALITY = 64;
constructor(llmodel, config) { constructor(llmodel, config) {
this.llm = llmodel; this.llm = llmodel;
this.config = config; this.config = config;
} }
embed(text) { embed(text, prefix, dimensionality, do_mean, atlas) {
return this.llm.embed(text) return this.llm.embed(text, prefix, dimensionality, do_mean, atlas);
} }
dispose() { dispose() {
@@ -39,7 +159,6 @@ class EmbeddingModel {
} }
} }
module.exports = { module.exports = {
InferenceModel, InferenceModel,
EmbeddingModel, EmbeddingModel,

View File

@@ -1,8 +1,7 @@
const { createWriteStream, existsSync, statSync } = require("node:fs"); const { createWriteStream, existsSync, statSync, mkdirSync } = require("node:fs");
const fsp = require("node:fs/promises"); const fsp = require("node:fs/promises");
const { performance } = require("node:perf_hooks"); const { performance } = require("node:perf_hooks");
const path = require("node:path"); const path = require("node:path");
const { mkdirp } = require("mkdirp");
const md5File = require("md5-file"); const md5File = require("md5-file");
const { const {
DEFAULT_DIRECTORY, DEFAULT_DIRECTORY,
@@ -50,6 +49,63 @@ function appendBinSuffixIfMissing(name) {
return name; return name;
} }
function prepareMessagesForIngest(messages) {
const systemMessages = messages.filter(
(message) => message.role === "system"
);
if (systemMessages.length > 0) {
console.warn(
"System messages are currently not supported and will be ignored. Use the systemPrompt option instead."
);
}
const userAssistantMessages = messages.filter(
(message) => message.role !== "system"
);
// make sure the first message is a user message
// if its not, the turns will be out of order
if (userAssistantMessages[0].role !== "user") {
userAssistantMessages.unshift({
role: "user",
content: "",
});
}
// create turns of user input + assistant reply
const turns = [];
let userMessage = null;
let assistantMessage = null;
for (const message of userAssistantMessages) {
// consecutive messages of the same role are concatenated into one message
if (message.role === "user") {
if (!userMessage) {
userMessage = message.content;
} else {
userMessage += "\n" + message.content;
}
} else if (message.role === "assistant") {
if (!assistantMessage) {
assistantMessage = message.content;
} else {
assistantMessage += "\n" + message.content;
}
}
if (userMessage && assistantMessage) {
turns.push({
user: userMessage,
assistant: assistantMessage,
});
userMessage = null;
assistantMessage = null;
}
}
return turns;
}
// readChunks() reads from the provided reader and yields the results into an async iterable // readChunks() reads from the provided reader and yields the results into an async iterable
// https://css-tricks.com/web-streams-everywhere-and-fetch-for-node-js/ // https://css-tricks.com/web-streams-everywhere-and-fetch-for-node-js/
function readChunks(reader) { function readChunks(reader) {
@@ -64,49 +120,13 @@ function readChunks(reader) {
}; };
} }
/**
* Prints a warning if any keys in the prompt context are snake_case.
*/
function warnOnSnakeCaseKeys(promptContext) {
const snakeCaseKeys = Object.keys(promptContext).filter((key) =>
key.includes("_")
);
if (snakeCaseKeys.length > 0) {
console.warn(
"Prompt context keys should be camelCase. Support for snake_case might be removed in the future. Found keys: " +
snakeCaseKeys.join(", ")
);
}
}
/**
* Converts all keys in the prompt context to snake_case
* For duplicate definitions, the value of the last occurrence will be used.
*/
function normalizePromptContext(promptContext) {
const normalizedPromptContext = {};
for (const key in promptContext) {
if (promptContext.hasOwnProperty(key)) {
const snakeKey = key.replace(
/[A-Z]/g,
(match) => `_${match.toLowerCase()}`
);
normalizedPromptContext[snakeKey] = promptContext[key];
}
}
return normalizedPromptContext;
}
function downloadModel(modelName, options = {}) { function downloadModel(modelName, options = {}) {
const downloadOptions = { const downloadOptions = {
modelPath: DEFAULT_DIRECTORY, modelPath: DEFAULT_DIRECTORY,
verbose: false, verbose: false,
...options, ...options,
}; };
const modelFileName = appendBinSuffixIfMissing(modelName); const modelFileName = appendBinSuffixIfMissing(modelName);
const partialModelPath = path.join( const partialModelPath = path.join(
downloadOptions.modelPath, downloadOptions.modelPath,
@@ -114,16 +134,17 @@ function downloadModel(modelName, options = {}) {
); );
const finalModelPath = path.join(downloadOptions.modelPath, modelFileName); const finalModelPath = path.join(downloadOptions.modelPath, modelFileName);
const modelUrl = const modelUrl =
downloadOptions.url ?? `https://gpt4all.io/models/gguf/${modelFileName}`; downloadOptions.url ??
`https://gpt4all.io/models/gguf/${modelFileName}`;
mkdirp.sync(downloadOptions.modelPath) mkdirSync(downloadOptions.modelPath, { recursive: true });
if (existsSync(finalModelPath)) { if (existsSync(finalModelPath)) {
throw Error(`Model already exists at ${finalModelPath}`); throw Error(`Model already exists at ${finalModelPath}`);
} }
if (downloadOptions.verbose) { if (downloadOptions.verbose) {
console.log(`Downloading ${modelName} from ${modelUrl}`); console.debug(`Downloading ${modelName} from ${modelUrl}`);
} }
const headers = { const headers = {
@@ -134,7 +155,9 @@ function downloadModel(modelName, options = {}) {
const writeStreamOpts = {}; const writeStreamOpts = {};
if (existsSync(partialModelPath)) { if (existsSync(partialModelPath)) {
console.log("Partial model exists, resuming download..."); if (downloadOptions.verbose) {
console.debug("Partial model exists, resuming download...");
}
const startRange = statSync(partialModelPath).size; const startRange = statSync(partialModelPath).size;
headers["Range"] = `bytes=${startRange}-`; headers["Range"] = `bytes=${startRange}-`;
writeStreamOpts.flags = "a"; writeStreamOpts.flags = "a";
@@ -144,15 +167,15 @@ function downloadModel(modelName, options = {}) {
const signal = abortController.signal; const signal = abortController.signal;
const finalizeDownload = async () => { const finalizeDownload = async () => {
if (options.md5sum) { if (downloadOptions.md5sum) {
const fileHash = await md5File(partialModelPath); const fileHash = await md5File(partialModelPath);
if (fileHash !== options.md5sum) { if (fileHash !== downloadOptions.md5sum) {
await fsp.unlink(partialModelPath); await fsp.unlink(partialModelPath);
const message = `Model "${modelName}" failed verification: Hashes mismatch. Expected ${options.md5sum}, got ${fileHash}`; const message = `Model "${modelName}" failed verification: Hashes mismatch. Expected ${downloadOptions.md5sum}, got ${fileHash}`;
throw Error(message); throw Error(message);
} }
if (options.verbose) { if (downloadOptions.verbose) {
console.log(`MD5 hash verified: ${fileHash}`); console.debug(`MD5 hash verified: ${fileHash}`);
} }
} }
@@ -163,8 +186,8 @@ function downloadModel(modelName, options = {}) {
const downloadPromise = new Promise((resolve, reject) => { const downloadPromise = new Promise((resolve, reject) => {
let timestampStart; let timestampStart;
if (options.verbose) { if (downloadOptions.verbose) {
console.log(`Downloading @ ${partialModelPath} ...`); console.debug(`Downloading @ ${partialModelPath} ...`);
timestampStart = performance.now(); timestampStart = performance.now();
} }
@@ -179,7 +202,7 @@ function downloadModel(modelName, options = {}) {
}); });
writeStream.on("finish", () => { writeStream.on("finish", () => {
if (options.verbose) { if (downloadOptions.verbose) {
const elapsed = performance.now() - timestampStart; const elapsed = performance.now() - timestampStart;
console.log(`Finished. Download took ${elapsed.toFixed(2)} ms`); console.log(`Finished. Download took ${elapsed.toFixed(2)} ms`);
} }
@@ -221,10 +244,10 @@ async function retrieveModel(modelName, options = {}) {
const retrieveOptions = { const retrieveOptions = {
modelPath: DEFAULT_DIRECTORY, modelPath: DEFAULT_DIRECTORY,
allowDownload: true, allowDownload: true,
verbose: true, verbose: false,
...options, ...options,
}; };
await mkdirp(retrieveOptions.modelPath); mkdirSync(retrieveOptions.modelPath, { recursive: true });
const modelFileName = appendBinSuffixIfMissing(modelName); const modelFileName = appendBinSuffixIfMissing(modelName);
const fullModelPath = path.join(retrieveOptions.modelPath, modelFileName); const fullModelPath = path.join(retrieveOptions.modelPath, modelFileName);
@@ -236,7 +259,7 @@ async function retrieveModel(modelName, options = {}) {
file: retrieveOptions.modelConfigFile, file: retrieveOptions.modelConfigFile,
url: url:
retrieveOptions.allowDownload && retrieveOptions.allowDownload &&
"https://gpt4all.io/models/models2.json", "https://gpt4all.io/models/models3.json",
}); });
const loadedModelConfig = availableModels.find( const loadedModelConfig = availableModels.find(
@@ -262,10 +285,9 @@ async function retrieveModel(modelName, options = {}) {
config.path = fullModelPath; config.path = fullModelPath;
if (retrieveOptions.verbose) { if (retrieveOptions.verbose) {
console.log(`Found ${modelName} at ${fullModelPath}`); console.debug(`Found ${modelName} at ${fullModelPath}`);
} }
} else if (retrieveOptions.allowDownload) { } else if (retrieveOptions.allowDownload) {
const downloadController = downloadModel(modelName, { const downloadController = downloadModel(modelName, {
modelPath: retrieveOptions.modelPath, modelPath: retrieveOptions.modelPath,
verbose: retrieveOptions.verbose, verbose: retrieveOptions.verbose,
@@ -278,7 +300,7 @@ async function retrieveModel(modelName, options = {}) {
config.path = downloadPath; config.path = downloadPath;
if (retrieveOptions.verbose) { if (retrieveOptions.verbose) {
console.log(`Model downloaded to ${downloadPath}`); console.debug(`Model downloaded to ${downloadPath}`);
} }
} else { } else {
throw Error("Failed to retrieve model."); throw Error("Failed to retrieve model.");
@@ -288,9 +310,8 @@ async function retrieveModel(modelName, options = {}) {
module.exports = { module.exports = {
appendBinSuffixIfMissing, appendBinSuffixIfMissing,
prepareMessagesForIngest,
downloadModel, downloadModel,
retrieveModel, retrieveModel,
listModels, listModels,
normalizePromptContext,
warnOnSnakeCaseKeys,
}; };

View File

@@ -7,7 +7,6 @@ const {
listModels, listModels,
downloadModel, downloadModel,
appendBinSuffixIfMissing, appendBinSuffixIfMissing,
normalizePromptContext,
} = require("../src/util.js"); } = require("../src/util.js");
const { const {
DEFAULT_DIRECTORY, DEFAULT_DIRECTORY,
@@ -19,8 +18,6 @@ const {
createPrompt, createPrompt,
createCompletion, createCompletion,
} = require("../src/gpt4all.js"); } = require("../src/gpt4all.js");
const { mock } = require("node:test");
const { mkdirp } = require("mkdirp");
describe("config", () => { describe("config", () => {
test("default paths constants are available and correct", () => { test("default paths constants are available and correct", () => {
@@ -87,7 +84,7 @@ describe("listModels", () => {
expect(fetch).toHaveBeenCalledTimes(0); expect(fetch).toHaveBeenCalledTimes(0);
expect(models[0]).toEqual(fakeModel); expect(models[0]).toEqual(fakeModel);
}); });
it("should throw an error if neither url nor file is specified", async () => { it("should throw an error if neither url nor file is specified", async () => {
await expect(listModels(null)).rejects.toThrow( await expect(listModels(null)).rejects.toThrow(
"No model list source specified. Please specify either a url or a file." "No model list source specified. Please specify either a url or a file."
@@ -141,10 +138,10 @@ describe("downloadModel", () => {
mockAbortController.mockReset(); mockAbortController.mockReset();
mockFetch.mockClear(); mockFetch.mockClear();
global.fetch.mockRestore(); global.fetch.mockRestore();
const rootDefaultPath = path.resolve(DEFAULT_DIRECTORY), const rootDefaultPath = path.resolve(DEFAULT_DIRECTORY),
partialPath = path.resolve(rootDefaultPath, fakeModelName+'.part'), partialPath = path.resolve(rootDefaultPath, fakeModelName+'.part'),
fullPath = path.resolve(rootDefaultPath, fakeModelName+'.bin') fullPath = path.resolve(rootDefaultPath, fakeModelName+'.bin')
//if tests fail, remove the created files //if tests fail, remove the created files
// acts as cleanup if tests fail // acts as cleanup if tests fail
@@ -206,46 +203,3 @@ describe("downloadModel", () => {
// test("should be able to cancel and resume a download", async () => { // test("should be able to cancel and resume a download", async () => {
// }); // });
}); });
describe("normalizePromptContext", () => {
it("should convert a dict with camelCased keys to snake_case", () => {
const camelCased = {
topK: 20,
repeatLastN: 10,
};
const expectedSnakeCased = {
top_k: 20,
repeat_last_n: 10,
};
const result = normalizePromptContext(camelCased);
expect(result).toEqual(expectedSnakeCased);
});
it("should convert a mixed case dict to snake_case, last value taking precedence", () => {
const mixedCased = {
topK: 20,
top_k: 10,
repeatLastN: 10,
};
const expectedSnakeCased = {
top_k: 10,
repeat_last_n: 10,
};
const result = normalizePromptContext(mixedCased);
expect(result).toEqual(expectedSnakeCased);
});
it("should not modify already snake cased dict", () => {
const snakeCased = {
top_k: 10,
repeast_last_n: 10,
};
const result = normalizePromptContext(snakeCased);
expect(result).toEqual(snakeCased);
});
});

View File

@@ -2300,7 +2300,6 @@ __metadata:
documentation: ^14.0.2 documentation: ^14.0.2
jest: ^29.5.0 jest: ^29.5.0
md5-file: ^5.0.0 md5-file: ^5.0.0
mkdirp: ^3.0.1
node-addon-api: ^6.1.0 node-addon-api: ^6.1.0
node-gyp: 9.x.x node-gyp: 9.x.x
node-gyp-build: ^4.6.0 node-gyp-build: ^4.6.0
@@ -4258,15 +4257,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"mkdirp@npm:^3.0.1":
version: 3.0.1
resolution: "mkdirp@npm:3.0.1"
bin:
mkdirp: dist/cjs/src/bin.js
checksum: 972deb188e8fb55547f1e58d66bd6b4a3623bf0c7137802582602d73e6480c1c2268dcbafbfb1be466e00cc7e56ac514d7fd9334b7cf33e3e2ab547c16f83a8d
languageName: node
linkType: hard
"mri@npm:^1.1.0": "mri@npm:^1.1.0":
version: 1.2.0 version: 1.2.0
resolution: "mri@npm:1.2.0" resolution: "mri@npm:1.2.0"