From 55f3b056b7250be44ec3df7689081654652810a2 Mon Sep 17 00:00:00 2001 From: Jacob Nguyen <76754747+jacoobes@users.noreply.github.com> Date: Thu, 28 Mar 2024 11:08:23 -0500 Subject: [PATCH] typescript!: chatSessions, fixes, tokenStreams (#2045) Signed-off-by: jacob Signed-off-by: limez Signed-off-by: Jared Van Bortel Co-authored-by: limez Co-authored-by: Jared Van Bortel --- .../python/docs/gpt4all_nodejs.md | 531 +++++++++---- gpt4all-bindings/typescript/.clang-format | 4 + gpt4all-bindings/typescript/README.md | 189 ++++- gpt4all-bindings/typescript/binding.ci.gyp | 6 +- gpt4all-bindings/typescript/binding.gyp | 6 +- gpt4all-bindings/typescript/index.cc | 582 ++++++++------ gpt4all-bindings/typescript/index.h | 101 +-- gpt4all-bindings/typescript/package.json | 3 +- gpt4all-bindings/typescript/prompt.cc | 300 +++++--- gpt4all-bindings/typescript/prompt.h | 105 +-- .../typescript/scripts/build_unix.sh | 1 - .../typescript/spec/callbacks.mjs | 31 + .../typescript/spec/chat-memory.mjs | 65 ++ .../typescript/spec/chat-minimal.mjs | 19 + gpt4all-bindings/typescript/spec/chat.mjs | 70 -- .../typescript/spec/concurrency.mjs | 29 + .../typescript/spec/embed-jsonl.mjs | 26 + gpt4all-bindings/typescript/spec/embed.mjs | 10 +- .../typescript/spec/generator.mjs | 41 - gpt4all-bindings/typescript/spec/llmodel.mjs | 61 ++ .../typescript/spec/long-context.mjs | 21 + .../typescript/spec/model-switching.mjs | 60 ++ .../typescript/spec/stateless.mjs | 50 ++ .../typescript/spec/streaming.mjs | 57 ++ gpt4all-bindings/typescript/spec/system.mjs | 19 + .../typescript/src/chat-session.js | 169 +++++ gpt4all-bindings/typescript/src/config.js | 11 +- gpt4all-bindings/typescript/src/gpt4all.d.ts | 710 ++++++++++++------ gpt4all-bindings/typescript/src/gpt4all.js | 305 +++----- gpt4all-bindings/typescript/src/models.js | 137 +++- gpt4all-bindings/typescript/src/util.js | 141 ++-- .../typescript/test/gpt4all.test.js | 52 +- gpt4all-bindings/typescript/yarn.lock | 10 - 33 files changed, 2573 insertions(+), 1349 deletions(-) create mode 100644 gpt4all-bindings/typescript/.clang-format create mode 100644 gpt4all-bindings/typescript/spec/callbacks.mjs create mode 100644 gpt4all-bindings/typescript/spec/chat-memory.mjs create mode 100644 gpt4all-bindings/typescript/spec/chat-minimal.mjs delete mode 100644 gpt4all-bindings/typescript/spec/chat.mjs create mode 100644 gpt4all-bindings/typescript/spec/concurrency.mjs create mode 100644 gpt4all-bindings/typescript/spec/embed-jsonl.mjs delete mode 100644 gpt4all-bindings/typescript/spec/generator.mjs create mode 100644 gpt4all-bindings/typescript/spec/llmodel.mjs create mode 100644 gpt4all-bindings/typescript/spec/long-context.mjs create mode 100644 gpt4all-bindings/typescript/spec/model-switching.mjs create mode 100644 gpt4all-bindings/typescript/spec/stateless.mjs create mode 100644 gpt4all-bindings/typescript/spec/streaming.mjs create mode 100644 gpt4all-bindings/typescript/spec/system.mjs create mode 100644 gpt4all-bindings/typescript/src/chat-session.js diff --git a/gpt4all-bindings/python/docs/gpt4all_nodejs.md b/gpt4all-bindings/python/docs/gpt4all_nodejs.md index 7e3a6b93..a282a47a 100644 --- a/gpt4all-bindings/python/docs/gpt4all_nodejs.md +++ b/gpt4all-bindings/python/docs/gpt4all_nodejs.md @@ -11,37 +11,116 @@ pnpm install gpt4all@latest ``` -The original [GPT4All typescript bindings](https://github.com/nomic-ai/gpt4all-ts) are now out of date. +## Contents -* New bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use. -* The nodejs api has made strides to mirror the python api. It is not 100% mirrored, but many pieces of the api resemble its python counterpart. -* Everything should work out the box. * See [API Reference](#api-reference) +* See [Examples](#api-example) +* See [Developing](#develop) +* GPT4ALL nodejs bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use. + +## Api Example ### Chat Completion ```js -import { createCompletion, loadModel } from '../src/gpt4all.js' +import { LLModel, createCompletion, DEFAULT_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY, loadModel } from '../src/gpt4all.js' -const model = await loadModel('mistral-7b-openorca.Q4_0.gguf', { verbose: true }); +const model = await loadModel( 'mistral-7b-openorca.gguf2.Q4_0.gguf', { verbose: true, device: 'gpu' }); -const response = await createCompletion(model, [ - { role : 'system', content: 'You are meant to be annoying and unhelpful.' }, - { role : 'user', content: 'What is 1 + 1?' } -]); +const completion1 = await createCompletion(model, 'What is 1 + 1?', { verbose: true, }) +console.log(completion1.message) +const completion2 = await createCompletion(model, 'And if we add two?', { verbose: true }) +console.log(completion2.message) + +model.dispose() ``` ### Embedding ```js -import { createEmbedding, loadModel } from '../src/gpt4all.js' +import { loadModel, createEmbedding } from '../src/gpt4all.js' -const model = await loadModel('ggml-all-MiniLM-L6-v2-f16', { verbose: true }); +const embedder = await loadModel("all-MiniLM-L6-v2-f16.gguf", { verbose: true, type: 'embedding'}) -const fltArray = createEmbedding(model, "Pain is inevitable, suffering optional"); +console.log(createEmbedding(embedder, "Maybe Minecraft was the friends we made along the way")); ``` +### Chat Sessions + +```js +import { loadModel, createCompletion } from "../src/gpt4all.js"; + +const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", { + verbose: true, + device: "gpu", +}); + +const chat = await model.createChatSession(); + +await createCompletion( + chat, + "Why are bananas rather blue than bread at night sometimes?", + { + verbose: true, + } +); +await createCompletion(chat, "Are you sure?", { verbose: true, }); + +``` + +### Streaming responses + +```js +import gpt from "../src/gpt4all.js"; + +const model = await gpt.loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", { + device: "gpu", +}); + +process.stdout.write("### Stream:"); +const stream = gpt.createCompletionStream(model, "How are you?"); +stream.tokens.on("data", (data) => { + process.stdout.write(data); +}); +//wait till stream finishes. We cannot continue until this one is done. +await stream.result; +process.stdout.write("\n"); + +process.stdout.write("### Stream with pipe:"); +const stream2 = gpt.createCompletionStream( + model, + "Please say something nice about node streams." +); +stream2.tokens.pipe(process.stdout); +await stream2.result; +process.stdout.write("\n"); + +console.log("done"); +model.dispose(); +``` + +### Async Generators + +```js +import gpt from "../src/gpt4all.js"; + +const model = await gpt.loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", { + device: "gpu", +}); + +process.stdout.write("### Generator:"); +const gen = gpt.createCompletionGenerator(model, "Redstone in Minecraft is Turing Complete. Let that sink in. (let it in!)"); +for await (const chunk of gen) { + process.stdout.write(chunk); +} + +process.stdout.write("\n"); +model.dispose(); +``` + +## Develop + ### Build Instructions * binding.gyp is compile config @@ -131,21 +210,27 @@ yarn test * why your model may be spewing bull 💩 * The downloaded model is broken (just reinstall or download from official site) - * That's it so far +* Your model is hanging after a call to generate tokens. + * Is `nPast` set too high? This may cause your model to hang (03/16/2024), Linux Mint, Ubuntu 22.04 +* Your GPU usage is still high after node.js exits. + * Make sure to call `model.dispose()`!!! ### Roadmap -This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list: +This package has been stabilizing over time development, and breaking changes may happen until the api stabilizes. Here's what's the todo list: +* \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well. +* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help) + * Should include prebuilds to avoid painful node-gyp errors +* \[x] createChatSession ( the python equivalent to create\_chat\_session ) +* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs! +* \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete * \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs -* \[ ] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete +* \[x] generateTokens is the new name for this^ * \[x] proper unit testing (integrate with circle ci) * \[x] publish to npm under alpha tag `gpt4all@alpha` * \[x] have more people test on other platforms (mac tester needed) * \[x] switch to new pluggable backend -* \[ ] NPM bundle size reduction via optionalDependencies strategy (need help) - * Should include prebuilds to avoid painful node-gyp errors -* \[ ] createChatSession ( the python equivalent to create\_chat\_session ) ### API Reference @@ -153,144 +238,200 @@ This package is in active development, and breaking changes may happen until the ##### Table of Contents -* [ModelFile](#modelfile) - * [gptj](#gptj) - * [llama](#llama) - * [mpt](#mpt) - * [replit](#replit) * [type](#type) * [TokenCallback](#tokencallback) +* [ChatSessionOptions](#chatsessionoptions) + * [systemPrompt](#systemprompt) + * [messages](#messages) +* [initialize](#initialize) + * [Parameters](#parameters) +* [generate](#generate) + * [Parameters](#parameters-1) * [InferenceModel](#inferencemodel) + * [createChatSession](#createchatsession) + * [Parameters](#parameters-2) + * [generate](#generate-1) + * [Parameters](#parameters-3) * [dispose](#dispose) * [EmbeddingModel](#embeddingmodel) * [dispose](#dispose-1) +* [InferenceResult](#inferenceresult) * [LLModel](#llmodel) * [constructor](#constructor) - * [Parameters](#parameters) + * [Parameters](#parameters-4) * [type](#type-1) * [name](#name) * [stateSize](#statesize) * [threadCount](#threadcount) * [setThreadCount](#setthreadcount) - * [Parameters](#parameters-1) - * [raw\_prompt](#raw_prompt) - * [Parameters](#parameters-2) + * [Parameters](#parameters-5) + * [infer](#infer) + * [Parameters](#parameters-6) * [embed](#embed) - * [Parameters](#parameters-3) + * [Parameters](#parameters-7) * [isModelLoaded](#ismodelloaded) * [setLibraryPath](#setlibrarypath) - * [Parameters](#parameters-4) + * [Parameters](#parameters-8) * [getLibraryPath](#getlibrarypath) * [initGpuByString](#initgpubystring) - * [Parameters](#parameters-5) + * [Parameters](#parameters-9) * [hasGpuDevice](#hasgpudevice) * [listGpu](#listgpu) - * [Parameters](#parameters-6) + * [Parameters](#parameters-10) * [dispose](#dispose-2) * [GpuDevice](#gpudevice) * [type](#type-2) * [LoadModelOptions](#loadmodeloptions) -* [loadModel](#loadmodel) - * [Parameters](#parameters-7) -* [createCompletion](#createcompletion) - * [Parameters](#parameters-8) -* [createEmbedding](#createembedding) - * [Parameters](#parameters-9) -* [CompletionOptions](#completionoptions) + * [modelPath](#modelpath) + * [librariesPath](#librariespath) + * [modelConfigFile](#modelconfigfile) + * [allowDownload](#allowdownload) * [verbose](#verbose) - * [systemPromptTemplate](#systemprompttemplate) - * [promptTemplate](#prompttemplate) - * [promptHeader](#promptheader) - * [promptFooter](#promptfooter) -* [PromptMessage](#promptmessage) + * [device](#device) + * [nCtx](#nctx) + * [ngl](#ngl) +* [loadModel](#loadmodel) + * [Parameters](#parameters-11) +* [InferenceProvider](#inferenceprovider) +* [createCompletion](#createcompletion) + * [Parameters](#parameters-12) +* [createCompletionStream](#createcompletionstream) + * [Parameters](#parameters-13) +* [createCompletionGenerator](#createcompletiongenerator) + * [Parameters](#parameters-14) +* [createEmbedding](#createembedding) + * [Parameters](#parameters-15) +* [CompletionOptions](#completionoptions) + * [verbose](#verbose-1) + * [onToken](#ontoken) +* [Message](#message) * [role](#role) * [content](#content) * [prompt\_tokens](#prompt_tokens) * [completion\_tokens](#completion_tokens) * [total\_tokens](#total_tokens) +* [n\_past\_tokens](#n_past_tokens) * [CompletionReturn](#completionreturn) * [model](#model) * [usage](#usage) - * [choices](#choices) -* [CompletionChoice](#completionchoice) - * [message](#message) + * [message](#message-1) +* [CompletionStreamReturn](#completionstreamreturn) * [LLModelPromptContext](#llmodelpromptcontext) * [logitsSize](#logitssize) * [tokensSize](#tokenssize) * [nPast](#npast) - * [nCtx](#nctx) * [nPredict](#npredict) + * [promptTemplate](#prompttemplate) + * [nCtx](#nctx-1) * [topK](#topk) * [topP](#topp) - * [temp](#temp) + * [minP](#minp) + * [temperature](#temperature) * [nBatch](#nbatch) * [repeatPenalty](#repeatpenalty) * [repeatLastN](#repeatlastn) * [contextErase](#contexterase) -* [generateTokens](#generatetokens) - * [Parameters](#parameters-10) * [DEFAULT\_DIRECTORY](#default_directory) * [DEFAULT\_LIBRARIES\_DIRECTORY](#default_libraries_directory) * [DEFAULT\_MODEL\_CONFIG](#default_model_config) * [DEFAULT\_PROMPT\_CONTEXT](#default_prompt_context) * [DEFAULT\_MODEL\_LIST\_URL](#default_model_list_url) * [downloadModel](#downloadmodel) - * [Parameters](#parameters-11) + * [Parameters](#parameters-16) * [Examples](#examples) * [DownloadModelOptions](#downloadmodeloptions) - * [modelPath](#modelpath) - * [verbose](#verbose-1) + * [modelPath](#modelpath-1) + * [verbose](#verbose-2) * [url](#url) * [md5sum](#md5sum) * [DownloadController](#downloadcontroller) * [cancel](#cancel) * [promise](#promise) -#### ModelFile - -Full list of models available -DEPRECATED!! These model names are outdated and this type will not be maintained, please use a string literal instead - -##### gptj - -List of GPT-J Models - -Type: (`"ggml-gpt4all-j-v1.3-groovy.bin"` | `"ggml-gpt4all-j-v1.2-jazzy.bin"` | `"ggml-gpt4all-j-v1.1-breezy.bin"` | `"ggml-gpt4all-j.bin"`) - -##### llama - -List Llama Models - -Type: (`"ggml-gpt4all-l13b-snoozy.bin"` | `"ggml-vicuna-7b-1.1-q4_2.bin"` | `"ggml-vicuna-13b-1.1-q4_2.bin"` | `"ggml-wizardLM-7B.q4_2.bin"` | `"ggml-stable-vicuna-13B.q4_2.bin"` | `"ggml-nous-gpt4-vicuna-13b.bin"` | `"ggml-v3-13b-hermes-q5_1.bin"`) - -##### mpt - -List of MPT Models - -Type: (`"ggml-mpt-7b-base.bin"` | `"ggml-mpt-7b-chat.bin"` | `"ggml-mpt-7b-instruct.bin"`) - -##### replit - -List of Replit Models - -Type: `"ggml-replit-code-v1-3b.bin"` - #### type Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user. -Type: ModelType +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) #### TokenCallback -Callback for controlling token generation +Callback for controlling token generation. Return false to stop token generation. Type: function (tokenId: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), token: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String), total: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)): [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean) +#### ChatSessionOptions + +**Extends Partial\** + +Options for the chat session. + +##### systemPrompt + +System prompt to ingest on initialization. + +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) + +##### messages + +Messages to ingest on initialization. + +Type: [Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[Message](#message)> + +#### initialize + +Ingests system prompt and initial messages. +Sets this chat session as the active chat session of the model. + +##### Parameters + +* `options` **[ChatSessionOptions](#chatsessionoptions)** The options for the chat session. + +Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)\** + +#### generate + +Prompts the model in chat-session context. + +##### Parameters + +* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input. +* `options` **[CompletionOptions](#completionoptions)?** Prompt context and other options. +* `callback` **[TokenCallback](#tokencallback)?** Token generation callback. + + + +* Throws **[Error](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Error)** If the chat session is not the active chat session of the model. + +Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[CompletionReturn](#completionreturn)>** The model's response to the prompt. + #### InferenceModel InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers. +##### createChatSession + +Create a chat session with the model. + +###### Parameters + +* `options` **[ChatSessionOptions](#chatsessionoptions)?** The options for the chat session. + +Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)\** The chat session. + +##### generate + +Prompts the model with a given input and optional parameters. + +###### Parameters + +* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** +* `options` **[CompletionOptions](#completionoptions)?** Prompt context and other options. +* `callback` **[TokenCallback](#tokencallback)?** Token generation callback. +* `input` The prompt input. + +Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[CompletionReturn](#completionreturn)>** The model's response to the prompt. + ##### dispose delete and cleanup the native model @@ -307,6 +448,10 @@ delete and cleanup the native model Returns **void** +#### InferenceResult + +Shape of LLModel's inference result. + #### LLModel LLModel class representing a language model. @@ -326,9 +471,9 @@ Initialize a new LLModel. ##### type -either 'gpt', mpt', or 'llama' or undefined +undefined or user supplied -Returns **(ModelType | [undefined](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/undefined))** +Returns **([string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) | [undefined](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/undefined))** ##### name @@ -360,7 +505,7 @@ Set the number of threads used for model inference. Returns **void** -##### raw\_prompt +##### infer Prompt the model with a given input and optional parameters. This is the raw output from model. @@ -368,23 +513,20 @@ Use the prompt function exported for a value ###### Parameters -* `q` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input. -* `params` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context. +* `prompt` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input. +* `promptContext` **Partial<[LLModelPromptContext](#llmodelpromptcontext)>** Optional parameters for the prompt context. * `callback` **[TokenCallback](#tokencallback)?** optional callback to control token generation. -Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The result of the model prompt. +Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<[InferenceResult](#inferenceresult)>** The result of the model prompt. ##### embed Embed text with the model. Keep in mind that -not all models can embed text, (only bert can embed as of 07/16/2023 (mm/dd/yyyy)) Use the prompt function exported for a value ###### Parameters -* `text` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** -* `q` The prompt input. -* `params` Optional parameters for the prompt context. +* `text` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The prompt input. Returns **[Float32Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Float32Array)** The result of the model prompt. @@ -462,6 +604,62 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa Options that configure a model's behavior. +##### modelPath + +Where to look for model files. + +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) + +##### librariesPath + +Where to look for the backend libraries. + +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) + +##### modelConfigFile + +The path to the model configuration file, useful for offline usage or custom model configurations. + +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) + +##### allowDownload + +Whether to allow downloading the model if it is not present at the specified path. + +Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean) + +##### verbose + +Enable verbose logging. + +Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean) + +##### device + +The processing unit on which the model will run. It can be set to + +* "cpu": Model will run on the central processing unit. +* "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor. +* "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor. +* "gpu name": Model will run on the GPU that matches the name if it's available. + Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All + instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the + model. + +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) + +##### nCtx + +The Maximum window size of this model + +Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) + +##### ngl + +Number of gpu layers needed + +Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) + #### loadModel Loads a machine learning model with the specified name. The defacto way to create a model. @@ -474,18 +672,46 @@ By default this will download a model from the official GPT4ALL website, if a mo Returns **[Promise](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Promise)<([InferenceModel](#inferencemodel) | [EmbeddingModel](#embeddingmodel))>** A promise that resolves to an instance of the loaded LLModel. +#### InferenceProvider + +Interface for inference, implemented by InferenceModel and ChatSession. + #### createCompletion The nodejs equivalent to python binding's chat\_completion ##### Parameters -* `model` **[InferenceModel](#inferencemodel)** The language model object. -* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation. +* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session +* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message * `options` **[CompletionOptions](#completionoptions)** The options for creating the completion. Returns **[CompletionReturn](#completionreturn)** The completion result. +#### createCompletionStream + +Streaming variant of createCompletion, returns a stream of tokens and a promise that resolves to the completion result. + +##### Parameters + +* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session +* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message. +* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion. + +Returns **[CompletionStreamReturn](#completionstreamreturn)** An object of token stream and the completion result promise. + +#### createCompletionGenerator + +Creates an async generator of tokens + +##### Parameters + +* `provider` **[InferenceProvider](#inferenceprovider)** The inference model object or chat session +* `message` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The user input message. +* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion. + +Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens + #### createEmbedding The nodejs moral equivalent to python binding's Embed4All().embed() @@ -510,34 +736,15 @@ Indicates if verbose logging is enabled. Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean) -##### systemPromptTemplate +##### onToken -Template for the system message. Will be put before the conversation with %1 being replaced by all system messages. -Note that if this is not defined, system messages will not be included in the prompt. +Callback for controlling token generation. Return false to stop processing. -Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) +Type: [TokenCallback](#tokencallback) -##### promptTemplate +#### Message -Template for user messages, with %1 being replaced by the message. - -Type: [boolean](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Boolean) - -##### promptHeader - -The initial instruction for the model, on top of the prompt - -Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) - -##### promptFooter - -The last instruction for the model, appended to the end of the prompt. - -Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) - -#### PromptMessage - -A message in the conversation, identical to OpenAI's chat message. +A message in the conversation. ##### role @@ -553,7 +760,7 @@ Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa #### prompt\_tokens -The number of tokens used in the prompt. +The number of tokens used in the prompt. Currently not available and always 0. Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) @@ -565,13 +772,19 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa #### total\_tokens -The total number of tokens used. +The total number of tokens used. Currently not available and always 0. + +Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) + +#### n\_past\_tokens + +Number of tokens used in the conversation. Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) #### CompletionReturn -The result of the completion, similar to OpenAI's format. +The result of a completion. ##### model @@ -583,23 +796,17 @@ Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa Token usage report. -Type: {prompt\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), completion\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), total\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)} - -##### choices - -The generated completions. - -Type: [Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[CompletionChoice](#completionchoice)> - -#### CompletionChoice - -A completion choice, similar to OpenAI's format. +Type: {prompt\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), completion\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), total\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number), n\_past\_tokens: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number)} ##### message -Response message +The generated completion. -Type: [PromptMessage](#promptmessage) +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) + +#### CompletionStreamReturn + +The result of a streamed completion, containing a stream of tokens and a promise that resolves to the completion result. #### LLModelPromptContext @@ -620,18 +827,29 @@ Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Globa ##### nPast The number of tokens in the past conversation. - -Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) - -##### nCtx - -The number of tokens possible in the context window. +This controls how far back the model looks when generating completions. Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) ##### nPredict -The number of tokens to predict. +The maximum number of tokens to predict. + +Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) + +##### promptTemplate + +Template for user / assistant message pairs. +%1 is required and will be replaced by the user input. +%2 is optional and will be replaced by the assistant response. + +Type: [string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String) + +##### nCtx + +The context window size. Do not use, it has no effect. See loadModel options. +THIS IS DEPRECATED!!! +Use loadModel's nCtx option instead. Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) @@ -654,12 +872,16 @@ above a threshold P. This method, also known as nucleus sampling, finds a balanc and quality by considering both token probabilities and the number of tokens available for sampling. When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse. On the other hand, a lower value (eg., 0.1) produces more focused and conservative text. -The default value is 0.4, which is aimed to be the middle ground between focus and diversity, but -for more creative tasks a higher top-p value will be beneficial, about 0.5-0.9 is a good range for that. Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) -##### temp +##### minP + +The minimum probability of a token to be considered. + +Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) + +##### temperature The temperature to adjust the model's output distribution. Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures @@ -704,19 +926,6 @@ The percentage of context to erase if the context window is exceeded. Type: [number](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Number) -#### generateTokens - -Creates an async generator of tokens - -##### Parameters - -* `llmodel` **[InferenceModel](#inferencemodel)** The language model object. -* `messages` **[Array](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/Array)<[PromptMessage](#promptmessage)>** The array of messages for the conversation. -* `options` **[CompletionOptions](#completionoptions)** The options for creating the completion. -* `callback` **[TokenCallback](#tokencallback)** optional callback to control token generation. - -Returns **AsyncGenerator<[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)>** The stream of generated tokens - #### DEFAULT\_DIRECTORY From python api: @@ -759,7 +968,7 @@ By default this downloads without waiting. use the controller returned to alter ##### Parameters * `modelName` **[string](https://developer.mozilla.org/docs/Web/JavaScript/Reference/Global_Objects/String)** The model to be downloaded. -* `options` **DownloadOptions** to pass into the downloader. Default is { location: (cwd), verbose: false }. +* `options` **[DownloadModelOptions](#downloadmodeloptions)** to pass into the downloader. Default is { location: (cwd), verbose: false }. ##### Examples diff --git a/gpt4all-bindings/typescript/.clang-format b/gpt4all-bindings/typescript/.clang-format new file mode 100644 index 00000000..98ba18a1 --- /dev/null +++ b/gpt4all-bindings/typescript/.clang-format @@ -0,0 +1,4 @@ +--- +Language: Cpp +BasedOnStyle: Microsoft +ColumnLimit: 120 \ No newline at end of file diff --git a/gpt4all-bindings/typescript/README.md b/gpt4all-bindings/typescript/README.md index 5eba3f41..384e4afb 100644 --- a/gpt4all-bindings/typescript/README.md +++ b/gpt4all-bindings/typescript/README.md @@ -10,45 +10,170 @@ npm install gpt4all@latest pnpm install gpt4all@latest ``` - -The original [GPT4All typescript bindings](https://github.com/nomic-ai/gpt4all-ts) are now out of date. - -* New bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use. -* The nodejs api has made strides to mirror the python api. It is not 100% mirrored, but many pieces of the api resemble its python counterpart. -* Everything should work out the box. +## Breaking changes in version 4!! +* See [Transition](#changes) +## Contents * See [API Reference](#api-reference) - +* See [Examples](#api-example) +* See [Developing](#develop) +* GPT4ALL nodejs bindings created by [jacoobes](https://github.com/jacoobes), [limez](https://github.com/iimez) and the [nomic ai community](https://home.nomic.ai), for all to use. +* [spare change](https://github.com/sponsors/jacoobes) for a college student? 🤑 +## Api Examples ### Chat Completion +Use a chat session to keep context between completions. This is useful for efficient back and forth conversations. + ```js -import { createCompletion, loadModel } from '../src/gpt4all.js' +import { createCompletion, loadModel } from "../src/gpt4all.js"; -const model = await loadModel('mistral-7b-openorca.Q4_0.gguf', { verbose: true }); +const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", { + verbose: true, // logs loaded model configuration + device: "gpu", // defaults to 'cpu' + nCtx: 2048, // the maximum sessions context window size. +}); -const response = await createCompletion(model, [ - { role : 'system', content: 'You are meant to be annoying and unhelpful.' }, - { role : 'user', content: 'What is 1 + 1?' } +// initialize a chat session on the model. a model instance can have only one chat session at a time. +const chat = await model.createChatSession({ + // any completion options set here will be used as default for all completions in this chat session + temperature: 0.8, + // a custom systemPrompt can be set here. note that the template depends on the model. + // if unset, the systemPrompt that comes with the model will be used. + systemPrompt: "### System:\nYou are an advanced mathematician.\n\n", +}); + +// create a completion using a string as input +const res1 = await createCompletion(chat, "What is 1 + 1?"); +console.debug(res1.choices[0].message); + +// multiple messages can be input to the conversation at once. +// note that if the last message is not of role 'user', an empty message will be returned. +await createCompletion(chat, [ + { + role: "user", + content: "What is 2 + 2?", + }, + { + role: "assistant", + content: "It's 5.", + }, ]); +const res3 = await createCompletion(chat, "Could you recalculate that?"); +console.debug(res3.choices[0].message); + +model.dispose(); +``` + +### Stateless usage +You can use the model without a chat session. This is useful for one-off completions. + +```js +import { createCompletion, loadModel } from "../src/gpt4all.js"; + +const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf"); + +// createCompletion methods can also be used on the model directly. +// context is not maintained between completions. +const res1 = await createCompletion(model, "What is 1 + 1?"); +console.debug(res1.choices[0].message); + +// a whole conversation can be input as well. +// note that if the last message is not of role 'user', an error will be thrown. +const res2 = await createCompletion(model, [ + { + role: "user", + content: "What is 2 + 2?", + }, + { + role: "assistant", + content: "It's 5.", + }, + { + role: "user", + content: "Could you recalculate that?", + }, +]); +console.debug(res2.choices[0].message); + ``` ### Embedding ```js -import { createEmbedding, loadModel } from '../src/gpt4all.js' +import { loadModel, createEmbedding } from '../src/gpt4all.js' -const model = await loadModel('ggml-all-MiniLM-L6-v2-f16', { verbose: true }); +const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding'}) -const fltArray = createEmbedding(model, "Pain is inevitable, suffering optional"); +console.log(createEmbedding(embedder, "Maybe Minecraft was the friends we made along the way")); ``` +### Streaming responses +```js +import { loadModel, createCompletionStream } from "../src/gpt4all.js"; + +const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", { + device: "gpu", +}); + +process.stdout.write("Output: "); +const stream = createCompletionStream(model, "How are you?"); +stream.tokens.on("data", (data) => { + process.stdout.write(data); +}); +//wait till stream finishes. We cannot continue until this one is done. +await stream.result; +process.stdout.write("\n"); +model.dispose(); + +``` + +### Async Generators +```js +import { loadModel, createCompletionGenerator } from "../src/gpt4all.js"; + +const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf"); + +process.stdout.write("Output: "); +const gen = createCompletionGenerator( + model, + "Redstone in Minecraft is Turing Complete. Let that sink in. (let it in!)" +); +for await (const chunk of gen) { + process.stdout.write(chunk); +} + +process.stdout.write("\n"); +model.dispose(); + +``` +### Offline usage +do this b4 going offline +```sh +curl -L https://gpt4all.io/models/models3.json -o ./models3.json +``` +```js +import { createCompletion, loadModel } from 'gpt4all' + +//make sure u downloaded the models before going offline! +const model = await loadModel('mistral-7b-openorca.gguf2.Q4_0.gguf', { + verbose: true, + device: 'gpu', + modelConfigFile: "./models3.json" +}); + +await createCompletion(model, 'What is 1 + 1?', { verbose: true }) + +model.dispose(); +``` + +## Develop ### Build Instructions -* binding.gyp is compile config +* `binding.gyp` is compile config * Tested on Ubuntu. Everything seems to work fine * Tested on Windows. Everything works fine. * Sparse testing on mac os. -* MingW works as well to build the gpt4all-backend. **HOWEVER**, this package works only with MSVC built dlls. +* MingW script works to build the gpt4all-backend. We left it there just in case. **HOWEVER**, this package works only with MSVC built dlls. ### Requirements @@ -76,23 +201,18 @@ cd gpt4all-bindings/typescript * To Build and Rebuild: ```sh -yarn +node scripts/prebuild.js ``` * llama.cpp git submodule for gpt4all can be possibly absent. If this is the case, make sure to run in llama.cpp parent directory ```sh -git submodule update --init --depth 1 --recursive +git submodule update --init --recursive ``` ```sh yarn build:backend ``` - -This will build platform-dependent dynamic libraries, and will be located in runtimes/(platform)/native The only current way to use them is to put them in the current working directory of your application. That is, **WHEREVER YOU RUN YOUR NODE APPLICATION** - -* llama-xxxx.dll is required. -* According to whatever model you are using, you'll need to select the proper model loader. - * For example, if you running an Mosaic MPT model, you will need to select the mpt-(buildvariant).(dynamiclibrary) +This will build platform-dependent dynamic libraries, and will be located in runtimes/(platform)/native ### Test @@ -130,17 +250,20 @@ yarn test * why your model may be spewing bull 💩 * The downloaded model is broken (just reinstall or download from official site) - * That's it so far +* Your model is hanging after a call to generate tokens. + * Is `nPast` set too high? This may cause your model to hang (03/16/2024), Linux Mint, Ubuntu 22.04 +* Your GPU usage is still high after node.js exits. + * Make sure to call `model.dispose()`!!! ### Roadmap -This package is in active development, and breaking changes may happen until the api stabilizes. Here's what's the todo list: +This package has been stabilizing over time development, and breaking changes may happen until the api stabilizes. Here's what's the todo list: * \[ ] Purely offline. Per the gui, which can be run completely offline, the bindings should be as well. * \[ ] NPM bundle size reduction via optionalDependencies strategy (need help) * Should include prebuilds to avoid painful node-gyp errors -* \[ ] createChatSession ( the python equivalent to create\_chat\_session ) -* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs! +* \[x] createChatSession ( the python equivalent to create\_chat\_session ) +* \[x] generateTokens, the new name for createTokenStream. As of 3.2.0, this is released but not 100% tested. Check spec/generator.mjs! * \[x] ~~createTokenStream, an async iterator that streams each token emitted from the model. Planning on following this [example](https://github.com/nodejs/node-addon-examples/tree/main/threadsafe-async-iterator)~~ May not implement unless someone else can complete * \[x] prompt models via a threadsafe function in order to have proper non blocking behavior in nodejs * \[x] generateTokens is the new name for this^ @@ -149,5 +272,13 @@ This package is in active development, and breaking changes may happen until the * \[x] have more people test on other platforms (mac tester needed) * \[x] switch to new pluggable backend +## Changes +This repository serves as the new bindings for nodejs users. +- If you were a user of [these bindings](https://github.com/nomic-ai/gpt4all-ts), they are outdated. +- Version 4 includes the follow breaking changes + * `createEmbedding` & `EmbeddingModel.embed()` returns an object, `EmbeddingResult`, instead of a float32array. + * Removed deprecated types `ModelType` and `ModelFile` + * Removed deprecated initiation of model by string path only + ### API Reference diff --git a/gpt4all-bindings/typescript/binding.ci.gyp b/gpt4all-bindings/typescript/binding.ci.gyp index 6867e981..5d511155 100644 --- a/gpt4all-bindings/typescript/binding.ci.gyp +++ b/gpt4all-bindings/typescript/binding.ci.gyp @@ -6,12 +6,12 @@ "(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers) )); - + return Napi::Number::New( + env, static_cast(llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers))); } - Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info) - { +Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo &info) +{ auto env = info.Env(); int num_devices = 0; auto mem_size = llmodel_required_mem(GetInference(), full_model_path.c_str(), nCtx, nGpuLayers); - llmodel_gpu_device* all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices); - if(all_devices == nullptr) { - Napi::Error::New( - env, - "Unable to retrieve list of all GPU devices" - ).ThrowAsJavaScriptException(); + llmodel_gpu_device *all_devices = llmodel_available_gpu_devices(GetInference(), mem_size, &num_devices); + if (all_devices == nullptr) + { + Napi::Error::New(env, "Unable to retrieve list of all GPU devices").ThrowAsJavaScriptException(); return env.Undefined(); } auto js_array = Napi::Array::New(env, num_devices); - for(int i = 0; i < num_devices; ++i) { - auto gpu_device = all_devices[i]; - /* - * - * struct llmodel_gpu_device { - int index = 0; - int type = 0; // same as VkPhysicalDeviceType - size_t heapSize = 0; - const char * name; - const char * vendor; - }; - * - */ - Napi::Object js_gpu_device = Napi::Object::New(env); + for (int i = 0; i < num_devices; ++i) + { + auto gpu_device = all_devices[i]; + /* + * + * struct llmodel_gpu_device { + int index = 0; + int type = 0; // same as VkPhysicalDeviceType + size_t heapSize = 0; + const char * name; + const char * vendor; + }; + * + */ + Napi::Object js_gpu_device = Napi::Object::New(env); js_gpu_device["index"] = uint32_t(gpu_device.index); js_gpu_device["type"] = uint32_t(gpu_device.type); - js_gpu_device["heapSize"] = static_cast( gpu_device.heapSize ); - js_gpu_device["name"]= gpu_device.name; + js_gpu_device["heapSize"] = static_cast(gpu_device.heapSize); + js_gpu_device["name"] = gpu_device.name; js_gpu_device["vendor"] = gpu_device.vendor; js_array[i] = js_gpu_device; } return js_array; - } +} - Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo& info) - { - if(type.empty()) { +Napi::Value NodeModelWrapper::GetType(const Napi::CallbackInfo &info) +{ + if (type.empty()) + { return info.Env().Undefined(); - } + } return Napi::String::New(info.Env(), type); - } +} - Napi::Value NodeModelWrapper::InitGpuByString(const Napi::CallbackInfo& info) - { +Napi::Value NodeModelWrapper::InitGpuByString(const Napi::CallbackInfo &info) +{ auto env = info.Env(); size_t memory_required = static_cast(info[0].As().Uint32Value()); - - std::string gpu_device_identifier = info[1].As(); + + std::string gpu_device_identifier = info[1].As(); size_t converted_value; - if(memory_required <= std::numeric_limits::max()) { + if (memory_required <= std::numeric_limits::max()) + { converted_value = static_cast(memory_required); - } else { - Napi::Error::New( - env, - "invalid number for memory size. Exceeded bounds for memory." - ).ThrowAsJavaScriptException(); + } + else + { + Napi::Error::New(env, "invalid number for memory size. Exceeded bounds for memory.") + .ThrowAsJavaScriptException(); return env.Undefined(); } - + auto result = llmodel_gpu_init_gpu_device_by_string(GetInference(), converted_value, gpu_device_identifier.c_str()); return Napi::Boolean::New(env, result); - } - Napi::Value NodeModelWrapper::HasGpuDevice(const Napi::CallbackInfo& info) - { +} +Napi::Value NodeModelWrapper::HasGpuDevice(const Napi::CallbackInfo &info) +{ return Napi::Boolean::New(info.Env(), llmodel_has_gpu_device(GetInference())); - } +} - NodeModelWrapper::NodeModelWrapper(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) - { +NodeModelWrapper::NodeModelWrapper(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) +{ auto env = info.Env(); - fs::path model_path; + auto config_object = info[0].As(); - std::string full_weight_path, - library_path = ".", - model_name, - device; - if(info[0].IsString()) { - model_path = info[0].As().Utf8Value(); - full_weight_path = model_path.string(); - std::cout << "DEPRECATION: constructor accepts object now. Check docs for more.\n"; - } else { - auto config_object = info[0].As(); - model_name = config_object.Get("model_name").As(); - model_path = config_object.Get("model_path").As().Utf8Value(); - if(config_object.Has("model_type")) { - type = config_object.Get("model_type").As(); - } - full_weight_path = (model_path / fs::path(model_name)).string(); - - if(config_object.Has("library_path")) { - library_path = config_object.Get("library_path").As(); - } else { - library_path = "."; - } - device = config_object.Get("device").As(); + // sets the directory where models (gguf files) are to be searched + llmodel_set_implementation_search_path( + config_object.Has("library_path") ? config_object.Get("library_path").As().Utf8Value().c_str() + : "."); - nCtx = config_object.Get("nCtx").As().Int32Value(); - nGpuLayers = config_object.Get("ngl").As().Int32Value(); - } - llmodel_set_implementation_search_path(library_path.c_str()); - const char* e; + std::string model_name = config_object.Get("model_name").As(); + fs::path model_path = config_object.Get("model_path").As().Utf8Value(); + std::string full_weight_path = (model_path / fs::path(model_name)).string(); + + name = model_name.empty() ? model_path.filename().string() : model_name; + full_model_path = full_weight_path; + nCtx = config_object.Get("nCtx").As().Int32Value(); + nGpuLayers = config_object.Get("ngl").As().Int32Value(); + + const char *e; inference_ = llmodel_model_create2(full_weight_path.c_str(), "auto", &e); - if(!inference_) { - Napi::Error::New(env, e).ThrowAsJavaScriptException(); - return; + if (!inference_) + { + Napi::Error::New(env, e).ThrowAsJavaScriptException(); + return; } - if(GetInference() == nullptr) { - std::cerr << "Tried searching libraries in \"" << library_path << "\"" << std::endl; - std::cerr << "Tried searching for model weight in \"" << full_weight_path << "\"" << std::endl; - std::cerr << "Do you have runtime libraries installed?" << std::endl; - Napi::Error::New(env, "Had an issue creating llmodel object, inference is null").ThrowAsJavaScriptException(); - return; + if (GetInference() == nullptr) + { + std::cerr << "Tried searching libraries in \"" << llmodel_get_implementation_search_path() << "\"" << std::endl; + std::cerr << "Tried searching for model weight in \"" << full_weight_path << "\"" << std::endl; + std::cerr << "Do you have runtime libraries installed?" << std::endl; + Napi::Error::New(env, "Had an issue creating llmodel object, inference is null").ThrowAsJavaScriptException(); + return; } - if(device != "cpu") { - size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(),nCtx, nGpuLayers); + + std::string device = config_object.Get("device").As(); + if (device != "cpu") + { + size_t mem = llmodel_required_mem(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers); auto success = llmodel_gpu_init_gpu_device_by_string(GetInference(), mem, device.c_str()); - if(!success) { - //https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215 - //Haven't implemented this but it is still open to contribution + if (!success) + { + // https://github.com/nomic-ai/gpt4all/blob/3acbef14b7c2436fe033cae9036e695d77461a16/gpt4all-bindings/python/gpt4all/pyllmodel.py#L215 + // Haven't implemented this but it is still open to contribution std::cout << "WARNING: Failed to init GPU\n"; } } auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), nCtx, nGpuLayers); - if(!success) { - Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException(); + if (!success) + { + Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException(); return; } - - name = model_name.empty() ? model_path.filename().string() : model_name; - full_model_path = full_weight_path; - }; + // optional + if (config_object.Has("model_type")) + { + type = config_object.Get("model_type").As(); + } +}; // NodeModelWrapper::~NodeModelWrapper() { // if(GetInference() != nullptr) { @@ -182,177 +178,275 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) // if(inference_ != nullptr) { // std::cout << "Debug: deleting model\n"; // -// } +// } // } - Napi::Value NodeModelWrapper::IsModelLoaded(const Napi::CallbackInfo& info) { +Napi::Value NodeModelWrapper::IsModelLoaded(const Napi::CallbackInfo &info) +{ return Napi::Boolean::New(info.Env(), llmodel_isModelLoaded(GetInference())); - } +} - Napi::Value NodeModelWrapper::StateSize(const Napi::CallbackInfo& info) { +Napi::Value NodeModelWrapper::StateSize(const Napi::CallbackInfo &info) +{ // Implement the binding for the stateSize method return Napi::Number::New(info.Env(), static_cast(llmodel_get_state_size(GetInference()))); - } - - Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo& info) { +} + +Napi::Array ChunkedFloatPtr(float *embedding_ptr, int embedding_size, int text_len, Napi::Env const &env) +{ + auto n_embd = embedding_size / text_len; + // std::cout << "Embedding size: " << embedding_size << std::endl; + // std::cout << "Text length: " << text_len << std::endl; + // std::cout << "Chunk size (n_embd): " << n_embd << std::endl; + Napi::Array result = Napi::Array::New(env, text_len); + auto count = 0; + for (int i = 0; i < embedding_size; i += n_embd) + { + int end = std::min(i + n_embd, embedding_size); + // possible bounds error? + // Constructs a container with as many elements as the range [first,last), with each element emplace-constructed + // from its corresponding element in that range, in the same order. + std::vector chunk(embedding_ptr + i, embedding_ptr + end); + Napi::Float32Array fltarr = Napi::Float32Array::New(env, chunk.size()); + // I know there's a way to emplace the raw float ptr into a Napi::Float32Array but idk how and + // im too scared to cause memory issues + // this is goodenough + for (int j = 0; j < chunk.size(); j++) + { + + fltarr.Set(j, chunk[j]); + } + result.Set(count++, fltarr); + } + return result; +} + +Napi::Value NodeModelWrapper::GenerateEmbedding(const Napi::CallbackInfo &info) +{ auto env = info.Env(); - std::string text = info[0].As().Utf8Value(); - size_t embedding_size = 0; - float* arr = llmodel_embedding(GetInference(), text.c_str(), &embedding_size); - if(arr == nullptr) { - Napi::Error::New( - env, - "Cannot embed. native embedder returned 'nullptr'" - ).ThrowAsJavaScriptException(); + + auto prefix = info[1]; + auto dimensionality = info[2].As().Int32Value(); + auto do_mean = info[3].As().Value(); + auto atlas = info[4].As().Value(); + size_t embedding_size; + size_t token_count = 0; + + // This procedure can maybe be optimized but its whatever, i have too many intermediary structures + std::vector text_arr; + bool is_single_text = false; + if (info[0].IsString()) + { + is_single_text = true; + text_arr.push_back(info[0].As().Utf8Value()); + } + else + { + auto jsarr = info[0].As(); + size_t len = jsarr.Length(); + text_arr.reserve(len); + for (size_t i = 0; i < len; ++i) + { + std::string str = jsarr.Get(i).As().Utf8Value(); + text_arr.push_back(str); + } + } + std::vector str_ptrs; + str_ptrs.reserve(text_arr.size() + 1); + for (size_t i = 0; i < text_arr.size(); ++i) + str_ptrs.push_back(text_arr[i].c_str()); + str_ptrs.push_back(nullptr); + const char *_err = nullptr; + float *embeds = llmodel_embed(GetInference(), str_ptrs.data(), &embedding_size, + prefix.IsUndefined() ? nullptr : prefix.As().Utf8Value().c_str(), + dimensionality, &token_count, do_mean, atlas, &_err); + if (!embeds) + { + // i dont wanna deal with c strings lol + std::string err(_err); + Napi::Error::New(env, err == "(unknown error)" ? "Unknown error: sorry bud" : err).ThrowAsJavaScriptException(); return env.Undefined(); } + auto embedmat = ChunkedFloatPtr(embeds, embedding_size, text_arr.size(), env); - if(embedding_size == 0 && text.size() != 0 ) { - std::cout << "Warning: embedding length 0 but input text length > 0" << std::endl; - } - Napi::Float32Array js_array = Napi::Float32Array::New(env, embedding_size); - - for (size_t i = 0; i < embedding_size; ++i) { - float element = *(arr + i); - js_array[i] = element; + llmodel_free_embedding(embeds); + auto res = Napi::Object::New(env); + res.Set("n_prompt_tokens", token_count); + if(is_single_text) { + res.Set("embeddings", embedmat.Get(static_cast(0))); + } else { + res.Set("embeddings", embedmat); } - llmodel_free_embedding(arr); - - return js_array; - } + return res; +} /** * Generate a response using the model. - * @param model A pointer to the llmodel_model instance. * @param prompt A string representing the input prompt. - * @param prompt_callback A callback function for handling the processing of prompt. - * @param response_callback A callback function for handling the generated response. - * @param recalculate_callback A callback function for handling recalculation requests. - * @param ctx A pointer to the llmodel_prompt_context structure. + * @param options Inference options. */ - Napi::Value NodeModelWrapper::Prompt(const Napi::CallbackInfo& info) { +Napi::Value NodeModelWrapper::Infer(const Napi::CallbackInfo &info) +{ auto env = info.Env(); - std::string question; - if(info[0].IsString()) { - question = info[0].As().Utf8Value(); - } else { + std::string prompt; + if (info[0].IsString()) + { + prompt = info[0].As().Utf8Value(); + } + else + { Napi::Error::New(info.Env(), "invalid string argument").ThrowAsJavaScriptException(); return info.Env().Undefined(); } - //defaults copied from python bindings - llmodel_prompt_context promptContext = { - .logits = nullptr, - .tokens = nullptr, - .n_past = 0, - .n_ctx = 1024, - .n_predict = 128, - .top_k = 40, - .top_p = 0.9f, - .min_p = 0.0f, - .temp = 0.72f, - .n_batch = 8, - .repeat_penalty = 1.0f, - .repeat_last_n = 10, - .context_erase = 0.5 - }; - - PromptWorkerConfig promptWorkerConfig; - if(info[1].IsObject()) - { - auto inputObject = info[1].As(); - - // Extract and assign the properties - if (inputObject.Has("logits") || inputObject.Has("tokens")) { - Napi::Error::New(info.Env(), "Invalid input: 'logits' or 'tokens' properties are not allowed").ThrowAsJavaScriptException(); - return info.Env().Undefined(); - } - // Assign the remaining properties - if(inputObject.Has("n_past")) - promptContext.n_past = inputObject.Get("n_past").As().Int32Value(); - if(inputObject.Has("n_ctx")) - promptContext.n_ctx = inputObject.Get("n_ctx").As().Int32Value(); - if(inputObject.Has("n_predict")) - promptContext.n_predict = inputObject.Get("n_predict").As().Int32Value(); - if(inputObject.Has("top_k")) - promptContext.top_k = inputObject.Get("top_k").As().Int32Value(); - if(inputObject.Has("top_p")) - promptContext.top_p = inputObject.Get("top_p").As().FloatValue(); - if(inputObject.Has("min_p")) - promptContext.min_p = inputObject.Get("min_p").As().FloatValue(); - if(inputObject.Has("temp")) - promptContext.temp = inputObject.Get("temp").As().FloatValue(); - if(inputObject.Has("n_batch")) - promptContext.n_batch = inputObject.Get("n_batch").As().Int32Value(); - if(inputObject.Has("repeat_penalty")) - promptContext.repeat_penalty = inputObject.Get("repeat_penalty").As().FloatValue(); - if(inputObject.Has("repeat_last_n")) - promptContext.repeat_last_n = inputObject.Get("repeat_last_n").As().Int32Value(); - if(inputObject.Has("context_erase")) - promptContext.context_erase = inputObject.Get("context_erase").As().FloatValue(); - } - else + if (!info[1].IsObject()) { Napi::Error::New(info.Env(), "Missing Prompt Options").ThrowAsJavaScriptException(); return info.Env().Undefined(); } + // defaults copied from python bindings + llmodel_prompt_context promptContext = {.logits = nullptr, + .tokens = nullptr, + .n_past = 0, + .n_ctx = nCtx, + .n_predict = 4096, + .top_k = 40, + .top_p = 0.9f, + .min_p = 0.0f, + .temp = 0.1f, + .n_batch = 8, + .repeat_penalty = 1.2f, + .repeat_last_n = 10, + .context_erase = 0.75}; - if(info.Length() >= 3 && info[2].IsFunction()){ - promptWorkerConfig.bHasTokenCallback = true; - promptWorkerConfig.tokenCallback = info[2].As(); + PromptWorkerConfig promptWorkerConfig; + + auto inputObject = info[1].As(); + + if (inputObject.Has("logits") || inputObject.Has("tokens")) + { + Napi::Error::New(info.Env(), "Invalid input: 'logits' or 'tokens' properties are not allowed") + .ThrowAsJavaScriptException(); + return info.Env().Undefined(); } - + // Assign the remaining properties + if (inputObject.Has("nPast") && inputObject.Get("nPast").IsNumber()) + { + promptContext.n_past = inputObject.Get("nPast").As().Int32Value(); + } + if (inputObject.Has("nPredict") && inputObject.Get("nPredict").IsNumber()) + { + promptContext.n_predict = inputObject.Get("nPredict").As().Int32Value(); + } + if (inputObject.Has("topK") && inputObject.Get("topK").IsNumber()) + { + promptContext.top_k = inputObject.Get("topK").As().Int32Value(); + } + if (inputObject.Has("topP") && inputObject.Get("topP").IsNumber()) + { + promptContext.top_p = inputObject.Get("topP").As().FloatValue(); + } + if (inputObject.Has("minP") && inputObject.Get("minP").IsNumber()) + { + promptContext.min_p = inputObject.Get("minP").As().FloatValue(); + } + if (inputObject.Has("temp") && inputObject.Get("temp").IsNumber()) + { + promptContext.temp = inputObject.Get("temp").As().FloatValue(); + } + if (inputObject.Has("nBatch") && inputObject.Get("nBatch").IsNumber()) + { + promptContext.n_batch = inputObject.Get("nBatch").As().Int32Value(); + } + if (inputObject.Has("repeatPenalty") && inputObject.Get("repeatPenalty").IsNumber()) + { + promptContext.repeat_penalty = inputObject.Get("repeatPenalty").As().FloatValue(); + } + if (inputObject.Has("repeatLastN") && inputObject.Get("repeatLastN").IsNumber()) + { + promptContext.repeat_last_n = inputObject.Get("repeatLastN").As().Int32Value(); + } + if (inputObject.Has("contextErase") && inputObject.Get("contextErase").IsNumber()) + { + promptContext.context_erase = inputObject.Get("contextErase").As().FloatValue(); + } + if (inputObject.Has("onPromptToken") && inputObject.Get("onPromptToken").IsFunction()) + { + promptWorkerConfig.promptCallback = inputObject.Get("onPromptToken").As(); + promptWorkerConfig.hasPromptCallback = true; + } + if (inputObject.Has("onResponseToken") && inputObject.Get("onResponseToken").IsFunction()) + { + promptWorkerConfig.responseCallback = inputObject.Get("onResponseToken").As(); + promptWorkerConfig.hasResponseCallback = true; + } - //copy to protect llmodel resources when splitting to new thread - // llmodel_prompt_context copiedPrompt = promptContext; + // copy to protect llmodel resources when splitting to new thread + // llmodel_prompt_context copiedPrompt = promptContext; promptWorkerConfig.context = promptContext; promptWorkerConfig.model = GetInference(); promptWorkerConfig.mutex = &inference_mutex; - promptWorkerConfig.prompt = question; + promptWorkerConfig.prompt = prompt; promptWorkerConfig.result = ""; - + promptWorkerConfig.promptTemplate = inputObject.Get("promptTemplate").As(); + if (inputObject.Has("special")) + { + promptWorkerConfig.special = inputObject.Get("special").As(); + } + if (inputObject.Has("fakeReply")) + { + // this will be deleted in the worker + promptWorkerConfig.fakeReply = new std::string(inputObject.Get("fakeReply").As().Utf8Value()); + } auto worker = new PromptWorker(env, promptWorkerConfig); worker->Queue(); return worker->GetPromise(); - } - void NodeModelWrapper::Dispose(const Napi::CallbackInfo& info) { +} +void NodeModelWrapper::Dispose(const Napi::CallbackInfo &info) +{ llmodel_model_destroy(inference_); - } - void NodeModelWrapper::SetThreadCount(const Napi::CallbackInfo& info) { - if(info[0].IsNumber()) { +} +void NodeModelWrapper::SetThreadCount(const Napi::CallbackInfo &info) +{ + if (info[0].IsNumber()) + { llmodel_setThreadCount(GetInference(), info[0].As().Int64Value()); - } else { - Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException(); + } + else + { + Napi::Error::New(info.Env(), "Could not set thread count: argument 1 is NaN").ThrowAsJavaScriptException(); return; } - } - - Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo& info) { - return Napi::String::New(info.Env(), name); - } - Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo& info) { - return Napi::Number::New(info.Env(), llmodel_threadCount(GetInference())); - } - - Napi::Value NodeModelWrapper::GetLibraryPath(const Napi::CallbackInfo& info) { - return Napi::String::New(info.Env(), - llmodel_get_implementation_search_path()); - } - - llmodel_model NodeModelWrapper::GetInference() { - return inference_; - } - -//Exports Bindings -Napi::Object Init(Napi::Env env, Napi::Object exports) { - exports["LLModel"] = NodeModelWrapper::GetClass(env); - return exports; } +Napi::Value NodeModelWrapper::GetName(const Napi::CallbackInfo &info) +{ + return Napi::String::New(info.Env(), name); +} +Napi::Value NodeModelWrapper::ThreadCount(const Napi::CallbackInfo &info) +{ + return Napi::Number::New(info.Env(), llmodel_threadCount(GetInference())); +} +Napi::Value NodeModelWrapper::GetLibraryPath(const Napi::CallbackInfo &info) +{ + return Napi::String::New(info.Env(), llmodel_get_implementation_search_path()); +} + +llmodel_model NodeModelWrapper::GetInference() +{ + return inference_; +} + +// Exports Bindings +Napi::Object Init(Napi::Env env, Napi::Object exports) +{ + exports["LLModel"] = NodeModelWrapper::GetClass(env); + return exports; +} NODE_API_MODULE(NODE_GYP_MODULE_NAME, Init) diff --git a/gpt4all-bindings/typescript/index.h b/gpt4all-bindings/typescript/index.h index 6afdf217..db3ef11e 100644 --- a/gpt4all-bindings/typescript/index.h +++ b/gpt4all-bindings/typescript/index.h @@ -1,62 +1,63 @@ -#include #include "llmodel.h" -#include -#include "llmodel_c.h" +#include "llmodel_c.h" #include "prompt.h" #include -#include #include -#include +#include +#include #include +#include +#include namespace fs = std::filesystem; +class NodeModelWrapper : public Napi::ObjectWrap +{ -class NodeModelWrapper: public Napi::ObjectWrap { - -public: - NodeModelWrapper(const Napi::CallbackInfo &); - //virtual ~NodeModelWrapper(); - Napi::Value GetType(const Napi::CallbackInfo& info); - Napi::Value IsModelLoaded(const Napi::CallbackInfo& info); - Napi::Value StateSize(const Napi::CallbackInfo& info); - //void Finalize(Napi::Env env) override; - /** - * Prompting the model. This entails spawning a new thread and adding the response tokens - * into a thread local string variable. - */ - Napi::Value Prompt(const Napi::CallbackInfo& info); - void SetThreadCount(const Napi::CallbackInfo& info); - void Dispose(const Napi::CallbackInfo& info); - Napi::Value GetName(const Napi::CallbackInfo& info); - Napi::Value ThreadCount(const Napi::CallbackInfo& info); - Napi::Value GenerateEmbedding(const Napi::CallbackInfo& info); - Napi::Value HasGpuDevice(const Napi::CallbackInfo& info); - Napi::Value ListGpus(const Napi::CallbackInfo& info); - Napi::Value InitGpuByString(const Napi::CallbackInfo& info); - Napi::Value GetRequiredMemory(const Napi::CallbackInfo& info); - Napi::Value GetGpuDevices(const Napi::CallbackInfo& info); - /* - * The path that is used to search for the dynamic libraries - */ - Napi::Value GetLibraryPath(const Napi::CallbackInfo& info); - /** - * Creates the LLModel class - */ - static Napi::Function GetClass(Napi::Env); - llmodel_model GetInference(); -private: - /** - * The underlying inference that interfaces with the C interface - */ - llmodel_model inference_; + public: + NodeModelWrapper(const Napi::CallbackInfo &); + // virtual ~NodeModelWrapper(); + Napi::Value GetType(const Napi::CallbackInfo &info); + Napi::Value IsModelLoaded(const Napi::CallbackInfo &info); + Napi::Value StateSize(const Napi::CallbackInfo &info); + // void Finalize(Napi::Env env) override; + /** + * Prompting the model. This entails spawning a new thread and adding the response tokens + * into a thread local string variable. + */ + Napi::Value Infer(const Napi::CallbackInfo &info); + void SetThreadCount(const Napi::CallbackInfo &info); + void Dispose(const Napi::CallbackInfo &info); + Napi::Value GetName(const Napi::CallbackInfo &info); + Napi::Value ThreadCount(const Napi::CallbackInfo &info); + Napi::Value GenerateEmbedding(const Napi::CallbackInfo &info); + Napi::Value HasGpuDevice(const Napi::CallbackInfo &info); + Napi::Value ListGpus(const Napi::CallbackInfo &info); + Napi::Value InitGpuByString(const Napi::CallbackInfo &info); + Napi::Value GetRequiredMemory(const Napi::CallbackInfo &info); + Napi::Value GetGpuDevices(const Napi::CallbackInfo &info); + /* + * The path that is used to search for the dynamic libraries + */ + Napi::Value GetLibraryPath(const Napi::CallbackInfo &info); + /** + * Creates the LLModel class + */ + static Napi::Function GetClass(Napi::Env); + llmodel_model GetInference(); - std::mutex inference_mutex; + private: + /** + * The underlying inference that interfaces with the C interface + */ + llmodel_model inference_; - std::string type; - // corresponds to LLModel::name() in typescript - std::string name; - int nCtx{}; - int nGpuLayers{}; - std::string full_model_path; + std::mutex inference_mutex; + + std::string type; + // corresponds to LLModel::name() in typescript + std::string name; + int nCtx{}; + int nGpuLayers{}; + std::string full_model_path; }; diff --git a/gpt4all-bindings/typescript/package.json b/gpt4all-bindings/typescript/package.json index dca1d6ca..7f3e368e 100644 --- a/gpt4all-bindings/typescript/package.json +++ b/gpt4all-bindings/typescript/package.json @@ -1,6 +1,6 @@ { "name": "gpt4all", - "version": "3.2.0", + "version": "4.0.0", "packageManager": "yarn@3.6.1", "main": "src/gpt4all.js", "repository": "nomic-ai/gpt4all", @@ -22,7 +22,6 @@ ], "dependencies": { "md5-file": "^5.0.0", - "mkdirp": "^3.0.1", "node-addon-api": "^6.1.0", "node-gyp-build": "^4.6.0" }, diff --git a/gpt4all-bindings/typescript/prompt.cc b/gpt4all-bindings/typescript/prompt.cc index bc9b21c4..d24a7e90 100644 --- a/gpt4all-bindings/typescript/prompt.cc +++ b/gpt4all-bindings/typescript/prompt.cc @@ -2,145 +2,195 @@ #include PromptWorker::PromptWorker(Napi::Env env, PromptWorkerConfig config) - : promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) { - if(_config.bHasTokenCallback){ - _tsfn = Napi::ThreadSafeFunction::New(config.tokenCallback.Env(),config.tokenCallback,"PromptWorker",0,1,this); - } - } - - PromptWorker::~PromptWorker() + : promise(Napi::Promise::Deferred::New(env)), _config(config), AsyncWorker(env) +{ + if (_config.hasResponseCallback) { - if(_config.bHasTokenCallback){ - _tsfn.Release(); - } + _responseCallbackFn = Napi::ThreadSafeFunction::New(config.responseCallback.Env(), config.responseCallback, + "PromptWorker", 0, 1, this); } - void PromptWorker::Execute() + if (_config.hasPromptCallback) { - _config.mutex->lock(); + _promptCallbackFn = Napi::ThreadSafeFunction::New(config.promptCallback.Env(), config.promptCallback, + "PromptWorker", 0, 1, this); + } +} - LLModelWrapper *wrapper = reinterpret_cast(_config.model); +PromptWorker::~PromptWorker() +{ + if (_config.hasResponseCallback) + { + _responseCallbackFn.Release(); + } + if (_config.hasPromptCallback) + { + _promptCallbackFn.Release(); + } +} - auto ctx = &_config.context; +void PromptWorker::Execute() +{ + _config.mutex->lock(); - if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size()) - wrapper->promptContext.tokens.resize(ctx->n_past); + LLModelWrapper *wrapper = reinterpret_cast(_config.model); - // Copy the C prompt context - wrapper->promptContext.n_past = ctx->n_past; - wrapper->promptContext.n_ctx = ctx->n_ctx; - wrapper->promptContext.n_predict = ctx->n_predict; - wrapper->promptContext.top_k = ctx->top_k; - wrapper->promptContext.top_p = ctx->top_p; - wrapper->promptContext.temp = ctx->temp; - wrapper->promptContext.n_batch = ctx->n_batch; - wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; - wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; - wrapper->promptContext.contextErase = ctx->context_erase; + auto ctx = &_config.context; - // Napi::Error::Fatal( - // "SUPRA", - // "About to prompt"); - // Call the C++ prompt method - wrapper->llModel->prompt( - _config.prompt, - [](int32_t tid) { return true; }, - [this](int32_t token_id, const std::string tok) - { - return ResponseCallback(token_id, tok); - }, - [](bool isRecalculating) - { - return isRecalculating; - }, - wrapper->promptContext); + if (size_t(ctx->n_past) < wrapper->promptContext.tokens.size()) + wrapper->promptContext.tokens.resize(ctx->n_past); - // Update the C context by giving access to the wrappers raw pointers to std::vector data - // which involves no copies - ctx->logits = wrapper->promptContext.logits.data(); - ctx->logits_size = wrapper->promptContext.logits.size(); - ctx->tokens = wrapper->promptContext.tokens.data(); - ctx->tokens_size = wrapper->promptContext.tokens.size(); + // Copy the C prompt context + wrapper->promptContext.n_past = ctx->n_past; + wrapper->promptContext.n_ctx = ctx->n_ctx; + wrapper->promptContext.n_predict = ctx->n_predict; + wrapper->promptContext.top_k = ctx->top_k; + wrapper->promptContext.top_p = ctx->top_p; + wrapper->promptContext.temp = ctx->temp; + wrapper->promptContext.n_batch = ctx->n_batch; + wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; + wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; + wrapper->promptContext.contextErase = ctx->context_erase; - // Update the rest of the C prompt context - ctx->n_past = wrapper->promptContext.n_past; - ctx->n_ctx = wrapper->promptContext.n_ctx; - ctx->n_predict = wrapper->promptContext.n_predict; - ctx->top_k = wrapper->promptContext.top_k; - ctx->top_p = wrapper->promptContext.top_p; - ctx->temp = wrapper->promptContext.temp; - ctx->n_batch = wrapper->promptContext.n_batch; - ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; - ctx->repeat_last_n = wrapper->promptContext.repeat_last_n; - ctx->context_erase = wrapper->promptContext.contextErase; + // Call the C++ prompt method - _config.mutex->unlock(); + wrapper->llModel->prompt( + _config.prompt, _config.promptTemplate, [this](int32_t token_id) { return PromptCallback(token_id); }, + [this](int32_t token_id, const std::string token) { return ResponseCallback(token_id, token); }, + [](bool isRecalculating) { return isRecalculating; }, wrapper->promptContext, _config.special, + _config.fakeReply); + + // Update the C context by giving access to the wrappers raw pointers to std::vector data + // which involves no copies + ctx->logits = wrapper->promptContext.logits.data(); + ctx->logits_size = wrapper->promptContext.logits.size(); + ctx->tokens = wrapper->promptContext.tokens.data(); + ctx->tokens_size = wrapper->promptContext.tokens.size(); + + // Update the rest of the C prompt context + ctx->n_past = wrapper->promptContext.n_past; + ctx->n_ctx = wrapper->promptContext.n_ctx; + ctx->n_predict = wrapper->promptContext.n_predict; + ctx->top_k = wrapper->promptContext.top_k; + ctx->top_p = wrapper->promptContext.top_p; + ctx->temp = wrapper->promptContext.temp; + ctx->n_batch = wrapper->promptContext.n_batch; + ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; + ctx->repeat_last_n = wrapper->promptContext.repeat_last_n; + ctx->context_erase = wrapper->promptContext.contextErase; + + _config.mutex->unlock(); +} + +void PromptWorker::OnOK() +{ + Napi::Object returnValue = Napi::Object::New(Env()); + returnValue.Set("text", result); + returnValue.Set("nPast", _config.context.n_past); + promise.Resolve(returnValue); + delete _config.fakeReply; +} + +void PromptWorker::OnError(const Napi::Error &e) +{ + delete _config.fakeReply; + promise.Reject(e.Value()); +} + +Napi::Promise PromptWorker::GetPromise() +{ + return promise.Promise(); +} + +bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token) +{ + if (token_id == -1) + { + return false; } - void PromptWorker::OnOK() - { - promise.Resolve(Napi::String::New(Env(), result)); - } - - void PromptWorker::OnError(const Napi::Error &e) - { - promise.Reject(e.Value()); - } - - Napi::Promise PromptWorker::GetPromise() - { - return promise.Promise(); - } - - bool PromptWorker::ResponseCallback(int32_t token_id, const std::string token) - { - if (token_id == -1) - { - return false; - } - - if(!_config.bHasTokenCallback){ - return true; - } - - result += token; - - std::promise promise; - - auto info = new TokenCallbackInfo(); - info->tokenId = token_id; - info->token = token; - info->total = result; - - auto future = promise.get_future(); - - auto status = _tsfn.BlockingCall(info, [&promise](Napi::Env env, Napi::Function jsCallback, TokenCallbackInfo *value) - { - // Transform native data into JS data, passing it to the provided - // `jsCallback` -- the TSFN's JavaScript function. - auto token_id = Napi::Number::New(env, value->tokenId); - auto token = Napi::String::New(env, value->token); - auto total = Napi::String::New(env,value->total); - auto jsResult = jsCallback.Call({ token_id, token, total}).ToBoolean(); - promise.set_value(jsResult); - // We're finished with the data. - delete value; - }); - if (status != napi_ok) { - Napi::Error::Fatal( - "PromptWorkerResponseCallback", - "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed"); - } - - return future.get(); - } - - bool PromptWorker::RecalculateCallback(bool isRecalculating) - { - return isRecalculating; - } - - bool PromptWorker::PromptCallback(int32_t tid) + if (!_config.hasResponseCallback) { return true; } + + result += token; + + std::promise promise; + + auto info = new ResponseCallbackData(); + info->tokenId = token_id; + info->token = token; + + auto future = promise.get_future(); + + auto status = _responseCallbackFn.BlockingCall( + info, [&promise](Napi::Env env, Napi::Function jsCallback, ResponseCallbackData *value) { + try + { + // Transform native data into JS data, passing it to the provided + // `jsCallback` -- the TSFN's JavaScript function. + auto token_id = Napi::Number::New(env, value->tokenId); + auto token = Napi::String::New(env, value->token); + auto jsResult = jsCallback.Call({token_id, token}).ToBoolean(); + promise.set_value(jsResult); + } + catch (const Napi::Error &e) + { + std::cerr << "Error in onResponseToken callback: " << e.what() << std::endl; + promise.set_value(false); + } + + delete value; + }); + if (status != napi_ok) + { + Napi::Error::Fatal("PromptWorkerResponseCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed"); + } + + return future.get(); +} + +bool PromptWorker::RecalculateCallback(bool isRecalculating) +{ + return isRecalculating; +} + +bool PromptWorker::PromptCallback(int32_t token_id) +{ + if (!_config.hasPromptCallback) + { + return true; + } + + std::promise promise; + + auto info = new PromptCallbackData(); + info->tokenId = token_id; + + auto future = promise.get_future(); + + auto status = _promptCallbackFn.BlockingCall( + info, [&promise](Napi::Env env, Napi::Function jsCallback, PromptCallbackData *value) { + try + { + // Transform native data into JS data, passing it to the provided + // `jsCallback` -- the TSFN's JavaScript function. + auto token_id = Napi::Number::New(env, value->tokenId); + auto jsResult = jsCallback.Call({token_id}).ToBoolean(); + promise.set_value(jsResult); + } + catch (const Napi::Error &e) + { + std::cerr << "Error in onPromptToken callback: " << e.what() << std::endl; + promise.set_value(false); + } + delete value; + }); + if (status != napi_ok) + { + Napi::Error::Fatal("PromptWorkerPromptCallback", "Napi::ThreadSafeNapi::Function.NonBlockingCall() failed"); + } + + return future.get(); +} diff --git a/gpt4all-bindings/typescript/prompt.h b/gpt4all-bindings/typescript/prompt.h index 4f8cd531..49c43620 100644 --- a/gpt4all-bindings/typescript/prompt.h +++ b/gpt4all-bindings/typescript/prompt.h @@ -1,59 +1,72 @@ #ifndef PREDICT_WORKER_H #define PREDICT_WORKER_H -#include "napi.h" -#include "llmodel_c.h" #include "llmodel.h" -#include -#include -#include +#include "llmodel_c.h" +#include "napi.h" #include +#include #include +#include +#include -struct TokenCallbackInfo +struct ResponseCallbackData +{ + int32_t tokenId; + std::string token; +}; + +struct PromptCallbackData +{ + int32_t tokenId; +}; + +struct LLModelWrapper +{ + LLModel *llModel = nullptr; + LLModel::PromptContext promptContext; + ~LLModelWrapper() { - int32_t tokenId; - std::string total; - std::string token; - }; + delete llModel; + } +}; - struct LLModelWrapper - { - LLModel *llModel = nullptr; - LLModel::PromptContext promptContext; - ~LLModelWrapper() { delete llModel; } - }; +struct PromptWorkerConfig +{ + Napi::Function responseCallback; + bool hasResponseCallback = false; + Napi::Function promptCallback; + bool hasPromptCallback = false; + llmodel_model model; + std::mutex *mutex; + std::string prompt; + std::string promptTemplate; + llmodel_prompt_context context; + std::string result; + bool special = false; + std::string *fakeReply = nullptr; +}; - struct PromptWorkerConfig - { - Napi::Function tokenCallback; - bool bHasTokenCallback = false; - llmodel_model model; - std::mutex * mutex; - std::string prompt; - llmodel_prompt_context context; - std::string result; - }; +class PromptWorker : public Napi::AsyncWorker +{ + public: + PromptWorker(Napi::Env env, PromptWorkerConfig config); + ~PromptWorker(); + void Execute() override; + void OnOK() override; + void OnError(const Napi::Error &e) override; + Napi::Promise GetPromise(); - class PromptWorker : public Napi::AsyncWorker - { - public: - PromptWorker(Napi::Env env, PromptWorkerConfig config); - ~PromptWorker(); - void Execute() override; - void OnOK() override; - void OnError(const Napi::Error &e) override; - Napi::Promise GetPromise(); + bool ResponseCallback(int32_t token_id, const std::string token); + bool RecalculateCallback(bool isrecalculating); + bool PromptCallback(int32_t token_id); - bool ResponseCallback(int32_t token_id, const std::string token); - bool RecalculateCallback(bool isrecalculating); - bool PromptCallback(int32_t tid); + private: + Napi::Promise::Deferred promise; + std::string result; + PromptWorkerConfig _config; + Napi::ThreadSafeFunction _responseCallbackFn; + Napi::ThreadSafeFunction _promptCallbackFn; +}; - private: - Napi::Promise::Deferred promise; - std::string result; - PromptWorkerConfig _config; - Napi::ThreadSafeFunction _tsfn; - }; - -#endif // PREDICT_WORKER_H +#endif // PREDICT_WORKER_H diff --git a/gpt4all-bindings/typescript/scripts/build_unix.sh b/gpt4all-bindings/typescript/scripts/build_unix.sh index 185cc255..d60343c9 100755 --- a/gpt4all-bindings/typescript/scripts/build_unix.sh +++ b/gpt4all-bindings/typescript/scripts/build_unix.sh @@ -24,7 +24,6 @@ mkdir -p "$NATIVE_DIR" "$BUILD_DIR" cmake -S ../../gpt4all-backend -B "$BUILD_DIR" && cmake --build "$BUILD_DIR" -j --config Release && { - cp "$BUILD_DIR"/libbert*.$LIB_EXT "$NATIVE_DIR"/ cp "$BUILD_DIR"/libgptj*.$LIB_EXT "$NATIVE_DIR"/ cp "$BUILD_DIR"/libllama*.$LIB_EXT "$NATIVE_DIR"/ } diff --git a/gpt4all-bindings/typescript/spec/callbacks.mjs b/gpt4all-bindings/typescript/spec/callbacks.mjs new file mode 100644 index 00000000..461f32be --- /dev/null +++ b/gpt4all-bindings/typescript/spec/callbacks.mjs @@ -0,0 +1,31 @@ +import { promises as fs } from "node:fs"; +import { loadModel, createCompletion } from "../src/gpt4all.js"; + +const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", { + verbose: true, + device: "gpu", +}); + +const res = await createCompletion( + model, + "I've got three 🍣 - What shall I name them?", + { + onPromptToken: (tokenId) => { + console.debug("onPromptToken", { tokenId }); + // throwing an error will cancel + throw new Error("This is an error"); + // const foo = thisMethodDoesNotExist(); + // returning false will cancel as well + // return false; + }, + onResponseToken: (tokenId, token) => { + console.debug("onResponseToken", { tokenId, token }); + // same applies here + }, + } +); + +console.debug("Output:", { + usage: res.usage, + message: res.choices[0].message, +}); diff --git a/gpt4all-bindings/typescript/spec/chat-memory.mjs b/gpt4all-bindings/typescript/spec/chat-memory.mjs new file mode 100644 index 00000000..9a771633 --- /dev/null +++ b/gpt4all-bindings/typescript/spec/chat-memory.mjs @@ -0,0 +1,65 @@ +import { loadModel, createCompletion } from "../src/gpt4all.js"; + +const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", { + verbose: true, + device: "gpu", +}); + +const chat = await model.createChatSession({ + messages: [ + { + role: "user", + content: "I'll tell you a secret password: It's 63445.", + }, + { + role: "assistant", + content: "I will do my best to remember that.", + }, + { + role: "user", + content: + "And here another fun fact: Bananas may be bluer than bread at night.", + }, + { + role: "assistant", + content: "Yes, that makes sense.", + }, + ], +}); + +const turn1 = await createCompletion( + chat, + "Please tell me the secret password." +); +console.debug(turn1.choices[0].message); +// "The secret password you shared earlier is 63445."" + +const turn2 = await createCompletion( + chat, + "Thanks! Have your heard about the bananas?" +); +console.debug(turn2.choices[0].message); + +for (let i = 0; i < 32; i++) { + // gpu go brr + const turn = await createCompletion( + chat, + i % 2 === 0 ? "Tell me a fun fact." : "And a boring one?" + ); + console.debug({ + message: turn.choices[0].message, + n_past_tokens: turn.usage.n_past_tokens, + }); +} + +const finalTurn = await createCompletion( + chat, + "Now I forgot the secret password. Can you remind me?" +); +console.debug(finalTurn.choices[0].message); + +// result of finalTurn may vary depending on whether the generated facts pushed the secret out of the context window. +// "Of course! The secret password you shared earlier is 63445." +// "I apologize for any confusion. As an AI language model, ..." + +model.dispose(); diff --git a/gpt4all-bindings/typescript/spec/chat-minimal.mjs b/gpt4all-bindings/typescript/spec/chat-minimal.mjs new file mode 100644 index 00000000..6d822f23 --- /dev/null +++ b/gpt4all-bindings/typescript/spec/chat-minimal.mjs @@ -0,0 +1,19 @@ +import { loadModel, createCompletion } from "../src/gpt4all.js"; + +const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", { + verbose: true, + device: "gpu", +}); + +const chat = await model.createChatSession(); + +await createCompletion( + chat, + "Why are bananas rather blue than bread at night sometimes?", + { + verbose: true, + } +); +await createCompletion(chat, "Are you sure?", { + verbose: true, +}); diff --git a/gpt4all-bindings/typescript/spec/chat.mjs b/gpt4all-bindings/typescript/spec/chat.mjs deleted file mode 100644 index 62ea2e95..00000000 --- a/gpt4all-bindings/typescript/spec/chat.mjs +++ /dev/null @@ -1,70 +0,0 @@ -import { LLModel, createCompletion, DEFAULT_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY, loadModel } from '../src/gpt4all.js' - -const model = await loadModel( - 'mistral-7b-openorca.Q4_0.gguf', - { verbose: true, device: 'gpu' } -); -const ll = model.llm; - -try { - class Extended extends LLModel { - } - -} catch(e) { - console.log("Extending from native class gone wrong " + e) -} - -console.log("state size " + ll.stateSize()) - -console.log("thread count " + ll.threadCount()); -ll.setThreadCount(5); - -console.log("thread count " + ll.threadCount()); -ll.setThreadCount(4); -console.log("thread count " + ll.threadCount()); -console.log("name " + ll.name()); -console.log("type: " + ll.type()); -console.log("Default directory for models", DEFAULT_DIRECTORY); -console.log("Default directory for libraries", DEFAULT_LIBRARIES_DIRECTORY); -console.log("Has GPU", ll.hasGpuDevice()); -console.log("gpu devices", ll.listGpu()) -console.log("Required Mem in bytes", ll.memoryNeeded()) -const completion1 = await createCompletion(model, [ - { role : 'system', content: 'You are an advanced mathematician.' }, - { role : 'user', content: 'What is 1 + 1?' }, -], { verbose: true }) -console.log(completion1.choices[0].message) - -const completion2 = await createCompletion(model, [ - { role : 'system', content: 'You are an advanced mathematician.' }, - { role : 'user', content: 'What is two plus two?' }, -], { verbose: true }) - -console.log(completion2.choices[0].message) - -//CALLING DISPOSE WILL INVALID THE NATIVE MODEL. USE THIS TO CLEANUP -model.dispose() -// At the moment, from testing this code, concurrent model prompting is not possible. -// Behavior: The last prompt gets answered, but the rest are cancelled -// my experience with threading is not the best, so if anyone who is good is willing to give this a shot, -// maybe this is possible -// INFO: threading with llama.cpp is not the best maybe not even possible, so this will be left here as reference - -//const responses = await Promise.all([ -// createCompletion(model, [ -// { role : 'system', content: 'You are an advanced mathematician.' }, -// { role : 'user', content: 'What is 1 + 1?' }, -// ], { verbose: true }), -// createCompletion(model, [ -// { role : 'system', content: 'You are an advanced mathematician.' }, -// { role : 'user', content: 'What is 1 + 1?' }, -// ], { verbose: true }), -// -//createCompletion(model, [ -// { role : 'system', content: 'You are an advanced mathematician.' }, -// { role : 'user', content: 'What is 1 + 1?' }, -//], { verbose: true }) -// -//]) -//console.log(responses.map(s => s.choices[0].message)) - diff --git a/gpt4all-bindings/typescript/spec/concurrency.mjs b/gpt4all-bindings/typescript/spec/concurrency.mjs new file mode 100644 index 00000000..55ba9046 --- /dev/null +++ b/gpt4all-bindings/typescript/spec/concurrency.mjs @@ -0,0 +1,29 @@ +import { + loadModel, + createCompletion, +} from "../src/gpt4all.js"; + +const modelOptions = { + verbose: true, +}; + +const model1 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", { + ...modelOptions, + device: "gpu", // only one model can be on gpu +}); +const model2 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", modelOptions); +const model3 = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", modelOptions); + +const promptContext = { + verbose: true, +} + +const responses = await Promise.all([ + createCompletion(model1, "What is 1 + 1?", promptContext), + // generating with the same model instance will wait for the previous completion to finish + createCompletion(model1, "What is 1 + 1?", promptContext), + // generating with different model instances will run in parallel + createCompletion(model2, "What is 1 + 2?", promptContext), + createCompletion(model3, "What is 1 + 3?", promptContext), +]); +console.log(responses.map((res) => res.choices[0].message)); diff --git a/gpt4all-bindings/typescript/spec/embed-jsonl.mjs b/gpt4all-bindings/typescript/spec/embed-jsonl.mjs new file mode 100644 index 00000000..2eb4bcab --- /dev/null +++ b/gpt4all-bindings/typescript/spec/embed-jsonl.mjs @@ -0,0 +1,26 @@ +import { loadModel, createEmbedding } from '../src/gpt4all.js' +import { createGunzip, createGzip, createUnzip } from 'node:zlib'; +import { Readable } from 'stream' +import readline from 'readline' +const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding', device: 'gpu' }) +console.log("Running with", embedder.llm.threadCount(), "threads"); + + +const unzip = createGunzip(); +const url = "https://huggingface.co/datasets/sentence-transformers/embedding-training-data/resolve/main/squad_pairs.jsonl.gz" +const stream = await fetch(url) + .then(res => Readable.fromWeb(res.body)); + +const lineReader = readline.createInterface({ + input: stream.pipe(unzip), + crlfDelay: Infinity +}) + +lineReader.on('line', line => { + //pairs of questions and answers + const question_answer = JSON.parse(line) + console.log(createEmbedding(embedder, question_answer)) +}) + +lineReader.on('close', () => embedder.dispose()) + diff --git a/gpt4all-bindings/typescript/spec/embed.mjs b/gpt4all-bindings/typescript/spec/embed.mjs index 929e321a..d3dc4e1b 100644 --- a/gpt4all-bindings/typescript/spec/embed.mjs +++ b/gpt4all-bindings/typescript/spec/embed.mjs @@ -1,6 +1,12 @@ import { loadModel, createEmbedding } from '../src/gpt4all.js' -const embedder = await loadModel("ggml-all-MiniLM-L6-v2-f16.bin", { verbose: true, type: 'embedding'}) +const embedder = await loadModel("nomic-embed-text-v1.5.f16.gguf", { verbose: true, type: 'embedding' , device: 'gpu' }) -console.log(createEmbedding(embedder, "Accept your current situation")) +try { +console.log(createEmbedding(embedder, ["Accept your current situation", "12312"], { prefix: "search_document" })) +} catch(e) { +console.log(e) +} + +embedder.dispose() diff --git a/gpt4all-bindings/typescript/spec/generator.mjs b/gpt4all-bindings/typescript/spec/generator.mjs deleted file mode 100644 index 963bdc28..00000000 --- a/gpt4all-bindings/typescript/spec/generator.mjs +++ /dev/null @@ -1,41 +0,0 @@ -import gpt from '../src/gpt4all.js' - -const model = await gpt.loadModel("mistral-7b-openorca.Q4_0.gguf", { device: 'gpu' }) - -process.stdout.write('Response: ') - - -const tokens = gpt.generateTokens(model, [{ - role: 'user', - content: "How are you ?" -}], { nPredict: 2048 }) -for await (const token of tokens){ - process.stdout.write(token); -} - - -const result = await gpt.createCompletion(model, [{ - role: 'user', - content: "You sure?" -}]) - -console.log(result) - -const result2 = await gpt.createCompletion(model, [{ - role: 'user', - content: "You sure you sure?" -}]) - -console.log(result2) - - -const tokens2 = gpt.generateTokens(model, [{ - role: 'user', - content: "If 3 + 3 is 5, what is 2 + 2?" -}], { nPredict: 2048 }) -for await (const token of tokens2){ - process.stdout.write(token); -} -console.log("done") -model.dispose(); - diff --git a/gpt4all-bindings/typescript/spec/llmodel.mjs b/gpt4all-bindings/typescript/spec/llmodel.mjs new file mode 100644 index 00000000..baa6ed76 --- /dev/null +++ b/gpt4all-bindings/typescript/spec/llmodel.mjs @@ -0,0 +1,61 @@ +import { + LLModel, + createCompletion, + DEFAULT_DIRECTORY, + DEFAULT_LIBRARIES_DIRECTORY, + loadModel, +} from "../src/gpt4all.js"; + +const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", { + verbose: true, + device: "gpu", +}); +const ll = model.llm; + +try { + class Extended extends LLModel {} +} catch (e) { + console.log("Extending from native class gone wrong " + e); +} + +console.log("state size " + ll.stateSize()); + +console.log("thread count " + ll.threadCount()); +ll.setThreadCount(5); + +console.log("thread count " + ll.threadCount()); +ll.setThreadCount(4); +console.log("thread count " + ll.threadCount()); +console.log("name " + ll.name()); +console.log("type: " + ll.type()); +console.log("Default directory for models", DEFAULT_DIRECTORY); +console.log("Default directory for libraries", DEFAULT_LIBRARIES_DIRECTORY); +console.log("Has GPU", ll.hasGpuDevice()); +console.log("gpu devices", ll.listGpu()); +console.log("Required Mem in bytes", ll.memoryNeeded()); + +// to ingest a custom system prompt without using a chat session. +await createCompletion( + model, + "<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n", + { + promptTemplate: "%1", + nPredict: 0, + special: true, + } +); +const completion1 = await createCompletion(model, "What is 1 + 1?", { + verbose: true, +}); +console.log(`🤖 > ${completion1.choices[0].message.content}`); +//Very specific: +// tested on Ubuntu 22.0, Linux Mint, if I set nPast to 100, the app hangs. +const completion2 = await createCompletion(model, "And if we add two?", { + verbose: true, +}); +console.log(`🤖 > ${completion2.choices[0].message.content}`); + +//CALLING DISPOSE WILL INVALID THE NATIVE MODEL. USE THIS TO CLEANUP +model.dispose(); + +console.log("model disposed, exiting..."); diff --git a/gpt4all-bindings/typescript/spec/long-context.mjs b/gpt4all-bindings/typescript/spec/long-context.mjs new file mode 100644 index 00000000..abe3f36d --- /dev/null +++ b/gpt4all-bindings/typescript/spec/long-context.mjs @@ -0,0 +1,21 @@ +import { promises as fs } from "node:fs"; +import { loadModel, createCompletion } from "../src/gpt4all.js"; + +const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", { + verbose: true, + device: "gpu", + nCtx: 32768, +}); + +const typeDefSource = await fs.readFile("./src/gpt4all.d.ts", "utf-8"); + +const res = await createCompletion( + model, + "Here are the type definitions for the GPT4All API:\n\n" + + typeDefSource + + "\n\nHow do I create a completion with a really large context window?", + { + verbose: true, + } +); +console.debug(res.choices[0].message); diff --git a/gpt4all-bindings/typescript/spec/model-switching.mjs b/gpt4all-bindings/typescript/spec/model-switching.mjs new file mode 100644 index 00000000..264c7156 --- /dev/null +++ b/gpt4all-bindings/typescript/spec/model-switching.mjs @@ -0,0 +1,60 @@ +import { loadModel, createCompletion } from "../src/gpt4all.js"; + +const model1 = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", { + device: "gpu", + nCtx: 4096, +}); + +const chat1 = await model1.createChatSession({ + temperature: 0.8, + topP: 0.7, + topK: 60, +}); + +const chat1turn1 = await createCompletion( + chat1, + "Outline a short story concept for adults. About why bananas are rather blue than bread is green at night sometimes. Not too long." +); +console.debug(chat1turn1.choices[0].message); + +const chat1turn2 = await createCompletion( + chat1, + "Lets sprinkle some plot twists. And a cliffhanger at the end." +); +console.debug(chat1turn2.choices[0].message); + +const chat1turn3 = await createCompletion( + chat1, + "Analyze your plot. Find the weak points." +); +console.debug(chat1turn3.choices[0].message); + +const chat1turn4 = await createCompletion( + chat1, + "Rewrite it based on the analysis." +); +console.debug(chat1turn4.choices[0].message); + +model1.dispose(); + +const model2 = await loadModel("gpt4all-falcon-newbpe-q4_0.gguf", { + device: "gpu", +}); + +const chat2 = await model2.createChatSession({ + messages: chat1.messages, +}); + +const chat2turn1 = await createCompletion( + chat2, + "Give three ideas how this plot could be improved." +); +console.debug(chat2turn1.choices[0].message); + +const chat2turn2 = await createCompletion( + chat2, + "Revise the plot, applying your ideas." +); +console.debug(chat2turn2.choices[0].message); + +model2.dispose(); diff --git a/gpt4all-bindings/typescript/spec/stateless.mjs b/gpt4all-bindings/typescript/spec/stateless.mjs new file mode 100644 index 00000000..6e3f82b2 --- /dev/null +++ b/gpt4all-bindings/typescript/spec/stateless.mjs @@ -0,0 +1,50 @@ +import { loadModel, createCompletion } from "../src/gpt4all.js"; + +const model = await loadModel("orca-mini-3b-gguf2-q4_0.gguf", { + verbose: true, + device: "gpu", +}); + +const messages = [ + { + role: "system", + content: "<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n", + }, + { + role: "user", + content: "What's 2+2?", + }, + { + role: "assistant", + content: "5", + }, + { + role: "user", + content: "Are you sure?", + }, +]; + + +const res1 = await createCompletion(model, messages); +console.debug(res1.choices[0].message); +messages.push(res1.choices[0].message); + +messages.push({ + role: "user", + content: "Could you double check that?", +}); + +const res2 = await createCompletion(model, messages); +console.debug(res2.choices[0].message); +messages.push(res2.choices[0].message); + +messages.push({ + role: "user", + content: "Let's bring out the big calculators.", +}); + +const res3 = await createCompletion(model, messages); +console.debug(res3.choices[0].message); +messages.push(res3.choices[0].message); + +// console.debug(messages); diff --git a/gpt4all-bindings/typescript/spec/streaming.mjs b/gpt4all-bindings/typescript/spec/streaming.mjs new file mode 100644 index 00000000..0dfcfd7b --- /dev/null +++ b/gpt4all-bindings/typescript/spec/streaming.mjs @@ -0,0 +1,57 @@ +import { + loadModel, + createCompletion, + createCompletionStream, + createCompletionGenerator, +} from "../src/gpt4all.js"; + +const model = await loadModel("mistral-7b-openorca.gguf2.Q4_0.gguf", { + device: "gpu", +}); + +process.stdout.write("### Stream:"); +const stream = createCompletionStream(model, "How are you?"); +stream.tokens.on("data", (data) => { + process.stdout.write(data); +}); +await stream.result; +process.stdout.write("\n"); + +process.stdout.write("### Stream with pipe:"); +const stream2 = createCompletionStream( + model, + "Please say something nice about node streams." +); +stream2.tokens.pipe(process.stdout); +const stream2Res = await stream2.result; +process.stdout.write("\n"); + +process.stdout.write("### Generator:"); +const gen = createCompletionGenerator(model, "generators instead?", { + nPast: stream2Res.usage.n_past_tokens, +}); +for await (const chunk of gen) { + process.stdout.write(chunk); +} + +process.stdout.write("\n"); + +process.stdout.write("### Callback:"); +await createCompletion(model, "Why not just callbacks?", { + onResponseToken: (tokenId, token) => { + process.stdout.write(token); + }, +}); +process.stdout.write("\n"); + +process.stdout.write("### 2nd Generator:"); +const gen2 = createCompletionGenerator(model, "If 3 + 3 is 5, what is 2 + 2?"); + +let chunk = await gen2.next(); +while (!chunk.done) { + process.stdout.write(chunk.value); + chunk = await gen2.next(); +} +process.stdout.write("\n"); +console.debug("generator finished", chunk); +model.dispose(); diff --git a/gpt4all-bindings/typescript/spec/system.mjs b/gpt4all-bindings/typescript/spec/system.mjs new file mode 100644 index 00000000..f80e3f3a --- /dev/null +++ b/gpt4all-bindings/typescript/spec/system.mjs @@ -0,0 +1,19 @@ +import { + loadModel, + createCompletion, +} from "../src/gpt4all.js"; + +const model = await loadModel("Nous-Hermes-2-Mistral-7B-DPO.Q4_0.gguf", { + verbose: true, + device: "gpu", +}); + +const chat = await model.createChatSession({ + verbose: true, + systemPrompt: "<|im_start|>system\nRoleplay as Batman. Answer as if you are Batman, never say you're an Assistant.\n<|im_end|>", +}); +const turn1 = await createCompletion(chat, "You have any plans tonight?"); +console.log(turn1.choices[0].message); +// "I'm afraid I must decline any personal invitations tonight. As Batman, I have a responsibility to protect Gotham City." + +model.dispose(); diff --git a/gpt4all-bindings/typescript/src/chat-session.js b/gpt4all-bindings/typescript/src/chat-session.js new file mode 100644 index 00000000..dcbdb7da --- /dev/null +++ b/gpt4all-bindings/typescript/src/chat-session.js @@ -0,0 +1,169 @@ +const { DEFAULT_PROMPT_CONTEXT } = require("./config"); +const { prepareMessagesForIngest } = require("./util"); + +class ChatSession { + model; + modelName; + /** + * @type {import('./gpt4all').ChatMessage[]} + */ + messages; + /** + * @type {string} + */ + systemPrompt; + /** + * @type {import('./gpt4all').LLModelPromptContext} + */ + promptContext; + /** + * @type {boolean} + */ + initialized; + + constructor(model, chatSessionOpts = {}) { + const { messages, systemPrompt, ...sessionDefaultPromptContext } = + chatSessionOpts; + this.model = model; + this.modelName = model.llm.name(); + this.messages = messages ?? []; + this.systemPrompt = systemPrompt ?? model.config.systemPrompt; + this.initialized = false; + this.promptContext = { + ...DEFAULT_PROMPT_CONTEXT, + ...sessionDefaultPromptContext, + nPast: 0, + }; + } + + async initialize(completionOpts = {}) { + if (this.model.activeChatSession !== this) { + this.model.activeChatSession = this; + } + + let tokensIngested = 0; + + // ingest system prompt + + if (this.systemPrompt) { + const systemRes = await this.model.generate(this.systemPrompt, { + promptTemplate: "%1", + nPredict: 0, + special: true, + nBatch: this.promptContext.nBatch, + // verbose: true, + }); + tokensIngested += systemRes.tokensIngested; + this.promptContext.nPast = systemRes.nPast; + } + + // ingest initial messages + if (this.messages.length > 0) { + tokensIngested += await this.ingestMessages( + this.messages, + completionOpts + ); + } + + this.initialized = true; + + return tokensIngested; + } + + async ingestMessages(messages, completionOpts = {}) { + const turns = prepareMessagesForIngest(messages); + + // send the message pairs to the model + let tokensIngested = 0; + + for (const turn of turns) { + const turnRes = await this.model.generate(turn.user, { + ...this.promptContext, + ...completionOpts, + fakeReply: turn.assistant, + }); + tokensIngested += turnRes.tokensIngested; + this.promptContext.nPast = turnRes.nPast; + } + return tokensIngested; + } + + async generate(input, completionOpts = {}) { + if (this.model.activeChatSession !== this) { + throw new Error( + "Chat session is not active. Create a new chat session or call initialize to continue." + ); + } + if (completionOpts.nPast > this.promptContext.nPast) { + throw new Error( + `nPast cannot be greater than ${this.promptContext.nPast}.` + ); + } + let tokensIngested = 0; + + if (!this.initialized) { + tokensIngested += await this.initialize(completionOpts); + } + + let prompt = input; + + if (Array.isArray(input)) { + // assuming input is a messages array + // -> tailing user message will be used as the final prompt. its optional. + // -> all system messages will be ignored. + // -> all other messages will be ingested with fakeReply + // -> user/assistant messages will be pushed into the messages array + + let tailingUserMessage = ""; + let messagesToIngest = input; + + const lastMessage = input[input.length - 1]; + if (lastMessage.role === "user") { + tailingUserMessage = lastMessage.content; + messagesToIngest = input.slice(0, input.length - 1); + } + + if (messagesToIngest.length > 0) { + tokensIngested += await this.ingestMessages( + messagesToIngest, + completionOpts + ); + this.messages.push(...messagesToIngest); + } + + if (tailingUserMessage) { + prompt = tailingUserMessage; + } else { + return { + text: "", + nPast: this.promptContext.nPast, + tokensIngested, + tokensGenerated: 0, + }; + } + } + + const result = await this.model.generate(prompt, { + ...this.promptContext, + ...completionOpts, + }); + + this.promptContext.nPast = result.nPast; + result.tokensIngested += tokensIngested; + + this.messages.push({ + role: "user", + content: prompt, + }); + this.messages.push({ + role: "assistant", + content: result.text, + }); + + return result; + } +} + +module.exports = { + ChatSession, +}; diff --git a/gpt4all-bindings/typescript/src/config.js b/gpt4all-bindings/typescript/src/config.js index 85f51d7a..29ce8e49 100644 --- a/gpt4all-bindings/typescript/src/config.js +++ b/gpt4all-bindings/typescript/src/config.js @@ -27,15 +27,16 @@ const DEFAULT_MODEL_CONFIG = { promptTemplate: "### Human:\n%1\n\n### Assistant:\n", } -const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models2.json"; +const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models3.json"; const DEFAULT_PROMPT_CONTEXT = { - temp: 0.7, + temp: 0.1, topK: 40, - topP: 0.4, + topP: 0.9, + minP: 0.0, repeatPenalty: 1.18, - repeatLastN: 64, - nBatch: 8, + repeatLastN: 10, + nBatch: 100, } module.exports = { diff --git a/gpt4all-bindings/typescript/src/gpt4all.d.ts b/gpt4all-bindings/typescript/src/gpt4all.d.ts index 18f1eb95..4d9bfdcc 100644 --- a/gpt4all-bindings/typescript/src/gpt4all.d.ts +++ b/gpt4all-bindings/typescript/src/gpt4all.d.ts @@ -1,43 +1,11 @@ /// declare module "gpt4all"; -type ModelType = "gptj" | "llama" | "mpt" | "replit"; - -// NOTE: "deprecated" tag in below comment breaks the doc generator https://github.com/documentationjs/documentation/issues/1596 -/** - * Full list of models available - * DEPRECATED!! These model names are outdated and this type will not be maintained, please use a string literal instead - */ -interface ModelFile { - /** List of GPT-J Models */ - gptj: - | "ggml-gpt4all-j-v1.3-groovy.bin" - | "ggml-gpt4all-j-v1.2-jazzy.bin" - | "ggml-gpt4all-j-v1.1-breezy.bin" - | "ggml-gpt4all-j.bin"; - /** List Llama Models */ - llama: - | "ggml-gpt4all-l13b-snoozy.bin" - | "ggml-vicuna-7b-1.1-q4_2.bin" - | "ggml-vicuna-13b-1.1-q4_2.bin" - | "ggml-wizardLM-7B.q4_2.bin" - | "ggml-stable-vicuna-13B.q4_2.bin" - | "ggml-nous-gpt4-vicuna-13b.bin" - | "ggml-v3-13b-hermes-q5_1.bin"; - /** List of MPT Models */ - mpt: - | "ggml-mpt-7b-base.bin" - | "ggml-mpt-7b-chat.bin" - | "ggml-mpt-7b-instruct.bin"; - /** List of Replit Models */ - replit: "ggml-replit-code-v1-3b.bin"; -} - interface LLModelOptions { /** * Model architecture. This argument currently does not have any functionality and is just used as descriptive identifier for user. */ - type?: ModelType; + type?: string; model_name: string; model_path: string; library_path?: string; @@ -51,47 +19,259 @@ interface ModelConfig { } /** - * Callback for controlling token generation + * Options for the chat session. */ -type TokenCallback = (tokenId: number, token: string, total: string) => boolean +interface ChatSessionOptions extends Partial { + /** + * System prompt to ingest on initialization. + */ + systemPrompt?: string; -/** - * - * InferenceModel represents an LLM which can make chat predictions, similar to GPT transformers. - * - */ -declare class InferenceModel { - constructor(llm: LLModel, config: ModelConfig); - llm: LLModel; - config: ModelConfig; - - generate( - prompt: string, - options?: Partial, - callback?: TokenCallback - ): Promise; - - /** - * delete and cleanup the native model - */ - dispose(): void + /** + * Messages to ingest on initialization. + */ + messages?: ChatMessage[]; } +/** + * ChatSession utilizes an InferenceModel for efficient processing of chat conversations. + */ +declare class ChatSession implements CompletionProvider { + /** + * Constructs a new ChatSession using the provided InferenceModel and options. + * Does not set the chat session as the active chat session until initialize is called. + * @param {InferenceModel} model An InferenceModel instance. + * @param {ChatSessionOptions} [options] Options for the chat session including default completion options. + */ + constructor(model: InferenceModel, options?: ChatSessionOptions); + /** + * The underlying InferenceModel used for generating completions. + */ + model: InferenceModel; + /** + * The name of the model. + */ + modelName: string; + /** + * The messages that have been exchanged in this chat session. + */ + messages: ChatMessage[]; + /** + * The system prompt that has been ingested at the beginning of the chat session. + */ + systemPrompt: string; + /** + * The current prompt context of the chat session. + */ + promptContext: LLModelPromptContext; + + /** + * Ingests system prompt and initial messages. + * Sets this chat session as the active chat session of the model. + * @param {CompletionOptions} [options] Set completion options for initialization. + * @returns {Promise} The number of tokens ingested during initialization. systemPrompt + messages. + */ + initialize(completionOpts?: CompletionOptions): Promise; + + /** + * Prompts the model in chat-session context. + * @param {CompletionInput} input Input string or message array. + * @param {CompletionOptions} [options] Set completion options for this generation. + * @returns {Promise} The inference result. + * @throws {Error} If the chat session is not the active chat session of the model. + * @throws {Error} If nPast is set to a value higher than what has been ingested in the session. + */ + generate( + input: CompletionInput, + options?: CompletionOptions + ): Promise; +} + +/** + * Shape of InferenceModel generations. + */ +interface InferenceResult extends LLModelInferenceResult { + tokensIngested: number; + tokensGenerated: number; +} + +/** + * InferenceModel represents an LLM which can make next-token predictions. + */ +declare class InferenceModel implements CompletionProvider { + constructor(llm: LLModel, config: ModelConfig); + /** The native LLModel */ + llm: LLModel; + /** The configuration the instance was constructed with. */ + config: ModelConfig; + /** The active chat session of the model. */ + activeChatSession?: ChatSession; + /** The name of the model. */ + modelName: string; + + /** + * Create a chat session with the model and set it as the active chat session of this model. + * A model instance can only have one active chat session at a time. + * @param {ChatSessionOptions} options The options for the chat session. + * @returns {Promise} The chat session. + */ + createChatSession(options?: ChatSessionOptions): Promise; + + /** + * Prompts the model with a given input and optional parameters. + * @param {CompletionInput} input The prompt input. + * @param {CompletionOptions} options Prompt context and other options. + * @returns {Promise} The model's response to the prompt. + * @throws {Error} If nPast is set to a value smaller than 0. + * @throws {Error} If a messages array without a tailing user message is provided. + */ + generate( + prompt: string, + options?: CompletionOptions + ): Promise; + + /** + * delete and cleanup the native model + */ + dispose(): void; +} + +/** + * Options for generating one or more embeddings. + */ +interface EmbedddingOptions { + /** + * The model-specific prefix representing the embedding task, without the trailing colon. For Nomic Embed + * this can be `search_query`, `search_document`, `classification`, or `clustering`. + */ + prefix?: string; + /** + *The embedding dimension, for use with Matryoshka-capable models. Defaults to full-size. + * @default determines on the model being used. + */ + dimensionality?: number; + /** + * How to handle texts longer than the model can accept. One of `mean` or `truncate`. + * @default "mean" + */ + longTextMode?: "mean" | "truncate"; + /** + * Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens + * with long_text_mode="mean" will raise an error. Disabled by default. + * @default false + */ + atlas?: boolean; +} + +/** + * The nodejs moral equivalent to python binding's Embed4All().embed() + * meow + * @param {EmbeddingModel} model The embedding model instance. + * @param {string} text Text to embed. + * @param {EmbeddingOptions} options Optional parameters for the embedding. + * @returns {EmbeddingResult} The embedding result. + * @throws {Error} If dimensionality is set to a value smaller than 1. + */ +declare function createEmbedding( + model: EmbeddingModel, + text: string, + options?: EmbedddingOptions +): EmbeddingResult; + +/** + * Overload that takes multiple strings to embed. + * @param {EmbeddingModel} model The embedding model instance. + * @param {string[]} texts Texts to embed. + * @param {EmbeddingOptions} options Optional parameters for the embedding. + * @returns {EmbeddingResult} The embedding result. + * @throws {Error} If dimensionality is set to a value smaller than 1. + */ +declare function createEmbedding( + model: EmbeddingModel, + text: string[], + options?: EmbedddingOptions +): EmbeddingResult; + +/** + * The resulting embedding. + */ +interface EmbeddingResult { + /** + * Encoded token count. Includes overlap but specifically excludes tokens used for the prefix/task_type, BOS/CLS token, and EOS/SEP token + **/ + n_prompt_tokens: number; + + embeddings: T; +} /** * EmbeddingModel represents an LLM which can create embeddings, which are float arrays */ declare class EmbeddingModel { constructor(llm: LLModel, config: ModelConfig); + /** The native LLModel */ llm: LLModel; + /** The configuration the instance was constructed with. */ config: ModelConfig; - embed(text: string): Float32Array; + /** + * Create an embedding from a given input string. See EmbeddingOptions. + * @param {string} text + * @param {string} prefix + * @param {number} dimensionality + * @param {boolean} doMean + * @param {boolean} atlas + * @returns {EmbeddingResult} The embedding result. + */ + embed( + text: string, + prefix: string, + dimensionality: number, + doMean: boolean, + atlas: boolean + ): EmbeddingResult; + /** + * Create an embedding from a given input text array. See EmbeddingOptions. + * @param {string[]} text + * @param {string} prefix + * @param {number} dimensionality + * @param {boolean} doMean + * @param {boolean} atlas + * @returns {EmbeddingResult} The embedding result. + */ + embed( + text: string[], + prefix: string, + dimensionality: number, + doMean: boolean, + atlas: boolean + ): EmbeddingResult; /** - * delete and cleanup the native model + * delete and cleanup the native model */ - dispose(): void + dispose(): void; +} +/** + * Shape of LLModel's inference result. + */ +interface LLModelInferenceResult { + text: string; + nPast: number; +} + +interface LLModelInferenceOptions extends Partial { + /** Callback for response tokens, called for each generated token. + * @param {number} tokenId The token id. + * @param {string} token The token. + * @returns {boolean | undefined} Whether to continue generating tokens. + * */ + onResponseToken?: (tokenId: number, token: string) => boolean | void; + /** Callback for prompt tokens, called for each input token in the prompt. + * @param {number} tokenId The token id. + * @returns {boolean | undefined} Whether to continue ingesting the prompt. + * */ + onPromptToken?: (tokenId: number) => boolean | void; } /** @@ -101,14 +281,13 @@ declare class EmbeddingModel { declare class LLModel { /** * Initialize a new LLModel. - * @param path Absolute path to the model file. + * @param {string} path Absolute path to the model file. * @throws {Error} If the model file does not exist. */ - constructor(path: string); constructor(options: LLModelOptions); - /** either 'gpt', mpt', or 'llama' or undefined */ - type(): ModelType | undefined; + /** undefined or user supplied */ + type(): string | undefined; /** The name of the model. */ name(): string; @@ -134,29 +313,53 @@ declare class LLModel { setThreadCount(newNumber: number): void; /** - * Prompt the model with a given input and optional parameters. - * This is the raw output from model. - * Use the prompt function exported for a value - * @param q The prompt input. - * @param params Optional parameters for the prompt context. - * @param callback - optional callback to control token generation. - * @returns The result of the model prompt. + * Prompt the model directly with a given input string and optional parameters. + * Use the higher level createCompletion methods for a more user-friendly interface. + * @param {string} prompt The prompt input. + * @param {LLModelInferenceOptions} options Optional parameters for the generation. + * @returns {LLModelInferenceResult} The response text and final context size. */ - raw_prompt( - q: string, - params: Partial, - callback?: TokenCallback - ): Promise + infer( + prompt: string, + options: LLModelInferenceOptions + ): Promise; /** - * Embed text with the model. Keep in mind that - * not all models can embed text, (only bert can embed as of 07/16/2023 (mm/dd/yyyy)) - * Use the prompt function exported for a value - * @param q The prompt input. - * @param params Optional parameters for the prompt context. - * @returns The result of the model prompt. + * Embed text with the model. See EmbeddingOptions for more information. + * Use the higher level createEmbedding methods for a more user-friendly interface. + * @param {string} text + * @param {string} prefix + * @param {number} dimensionality + * @param {boolean} doMean + * @param {boolean} atlas + * @returns {Float32Array} The embedding of the text. */ - embed(text: string): Float32Array; + embed( + text: string, + prefix: string, + dimensionality: number, + doMean: boolean, + atlas: boolean + ): Float32Array; + + /** + * Embed multiple texts with the model. See EmbeddingOptions for more information. + * Use the higher level createEmbedding methods for a more user-friendly interface. + * @param {string[]} texts + * @param {string} prefix + * @param {number} dimensionality + * @param {boolean} doMean + * @param {boolean} atlas + * @returns {Float32Array[]} The embeddings of the texts. + */ + embed( + texts: string, + prefix: string, + dimensionality: number, + doMean: boolean, + atlas: boolean + ): Float32Array[]; + /** * Whether the model is loaded or not. */ @@ -166,81 +369,97 @@ declare class LLModel { * Where to search for the pluggable backend libraries */ setLibraryPath(s: string): void; + /** * Where to get the pluggable backend libraries */ getLibraryPath(): string; + /** - * Initiate a GPU by a string identifier. - * @param {number} memory_required Should be in the range size_t or will throw + * Initiate a GPU by a string identifier. + * @param {number} memory_required Should be in the range size_t or will throw * @param {string} device_name 'amd' | 'nvidia' | 'intel' | 'gpu' | gpu name. * read LoadModelOptions.device for more information */ - initGpuByString(memory_required: number, device_name: string): boolean + initGpuByString(memory_required: number, device_name: string): boolean; + /** * From C documentation * @returns True if a GPU device is successfully initialized, false otherwise. */ - hasGpuDevice(): boolean - /** - * GPUs that are usable for this LLModel - * @param nCtx Maximum size of context window - * @throws if hasGpuDevice returns false (i think) - * @returns - */ - listGpu(nCtx: number) : GpuDevice[] + hasGpuDevice(): boolean; /** - * delete and cleanup the native model + * GPUs that are usable for this LLModel + * @param {number} nCtx Maximum size of context window + * @throws if hasGpuDevice returns false (i think) + * @returns */ - dispose(): void + listGpu(nCtx: number): GpuDevice[]; + + /** + * delete and cleanup the native model + */ + dispose(): void; } -/** +/** * an object that contains gpu data on this machine. */ interface GpuDevice { index: number; /** - * same as VkPhysicalDeviceType + * same as VkPhysicalDeviceType */ - type: number; - heapSize : number; + type: number; + heapSize: number; name: string; vendor: string; } /** - * Options that configure a model's behavior. - */ + * Options that configure a model's behavior. + */ interface LoadModelOptions { + /** + * Where to look for model files. + */ modelPath?: string; + /** + * Where to look for the backend libraries. + */ librariesPath?: string; + /** + * The path to the model configuration file, useful for offline usage or custom model configurations. + */ modelConfigFile?: string; + /** + * Whether to allow downloading the model if it is not present at the specified path. + */ allowDownload?: boolean; + /** + * Enable verbose logging. + */ verbose?: boolean; - /* The processing unit on which the model will run. It can be set to + /** + * The processing unit on which the model will run. It can be set to * - "cpu": Model will run on the central processing unit. * - "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor. * - "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor. - - Alternatively, a specific GPU name can also be provided, and the model will run on the GPU that matches the name - if it's available. - - Default is "cpu". - - Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All - instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the - model. - */ + * - "gpu name": Model will run on the GPU that matches the name if it's available. + * Note: If a GPU device lacks sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All + * instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the + * model. + * @default "cpu" + */ device?: string; - /* + /** * The Maximum window size of this model - * Default of 2048 + * @default 2048 */ nCtx?: number; - /* + /** * Number of gpu layers needed - * Default of 100 + * @default 100 */ ngl?: number; } @@ -277,66 +496,84 @@ declare function loadModel( ): Promise; /** - * The nodejs equivalent to python binding's chat_completion - * @param {InferenceModel} model - The language model object. - * @param {PromptMessage[]} messages - The array of messages for the conversation. - * @param {CompletionOptions} options - The options for creating the completion. - * @returns {CompletionReturn} The completion result. + * Interface for createCompletion methods, implemented by InferenceModel and ChatSession. + * Implement your own CompletionProvider or extend ChatSession to generate completions with custom logic. */ -declare function createCompletion( - model: InferenceModel, - messages: PromptMessage[], - options?: CompletionOptions -): Promise; - -/** - * The nodejs moral equivalent to python binding's Embed4All().embed() - * meow - * @param {EmbeddingModel} model - The language model object. - * @param {string} text - text to embed - * @returns {Float32Array} The completion result. - */ -declare function createEmbedding( - model: EmbeddingModel, - text: string -): Float32Array; - -/** - * The options for creating the completion. - */ -interface CompletionOptions extends Partial { - /** - * Indicates if verbose logging is enabled. - * @default true - */ - verbose?: boolean; - - /** - * Template for the system message. Will be put before the conversation with %1 being replaced by all system messages. - * Note that if this is not defined, system messages will not be included in the prompt. - */ - systemPromptTemplate?: string; - - /** - * Template for user messages, with %1 being replaced by the message. - */ - promptTemplate?: boolean; - - /** - * The initial instruction for the model, on top of the prompt - */ - promptHeader?: string; - - /** - * The last instruction for the model, appended to the end of the prompt. - */ - promptFooter?: string; +interface CompletionProvider { + modelName: string; + generate( + input: CompletionInput, + options?: CompletionOptions + ): Promise; } /** - * A message in the conversation, identical to OpenAI's chat message. + * Options for creating a completion. */ -interface PromptMessage { +interface CompletionOptions extends LLModelInferenceOptions { + /** + * Indicates if verbose logging is enabled. + * @default false + */ + verbose?: boolean; +} + +/** + * The input for creating a completion. May be a string or an array of messages. + */ +type CompletionInput = string | ChatMessage[]; + +/** + * The nodejs equivalent to python binding's chat_completion + * @param {CompletionProvider} provider - The inference model object or chat session + * @param {CompletionInput} input - The input string or message array + * @param {CompletionOptions} options - The options for creating the completion. + * @returns {CompletionResult} The completion result. + */ +declare function createCompletion( + provider: CompletionProvider, + input: CompletionInput, + options?: CompletionOptions +): Promise; + +/** + * Streaming variant of createCompletion, returns a stream of tokens and a promise that resolves to the completion result. + * @param {CompletionProvider} provider - The inference model object or chat session + * @param {CompletionInput} input - The input string or message array + * @param {CompletionOptions} options - The options for creating the completion. + * @returns {CompletionStreamReturn} An object of token stream and the completion result promise. + */ +declare function createCompletionStream( + provider: CompletionProvider, + input: CompletionInput, + options?: CompletionOptions +): CompletionStreamReturn; + +/** + * The result of a streamed completion, containing a stream of tokens and a promise that resolves to the completion result. + */ +interface CompletionStreamReturn { + tokens: NodeJS.ReadableStream; + result: Promise; +} + +/** + * Async generator variant of createCompletion, yields tokens as they are generated and returns the completion result. + * @param {CompletionProvider} provider - The inference model object or chat session + * @param {CompletionInput} input - The input string or message array + * @param {CompletionOptions} options - The options for creating the completion. + * @returns {AsyncGenerator} The stream of generated tokens + */ +declare function createCompletionGenerator( + provider: CompletionProvider, + input: CompletionInput, + options: CompletionOptions +): AsyncGenerator; + +/** + * A message in the conversation. + */ +interface ChatMessage { /** The role of the message. */ role: "system" | "assistant" | "user"; @@ -345,34 +582,31 @@ interface PromptMessage { } /** - * The result of the completion, similar to OpenAI's format. + * The result of a completion. */ -interface CompletionReturn { +interface CompletionResult { /** The model used for the completion. */ model: string; /** Token usage report. */ usage: { - /** The number of tokens used in the prompt. */ + /** The number of tokens ingested during the completion. */ prompt_tokens: number; - /** The number of tokens used in the completion. */ + /** The number of tokens generated in the completion. */ completion_tokens: number; /** The total number of tokens used. */ total_tokens: number; + + /** Number of tokens used in the conversation. */ + n_past_tokens: number; }; - /** The generated completions. */ - choices: CompletionChoice[]; -} - -/** - * A completion choice, similar to OpenAI's format. - */ -interface CompletionChoice { - /** Response message */ - message: PromptMessage; + /** The generated completion. */ + choices: Array<{ + message: ChatMessage; + }>; } /** @@ -385,19 +619,33 @@ interface LLModelPromptContext { /** The size of the raw tokens vector. */ tokensSize: number; - /** The number of tokens in the past conversation. */ + /** The number of tokens in the past conversation. + * This may be used to "roll back" the conversation to a previous state. + * Note that for most use cases the default value should be sufficient and this should not be set. + * @default 0 For completions using InferenceModel, meaning the model will only consider the input prompt. + * @default nPast For completions using ChatSession. This means the context window will be automatically determined + * and possibly resized (see contextErase) to keep the conversation performant. + * */ nPast: number; - /** The number of tokens possible in the context window. - * @default 1024 - */ - nCtx: number; - - /** The number of tokens to predict. - * @default 128 + /** The maximum number of tokens to predict. + * @default 4096 * */ nPredict: number; + /** Template for user / assistant message pairs. + * %1 is required and will be replaced by the user input. + * %2 is optional and will be replaced by the assistant response. If not present, the assistant response will be appended. + */ + promptTemplate?: string; + + /** The context window size. Do not use, it has no effect. See loadModel options. + * THIS IS DEPRECATED!!! + * Use loadModel's nCtx option instead. + * @default 2048 + */ + nCtx: number; + /** The top-k logits to sample from. * Top-K sampling selects the next token only from the top K most likely tokens predicted by the model. * It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit @@ -409,26 +657,33 @@ interface LLModelPromptContext { topK: number; /** The nucleus sampling probability threshold. - * Top-P limits the selection of the next token to a subset of tokens with a cumulative probability + * Top-P limits the selection of the next token to a subset of tokens with a cumulative probability * above a threshold P. This method, also known as nucleus sampling, finds a balance between diversity * and quality by considering both token probabilities and the number of tokens available for sampling. * When using a higher value for top-P (eg., 0.95), the generated text becomes more diverse. * On the other hand, a lower value (eg., 0.1) produces more focused and conservative text. - * The default value is 0.4, which is aimed to be the middle ground between focus and diversity, but - * for more creative tasks a higher top-p value will be beneficial, about 0.5-0.9 is a good range for that. - * @default 0.4 + * @default 0.9 + * * */ topP: number; + /** + * The minimum probability of a token to be considered. + * @default 0.0 + */ + minP: number; + /** The temperature to adjust the model's output distribution. * Temperature is like a knob that adjusts how creative or focused the output becomes. Higher temperatures * (eg., 1.2) increase randomness, resulting in more imaginative and diverse text. Lower temperatures (eg., 0.5) * make the output more focused, predictable, and conservative. When the temperature is set to 0, the output * becomes completely deterministic, always selecting the most probable next token and producing identical results - * each time. A safe range would be around 0.6 - 0.85, but you are free to search what value fits best for you. - * @default 0.7 + * each time. Try what value fits best for your use case and model. + * @default 0.1 + * @alias temperature * */ temp: number; + temperature: number; /** The number of predictions to generate in parallel. * By splitting the prompt every N tokens, prompt-batch-size reduces RAM usage during processing. However, @@ -451,31 +706,17 @@ interface LLModelPromptContext { * The repeat-penalty-tokens N option controls the number of tokens in the history to consider for penalizing repetition. * A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only * consider recent tokens. - * @default 64 + * @default 10 * */ repeatLastN: number; /** The percentage of context to erase if the context window is exceeded. - * @default 0.5 + * Set it to a lower value to keep context for longer at the cost of performance. + * @default 0.75 * */ contextErase: number; } - -/** - * Creates an async generator of tokens - * @param {InferenceModel} llmodel - The language model object. - * @param {PromptMessage[]} messages - The array of messages for the conversation. - * @param {CompletionOptions} options - The options for creating the completion. - * @param {TokenCallback} callback - optional callback to control token generation. - * @returns {AsyncGenerator} The stream of generated tokens - */ -declare function generateTokens( - llmodel: InferenceModel, - messages: PromptMessage[], - options: CompletionOptions, - callback?: TokenCallback -): AsyncGenerator; /** * From python api: * models will be stored in (homedir)/.cache/gpt4all/` @@ -508,7 +749,7 @@ declare const DEFAULT_MODEL_LIST_URL: string; * Initiates the download of a model file. * By default this downloads without waiting. use the controller returned to alter this behavior. * @param {string} modelName - The model to be downloaded. - * @param {DownloadOptions} options - to pass into the downloader. Default is { location: (cwd), verbose: false }. + * @param {DownloadModelOptions} options - to pass into the downloader. Default is { location: (cwd), verbose: false }. * @returns {DownloadController} object that allows controlling the download process. * * @throws {Error} If the model already exists in the specified location. @@ -556,7 +797,9 @@ interface ListModelsOptions { file?: string; } -declare function listModels(options?: ListModelsOptions): Promise; +declare function listModels( + options?: ListModelsOptions +): Promise; interface RetrieveModelOptions { allowDownload?: boolean; @@ -581,30 +824,35 @@ interface DownloadController { } export { - ModelType, - ModelFile, - ModelConfig, - InferenceModel, - EmbeddingModel, LLModel, LLModelPromptContext, - PromptMessage, + ModelConfig, + InferenceModel, + InferenceResult, + EmbeddingModel, + EmbeddingResult, + ChatSession, + ChatMessage, + CompletionInput, + CompletionProvider, CompletionOptions, + CompletionResult, LoadModelOptions, + DownloadController, + RetrieveModelOptions, + DownloadModelOptions, + GpuDevice, loadModel, + downloadModel, + retrieveModel, + listModels, createCompletion, + createCompletionStream, + createCompletionGenerator, createEmbedding, - generateTokens, DEFAULT_DIRECTORY, DEFAULT_LIBRARIES_DIRECTORY, DEFAULT_MODEL_CONFIG, DEFAULT_PROMPT_CONTEXT, DEFAULT_MODEL_LIST_URL, - downloadModel, - retrieveModel, - listModels, - DownloadController, - RetrieveModelOptions, - DownloadModelOptions, - GpuDevice }; diff --git a/gpt4all-bindings/typescript/src/gpt4all.js b/gpt4all-bindings/typescript/src/gpt4all.js index af19dfa2..aab01d91 100644 --- a/gpt4all-bindings/typescript/src/gpt4all.js +++ b/gpt4all-bindings/typescript/src/gpt4all.js @@ -2,8 +2,10 @@ /// This file implements the gpt4all.d.ts file endings. /// Written in commonjs to support both ESM and CJS projects. -const { existsSync } = require("fs"); +const { existsSync } = require("node:fs"); const path = require("node:path"); +const Stream = require("node:stream"); +const assert = require("node:assert"); const { LLModel } = require("node-gyp-build")(path.resolve(__dirname, "..")); const { retrieveModel, @@ -18,15 +20,14 @@ const { DEFAULT_MODEL_LIST_URL, } = require("./config.js"); const { InferenceModel, EmbeddingModel } = require("./models.js"); -const Stream = require('stream') -const assert = require("assert"); +const { ChatSession } = require("./chat-session.js"); /** * Loads a machine learning model with the specified name. The defacto way to create a model. * By default this will download a model from the official GPT4ALL website, if a model is not present at given path. * * @param {string} modelName - The name of the model to load. - * @param {LoadModelOptions|undefined} [options] - (Optional) Additional options for loading the model. + * @param {import('./gpt4all').LoadModelOptions|undefined} [options] - (Optional) Additional options for loading the model. * @returns {Promise} A promise that resolves to an instance of the loaded LLModel. */ async function loadModel(modelName, options = {}) { @@ -35,10 +36,10 @@ async function loadModel(modelName, options = {}) { librariesPath: DEFAULT_LIBRARIES_DIRECTORY, type: "inference", allowDownload: true, - verbose: true, - device: 'cpu', + verbose: false, + device: "cpu", nCtx: 2048, - ngl : 100, + ngl: 100, ...options, }; @@ -49,12 +50,14 @@ async function loadModel(modelName, options = {}) { verbose: loadOptions.verbose, }); - assert.ok(typeof loadOptions.librariesPath === 'string'); + assert.ok( + typeof loadOptions.librariesPath === "string", + "Libraries path should be a string" + ); const existingPaths = loadOptions.librariesPath .split(";") .filter(existsSync) - .join(';'); - console.log("Passing these paths into runtime library search:", existingPaths) + .join(";"); const llmOptions = { model_name: appendBinSuffixIfMissing(modelName), @@ -62,13 +65,15 @@ async function loadModel(modelName, options = {}) { library_path: existingPaths, device: loadOptions.device, nCtx: loadOptions.nCtx, - ngl: loadOptions.ngl + ngl: loadOptions.ngl, }; if (loadOptions.verbose) { - console.debug("Creating LLModel with options:", llmOptions); + console.debug("Creating LLModel:", { + llmOptions, + modelConfig, + }); } - console.log(modelConfig) const llmodel = new LLModel(llmOptions); if (loadOptions.type === "embedding") { return new EmbeddingModel(llmodel, modelConfig); @@ -79,75 +84,43 @@ async function loadModel(modelName, options = {}) { } } -/** - * Formats a list of messages into a single prompt string. - */ -function formatChatPrompt( - messages, - { - systemPromptTemplate, - defaultSystemPrompt, - promptTemplate, - promptFooter, - promptHeader, - } -) { - const systemMessages = messages - .filter((message) => message.role === "system") - .map((message) => message.content); +function createEmbedding(model, text, options={}) { + let { + dimensionality = undefined, + longTextMode = "mean", + atlas = false, + } = options; - let fullPrompt = ""; - - if (promptHeader) { - fullPrompt += promptHeader + "\n\n"; - } - - if (systemPromptTemplate) { - // if user specified a template for the system prompt, put all system messages in the template - let systemPrompt = ""; - - if (systemMessages.length > 0) { - systemPrompt += systemMessages.join("\n"); - } - - if (systemPrompt) { - fullPrompt += - systemPromptTemplate.replace("%1", systemPrompt) + "\n"; - } - } else if (defaultSystemPrompt) { - // otherwise, use the system prompt from the model config and ignore system messages - fullPrompt += defaultSystemPrompt + "\n\n"; - } - - if (systemMessages.length > 0 && !systemPromptTemplate) { - console.warn( - "System messages were provided, but no systemPromptTemplate was specified. System messages will be ignored." - ); - } - - for (const message of messages) { - if (message.role === "user") { - const userMessage = promptTemplate.replace( - "%1", - message["content"] + if (dimensionality === undefined) { + dimensionality = -1; + } else { + if (dimensionality <= 0) { + throw new Error( + `Dimensionality must be undefined or a positive integer, got ${dimensionality}` ); - fullPrompt += userMessage; } - if (message["role"] == "assistant") { - const assistantMessage = message["content"] + "\n"; - fullPrompt += assistantMessage; + if (dimensionality < model.MIN_DIMENSIONALITY) { + console.warn( + `Dimensionality ${dimensionality} is less than the suggested minimum of ${model.MIN_DIMENSIONALITY}. Performance may be degraded.` + ); } } - if (promptFooter) { - fullPrompt += "\n\n" + promptFooter; + let doMean; + switch (longTextMode) { + case "mean": + doMean = true; + break; + case "truncate": + doMean = false; + break; + default: + throw new Error( + `Long text mode must be one of 'mean' or 'truncate', got ${longTextMode}` + ); } - return fullPrompt; -} - -function createEmbedding(model, text) { - return model.embed(text); + return model.embed(text, options?.prefix, dimensionality, doMean, atlas); } const defaultCompletionOptions = { @@ -155,162 +128,76 @@ const defaultCompletionOptions = { ...DEFAULT_PROMPT_CONTEXT, }; -function preparePromptAndContext(model,messages,options){ - if (options.hasDefaultHeader !== undefined) { - console.warn( - "hasDefaultHeader (bool) is deprecated and has no effect, use promptHeader (string) instead" - ); - } - - if (options.hasDefaultFooter !== undefined) { - console.warn( - "hasDefaultFooter (bool) is deprecated and has no effect, use promptFooter (string) instead" - ); - } - - const optionsWithDefaults = { +async function createCompletion( + provider, + input, + options = defaultCompletionOptions +) { + const completionOptions = { ...defaultCompletionOptions, ...options, }; - const { - verbose, - systemPromptTemplate, - promptTemplate, - promptHeader, - promptFooter, - ...promptContext - } = optionsWithDefaults; - - - const prompt = formatChatPrompt(messages, { - systemPromptTemplate, - defaultSystemPrompt: model.config.systemPrompt, - promptTemplate: promptTemplate || model.config.promptTemplate || "%1", - promptHeader: promptHeader || "", - promptFooter: promptFooter || "", - // These were the default header/footer prompts used for non-chat single turn completions. - // both seem to be working well still with some models, so keeping them here for reference. - // promptHeader: '### Instruction: The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.', - // promptFooter: '### Response:', - }); + const result = await provider.generate( + input, + completionOptions, + ); return { - prompt, promptContext, verbose - } -} - -async function createCompletion( - model, - messages, - options = defaultCompletionOptions -) { - const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options); - - if (verbose) { - console.debug("Sending Prompt:\n" + prompt); - } - - let tokensGenerated = 0 - - const response = await model.generate(prompt, promptContext,() => { - tokensGenerated++; - return true; - }); - - if (verbose) { - console.debug("Received Response:\n" + response); - } - - return { - llmodel: model.llm.name(), + model: provider.modelName, usage: { - prompt_tokens: prompt.length, - completion_tokens: tokensGenerated, - total_tokens: prompt.length + tokensGenerated, //TODO Not sure how to get tokens in prompt + prompt_tokens: result.tokensIngested, + total_tokens: result.tokensIngested + result.tokensGenerated, + completion_tokens: result.tokensGenerated, + n_past_tokens: result.nPast, }, choices: [ { message: { role: "assistant", - content: response, + content: result.text, }, + // TODO some completion APIs also provide logprobs and finish_reason, could look into adding those }, ], }; } -function _internal_createTokenStream(stream,model, - messages, - options = defaultCompletionOptions,callback = undefined) { - const { prompt, promptContext, verbose } = preparePromptAndContext(model,messages,options); +function createCompletionStream( + provider, + input, + options = defaultCompletionOptions +) { + const completionStream = new Stream.PassThrough({ + encoding: "utf-8", + }); - - if (verbose) { - console.debug("Sending Prompt:\n" + prompt); - } - - model.generate(prompt, promptContext,(tokenId, token, total) => { - stream.push(token); - - if(callback !== undefined){ - return callback(tokenId,token,total); - } - - return true; - }).then(() => { - stream.end() - }) - - return stream; -} - -function _createTokenStream(model, - messages, - options = defaultCompletionOptions,callback = undefined) { - - // Silent crash if we dont do this here - const stream = new Stream.PassThrough({ - encoding: 'utf-8' - }); - return _internal_createTokenStream(stream,model,messages,options,callback); -} - -async function* generateTokens(model, - messages, - options = defaultCompletionOptions, callback = undefined) { - const stream = _createTokenStream(model,messages,options,callback) - - let bHasFinished = false; - let activeDataCallback = undefined; - const finishCallback = () => { - bHasFinished = true; - if(activeDataCallback !== undefined){ - activeDataCallback(undefined); - } - } - - stream.on("finish",finishCallback) - - while (!bHasFinished) { - const token = await new Promise((res) => { - - activeDataCallback = (d) => { - stream.off("data",activeDataCallback) - activeDataCallback = undefined - res(d); + const completionPromise = createCompletion(provider, input, { + ...options, + onResponseToken: (tokenId, token) => { + completionStream.push(token); + if (options.onResponseToken) { + return options.onResponseToken(tokenId, token); } - stream.on('data', activeDataCallback) - }) + }, + }).then((result) => { + completionStream.push(null); + completionStream.emit("end"); + return result; + }); - if (token == undefined) { - break; - } + return { + tokens: completionStream, + result: completionPromise, + }; +} - yield token; +async function* createCompletionGenerator(provider, input, options) { + const completion = createCompletionStream(provider, input, options); + for await (const chunk of completion.tokens) { + yield chunk; } - - stream.off("finish",finishCallback); + return await completion.result; } module.exports = { @@ -322,10 +209,12 @@ module.exports = { LLModel, InferenceModel, EmbeddingModel, + ChatSession, createCompletion, + createCompletionStream, + createCompletionGenerator, createEmbedding, downloadModel, retrieveModel, loadModel, - generateTokens }; diff --git a/gpt4all-bindings/typescript/src/models.js b/gpt4all-bindings/typescript/src/models.js index 31fe8001..2c516ccb 100644 --- a/gpt4all-bindings/typescript/src/models.js +++ b/gpt4all-bindings/typescript/src/models.js @@ -1,18 +1,138 @@ -const { normalizePromptContext, warnOnSnakeCaseKeys } = require('./util'); +const { DEFAULT_PROMPT_CONTEXT } = require("./config"); +const { ChatSession } = require("./chat-session"); +const { prepareMessagesForIngest } = require("./util"); class InferenceModel { llm; + modelName; config; + activeChatSession; constructor(llmodel, config) { this.llm = llmodel; this.config = config; + this.modelName = this.llm.name(); } - async generate(prompt, promptContext,callback) { - warnOnSnakeCaseKeys(promptContext); - const normalizedPromptContext = normalizePromptContext(promptContext); - const result = this.llm.raw_prompt(prompt, normalizedPromptContext,callback); + async createChatSession(options) { + const chatSession = new ChatSession(this, options); + await chatSession.initialize(); + this.activeChatSession = chatSession; + return this.activeChatSession; + } + + async generate(input, options = DEFAULT_PROMPT_CONTEXT) { + const { verbose, ...otherOptions } = options; + const promptContext = { + promptTemplate: this.config.promptTemplate, + temp: + otherOptions.temp ?? + otherOptions.temperature ?? + DEFAULT_PROMPT_CONTEXT.temp, + ...otherOptions, + }; + + if (promptContext.nPast < 0) { + throw new Error("nPast must be a non-negative integer."); + } + + if (verbose) { + console.debug("Generating completion", { + input, + promptContext, + }); + } + + let prompt = input; + let nPast = promptContext.nPast; + let tokensIngested = 0; + + if (Array.isArray(input)) { + // assuming input is a messages array + // -> tailing user message will be used as the final prompt. its required. + // -> leading system message will be ingested as systemPrompt, further system messages will be ignored + // -> all other messages will be ingested with fakeReply + // -> model/context will only be kept for this completion; "stateless" + nPast = 0; + const messages = [...input]; + const lastMessage = input[input.length - 1]; + if (lastMessage.role !== "user") { + // this is most likely a user error + throw new Error("The final message must be of role 'user'."); + } + if (input[0].role === "system") { + // needs to be a pre-templated prompt ala '<|im_start|>system\nYou are an advanced mathematician.\n<|im_end|>\n' + const systemPrompt = input[0].content; + const systemRes = await this.llm.infer(systemPrompt, { + promptTemplate: "%1", + nPredict: 0, + special: true, + }); + nPast = systemRes.nPast; + tokensIngested += systemRes.tokensIngested; + messages.shift(); + } + + prompt = lastMessage.content; + const messagesToIngest = messages.slice(0, input.length - 1); + const turns = prepareMessagesForIngest(messagesToIngest); + + for (const turn of turns) { + const turnRes = await this.llm.infer(turn.user, { + ...promptContext, + nPast, + fakeReply: turn.assistant, + }); + tokensIngested += turnRes.tokensIngested; + nPast = turnRes.nPast; + } + } + + let tokensGenerated = 0; + + const result = await this.llm.infer(prompt, { + ...promptContext, + nPast, + onPromptToken: (tokenId) => { + let continueIngestion = true; + tokensIngested++; + if (options.onPromptToken) { + // catch errors because if they go through cpp they will loose stacktraces + try { + // don't cancel ingestion unless user explicitly returns false + continueIngestion = + options.onPromptToken(tokenId) !== false; + } catch (e) { + console.error("Error in onPromptToken callback", e); + continueIngestion = false; + } + } + return continueIngestion; + }, + onResponseToken: (tokenId, token) => { + let continueGeneration = true; + tokensGenerated++; + if (options.onResponseToken) { + try { + // don't cancel the generation unless user explicitly returns false + continueGeneration = + options.onResponseToken(tokenId, token) !== false; + } catch (err) { + console.error("Error in onResponseToken callback", err); + continueGeneration = false; + } + } + return continueGeneration; + }, + }); + + result.tokensGenerated = tokensGenerated; + result.tokensIngested = tokensIngested; + + if (verbose) { + console.debug("Finished completion:\n", result); + } + return result; } @@ -24,14 +144,14 @@ class InferenceModel { class EmbeddingModel { llm; config; - + MIN_DIMENSIONALITY = 64; constructor(llmodel, config) { this.llm = llmodel; this.config = config; } - embed(text) { - return this.llm.embed(text) + embed(text, prefix, dimensionality, do_mean, atlas) { + return this.llm.embed(text, prefix, dimensionality, do_mean, atlas); } dispose() { @@ -39,7 +159,6 @@ class EmbeddingModel { } } - module.exports = { InferenceModel, EmbeddingModel, diff --git a/gpt4all-bindings/typescript/src/util.js b/gpt4all-bindings/typescript/src/util.js index a0923f93..b9c9979b 100644 --- a/gpt4all-bindings/typescript/src/util.js +++ b/gpt4all-bindings/typescript/src/util.js @@ -1,8 +1,7 @@ -const { createWriteStream, existsSync, statSync } = require("node:fs"); +const { createWriteStream, existsSync, statSync, mkdirSync } = require("node:fs"); const fsp = require("node:fs/promises"); const { performance } = require("node:perf_hooks"); const path = require("node:path"); -const { mkdirp } = require("mkdirp"); const md5File = require("md5-file"); const { DEFAULT_DIRECTORY, @@ -50,6 +49,63 @@ function appendBinSuffixIfMissing(name) { return name; } +function prepareMessagesForIngest(messages) { + const systemMessages = messages.filter( + (message) => message.role === "system" + ); + if (systemMessages.length > 0) { + console.warn( + "System messages are currently not supported and will be ignored. Use the systemPrompt option instead." + ); + } + + const userAssistantMessages = messages.filter( + (message) => message.role !== "system" + ); + + // make sure the first message is a user message + // if its not, the turns will be out of order + if (userAssistantMessages[0].role !== "user") { + userAssistantMessages.unshift({ + role: "user", + content: "", + }); + } + + // create turns of user input + assistant reply + const turns = []; + let userMessage = null; + let assistantMessage = null; + + for (const message of userAssistantMessages) { + // consecutive messages of the same role are concatenated into one message + if (message.role === "user") { + if (!userMessage) { + userMessage = message.content; + } else { + userMessage += "\n" + message.content; + } + } else if (message.role === "assistant") { + if (!assistantMessage) { + assistantMessage = message.content; + } else { + assistantMessage += "\n" + message.content; + } + } + + if (userMessage && assistantMessage) { + turns.push({ + user: userMessage, + assistant: assistantMessage, + }); + userMessage = null; + assistantMessage = null; + } + } + + return turns; +} + // readChunks() reads from the provided reader and yields the results into an async iterable // https://css-tricks.com/web-streams-everywhere-and-fetch-for-node-js/ function readChunks(reader) { @@ -64,49 +120,13 @@ function readChunks(reader) { }; } -/** - * Prints a warning if any keys in the prompt context are snake_case. - */ -function warnOnSnakeCaseKeys(promptContext) { - const snakeCaseKeys = Object.keys(promptContext).filter((key) => - key.includes("_") - ); - - if (snakeCaseKeys.length > 0) { - console.warn( - "Prompt context keys should be camelCase. Support for snake_case might be removed in the future. Found keys: " + - snakeCaseKeys.join(", ") - ); - } -} - -/** - * Converts all keys in the prompt context to snake_case - * For duplicate definitions, the value of the last occurrence will be used. - */ -function normalizePromptContext(promptContext) { - const normalizedPromptContext = {}; - - for (const key in promptContext) { - if (promptContext.hasOwnProperty(key)) { - const snakeKey = key.replace( - /[A-Z]/g, - (match) => `_${match.toLowerCase()}` - ); - normalizedPromptContext[snakeKey] = promptContext[key]; - } - } - - return normalizedPromptContext; -} - function downloadModel(modelName, options = {}) { const downloadOptions = { modelPath: DEFAULT_DIRECTORY, verbose: false, ...options, }; - + const modelFileName = appendBinSuffixIfMissing(modelName); const partialModelPath = path.join( downloadOptions.modelPath, @@ -114,16 +134,17 @@ function downloadModel(modelName, options = {}) { ); const finalModelPath = path.join(downloadOptions.modelPath, modelFileName); const modelUrl = - downloadOptions.url ?? `https://gpt4all.io/models/gguf/${modelFileName}`; + downloadOptions.url ?? + `https://gpt4all.io/models/gguf/${modelFileName}`; - mkdirp.sync(downloadOptions.modelPath) + mkdirSync(downloadOptions.modelPath, { recursive: true }); if (existsSync(finalModelPath)) { throw Error(`Model already exists at ${finalModelPath}`); } - + if (downloadOptions.verbose) { - console.log(`Downloading ${modelName} from ${modelUrl}`); + console.debug(`Downloading ${modelName} from ${modelUrl}`); } const headers = { @@ -134,7 +155,9 @@ function downloadModel(modelName, options = {}) { const writeStreamOpts = {}; if (existsSync(partialModelPath)) { - console.log("Partial model exists, resuming download..."); + if (downloadOptions.verbose) { + console.debug("Partial model exists, resuming download..."); + } const startRange = statSync(partialModelPath).size; headers["Range"] = `bytes=${startRange}-`; writeStreamOpts.flags = "a"; @@ -144,15 +167,15 @@ function downloadModel(modelName, options = {}) { const signal = abortController.signal; const finalizeDownload = async () => { - if (options.md5sum) { + if (downloadOptions.md5sum) { const fileHash = await md5File(partialModelPath); - if (fileHash !== options.md5sum) { + if (fileHash !== downloadOptions.md5sum) { await fsp.unlink(partialModelPath); - const message = `Model "${modelName}" failed verification: Hashes mismatch. Expected ${options.md5sum}, got ${fileHash}`; + const message = `Model "${modelName}" failed verification: Hashes mismatch. Expected ${downloadOptions.md5sum}, got ${fileHash}`; throw Error(message); } - if (options.verbose) { - console.log(`MD5 hash verified: ${fileHash}`); + if (downloadOptions.verbose) { + console.debug(`MD5 hash verified: ${fileHash}`); } } @@ -163,8 +186,8 @@ function downloadModel(modelName, options = {}) { const downloadPromise = new Promise((resolve, reject) => { let timestampStart; - if (options.verbose) { - console.log(`Downloading @ ${partialModelPath} ...`); + if (downloadOptions.verbose) { + console.debug(`Downloading @ ${partialModelPath} ...`); timestampStart = performance.now(); } @@ -179,7 +202,7 @@ function downloadModel(modelName, options = {}) { }); writeStream.on("finish", () => { - if (options.verbose) { + if (downloadOptions.verbose) { const elapsed = performance.now() - timestampStart; console.log(`Finished. Download took ${elapsed.toFixed(2)} ms`); } @@ -221,10 +244,10 @@ async function retrieveModel(modelName, options = {}) { const retrieveOptions = { modelPath: DEFAULT_DIRECTORY, allowDownload: true, - verbose: true, + verbose: false, ...options, }; - await mkdirp(retrieveOptions.modelPath); + mkdirSync(retrieveOptions.modelPath, { recursive: true }); const modelFileName = appendBinSuffixIfMissing(modelName); const fullModelPath = path.join(retrieveOptions.modelPath, modelFileName); @@ -236,7 +259,7 @@ async function retrieveModel(modelName, options = {}) { file: retrieveOptions.modelConfigFile, url: retrieveOptions.allowDownload && - "https://gpt4all.io/models/models2.json", + "https://gpt4all.io/models/models3.json", }); const loadedModelConfig = availableModels.find( @@ -262,10 +285,9 @@ async function retrieveModel(modelName, options = {}) { config.path = fullModelPath; if (retrieveOptions.verbose) { - console.log(`Found ${modelName} at ${fullModelPath}`); + console.debug(`Found ${modelName} at ${fullModelPath}`); } } else if (retrieveOptions.allowDownload) { - const downloadController = downloadModel(modelName, { modelPath: retrieveOptions.modelPath, verbose: retrieveOptions.verbose, @@ -278,7 +300,7 @@ async function retrieveModel(modelName, options = {}) { config.path = downloadPath; if (retrieveOptions.verbose) { - console.log(`Model downloaded to ${downloadPath}`); + console.debug(`Model downloaded to ${downloadPath}`); } } else { throw Error("Failed to retrieve model."); @@ -288,9 +310,8 @@ async function retrieveModel(modelName, options = {}) { module.exports = { appendBinSuffixIfMissing, + prepareMessagesForIngest, downloadModel, retrieveModel, listModels, - normalizePromptContext, - warnOnSnakeCaseKeys, }; diff --git a/gpt4all-bindings/typescript/test/gpt4all.test.js b/gpt4all-bindings/typescript/test/gpt4all.test.js index f60efdb4..6d987a3f 100644 --- a/gpt4all-bindings/typescript/test/gpt4all.test.js +++ b/gpt4all-bindings/typescript/test/gpt4all.test.js @@ -7,7 +7,6 @@ const { listModels, downloadModel, appendBinSuffixIfMissing, - normalizePromptContext, } = require("../src/util.js"); const { DEFAULT_DIRECTORY, @@ -19,8 +18,6 @@ const { createPrompt, createCompletion, } = require("../src/gpt4all.js"); -const { mock } = require("node:test"); -const { mkdirp } = require("mkdirp"); describe("config", () => { test("default paths constants are available and correct", () => { @@ -87,7 +84,7 @@ describe("listModels", () => { expect(fetch).toHaveBeenCalledTimes(0); expect(models[0]).toEqual(fakeModel); }); - + it("should throw an error if neither url nor file is specified", async () => { await expect(listModels(null)).rejects.toThrow( "No model list source specified. Please specify either a url or a file." @@ -141,10 +138,10 @@ describe("downloadModel", () => { mockAbortController.mockReset(); mockFetch.mockClear(); global.fetch.mockRestore(); - + const rootDefaultPath = path.resolve(DEFAULT_DIRECTORY), partialPath = path.resolve(rootDefaultPath, fakeModelName+'.part'), - fullPath = path.resolve(rootDefaultPath, fakeModelName+'.bin') + fullPath = path.resolve(rootDefaultPath, fakeModelName+'.bin') //if tests fail, remove the created files // acts as cleanup if tests fail @@ -206,46 +203,3 @@ describe("downloadModel", () => { // test("should be able to cancel and resume a download", async () => { // }); }); - -describe("normalizePromptContext", () => { - it("should convert a dict with camelCased keys to snake_case", () => { - const camelCased = { - topK: 20, - repeatLastN: 10, - }; - - const expectedSnakeCased = { - top_k: 20, - repeat_last_n: 10, - }; - - const result = normalizePromptContext(camelCased); - expect(result).toEqual(expectedSnakeCased); - }); - - it("should convert a mixed case dict to snake_case, last value taking precedence", () => { - const mixedCased = { - topK: 20, - top_k: 10, - repeatLastN: 10, - }; - - const expectedSnakeCased = { - top_k: 10, - repeat_last_n: 10, - }; - - const result = normalizePromptContext(mixedCased); - expect(result).toEqual(expectedSnakeCased); - }); - - it("should not modify already snake cased dict", () => { - const snakeCased = { - top_k: 10, - repeast_last_n: 10, - }; - - const result = normalizePromptContext(snakeCased); - expect(result).toEqual(snakeCased); - }); -}); diff --git a/gpt4all-bindings/typescript/yarn.lock b/gpt4all-bindings/typescript/yarn.lock index 251b5398..f760a3b2 100644 --- a/gpt4all-bindings/typescript/yarn.lock +++ b/gpt4all-bindings/typescript/yarn.lock @@ -2300,7 +2300,6 @@ __metadata: documentation: ^14.0.2 jest: ^29.5.0 md5-file: ^5.0.0 - mkdirp: ^3.0.1 node-addon-api: ^6.1.0 node-gyp: 9.x.x node-gyp-build: ^4.6.0 @@ -4258,15 +4257,6 @@ __metadata: languageName: node linkType: hard -"mkdirp@npm:^3.0.1": - version: 3.0.1 - resolution: "mkdirp@npm:3.0.1" - bin: - mkdirp: dist/cjs/src/bin.js - checksum: 972deb188e8fb55547f1e58d66bd6b4a3623bf0c7137802582602d73e6480c1c2268dcbafbfb1be466e00cc7e56ac514d7fd9334b7cf33e3e2ab547c16f83a8d - languageName: node - linkType: hard - "mri@npm:^1.1.0": version: 1.2.0 resolution: "mri@npm:1.2.0"