mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-07 11:30:05 +00:00
Implement configurable context length (#1749)
This commit is contained in:
@@ -188,7 +188,7 @@ public class LLModel : ILLModel
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool Load(string modelPath)
|
||||
{
|
||||
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
||||
return NativeMethods.llmodel_loadModel(_handle, modelPath, 2048);
|
||||
}
|
||||
|
||||
protected void Destroy()
|
||||
|
@@ -70,7 +70,8 @@ internal static unsafe partial class NativeMethods
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_loadModel(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
|
||||
[NativeTypeName("int32_t")] int n_ctx);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
|
||||
|
@@ -39,7 +39,7 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
|
||||
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
|
||||
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
|
||||
_logger.LogInformation("Model loading started");
|
||||
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
|
||||
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048);
|
||||
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
|
||||
if (!loadedSuccessfully)
|
||||
{
|
||||
|
@@ -23,7 +23,7 @@ void* load_model(const char *fname, int n_threads) {
|
||||
fprintf(stderr, "%s: error '%s'\n", __func__, new_error);
|
||||
return nullptr;
|
||||
}
|
||||
if (!llmodel_loadModel(model, fname)) {
|
||||
if (!llmodel_loadModel(model, fname, 2048)) {
|
||||
llmodel_model_destroy(model);
|
||||
return nullptr;
|
||||
}
|
||||
|
@@ -195,7 +195,7 @@ public class LLModel implements AutoCloseable {
|
||||
if(model == null) {
|
||||
throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0));
|
||||
}
|
||||
library.llmodel_loadModel(model, modelPathAbs);
|
||||
library.llmodel_loadModel(model, modelPathAbs, 2048);
|
||||
|
||||
if(!library.llmodel_isModelLoaded(model)){
|
||||
throw new IllegalStateException("The model " + modelName + " could not be loaded");
|
||||
|
@@ -61,7 +61,7 @@ public interface LLModelLibrary {
|
||||
|
||||
Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error);
|
||||
void llmodel_model_destroy(Pointer model);
|
||||
boolean llmodel_loadModel(Pointer model, String model_path);
|
||||
boolean llmodel_loadModel(Pointer model, String model_path, int n_ctx);
|
||||
boolean llmodel_isModelLoaded(Pointer model);
|
||||
@u_int64_t long llmodel_get_state_size(Pointer model);
|
||||
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from .gpt4all import Embed4All, GPT4All # noqa
|
||||
from .pyllmodel import LLModel # noqa
|
||||
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
|
||||
from .pyllmodel import LLModel as LLModel
|
||||
|
@@ -69,6 +69,7 @@ class GPT4All:
|
||||
allow_download: bool = True,
|
||||
n_threads: Optional[int] = None,
|
||||
device: Optional[str] = "cpu",
|
||||
n_ctx: int = 2048,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -90,15 +91,16 @@ class GPT4All:
|
||||
Default is "cpu".
|
||||
|
||||
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
||||
n_ctx: Maximum size of context window
|
||||
verbose: If True, print debug messages.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.model = pyllmodel.LLModel()
|
||||
# Retrieve model and download if allowed
|
||||
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
||||
if device is not None:
|
||||
if device != "cpu":
|
||||
self.model.init_gpu(model_path=self.config["path"], device=device)
|
||||
self.model.load_model(self.config["path"])
|
||||
if device is not None and device != "cpu":
|
||||
self.model.init_gpu(model_path=self.config["path"], device=device, n_ctx=n_ctx)
|
||||
self.model.load_model(self.config["path"], n_ctx)
|
||||
# Set n_threads
|
||||
if n_threads is not None:
|
||||
self.model.set_thread_count(n_threads)
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import importlib.resources
|
||||
import logging
|
||||
@@ -7,6 +9,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
@@ -72,9 +75,9 @@ llmodel.llmodel_model_create2.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_model_destroy.restype = None
|
||||
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
|
||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
|
||||
llmodel.llmodel_required_mem.restype = ctypes.c_size_t
|
||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
@@ -114,7 +117,7 @@ llmodel.llmodel_set_implementation_search_path.restype = None
|
||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||
|
||||
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode("utf-8"))
|
||||
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode())
|
||||
|
||||
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
|
||||
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
|
||||
@@ -143,10 +146,16 @@ def _create_model(model_path: bytes) -> ctypes.c_void_p:
|
||||
err = ctypes.c_char_p()
|
||||
model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err))
|
||||
if model is None:
|
||||
raise ValueError(f"Unable to instantiate model: {err.decode()}")
|
||||
s = err.value
|
||||
raise ValueError("Unable to instantiate model: {'null' if s is None else s.decode()}")
|
||||
return model
|
||||
|
||||
|
||||
# Symbol to terminate from generator
|
||||
class Sentinel(Enum):
|
||||
TERMINATING_SYMBOL = 0
|
||||
|
||||
|
||||
class LLModel:
|
||||
"""
|
||||
Base class and universal wrapper for GPT4All language models
|
||||
@@ -173,12 +182,16 @@ class LLModel:
|
||||
if self.model is not None:
|
||||
self.llmodel_lib.llmodel_model_destroy(self.model)
|
||||
|
||||
def memory_needed(self, model_path: str) -> int:
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
self.model = _create_model(model_path_enc)
|
||||
return llmodel.llmodel_required_mem(self.model, model_path_enc)
|
||||
def memory_needed(self, model_path: str, n_ctx: int) -> int:
|
||||
self.model = None
|
||||
return self._memory_needed(model_path, n_ctx)
|
||||
|
||||
def list_gpu(self, model_path: str) -> list:
|
||||
def _memory_needed(self, model_path: str, n_ctx: int) -> int:
|
||||
if self.model is None:
|
||||
self.model = _create_model(model_path.encode())
|
||||
return llmodel.llmodel_required_mem(self.model, model_path.encode(), n_ctx)
|
||||
|
||||
def list_gpu(self, model_path: str, n_ctx: int) -> list[LLModelGPUDevice]:
|
||||
"""
|
||||
Lists available GPU devices that satisfy the model's memory requirements.
|
||||
|
||||
@@ -186,45 +199,41 @@ class LLModel:
|
||||
----------
|
||||
model_path : str
|
||||
Path to the model.
|
||||
n_ctx : int
|
||||
Maximum size of context window
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of LLModelGPUDevice structures representing available GPU devices.
|
||||
"""
|
||||
if self.model is not None:
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
|
||||
else:
|
||||
mem_required = self.memory_needed(model_path)
|
||||
mem_required = self._memory_needed(model_path, n_ctx)
|
||||
return self._list_gpu(mem_required)
|
||||
|
||||
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
|
||||
num_devices = ctypes.c_int32(0)
|
||||
devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
|
||||
if not devices_ptr:
|
||||
raise ValueError("Unable to retrieve available GPU devices")
|
||||
devices = [devices_ptr[i] for i in range(num_devices.value)]
|
||||
return devices
|
||||
return devices_ptr[:num_devices.value]
|
||||
|
||||
def init_gpu(self, model_path: str, device: str):
|
||||
if self.model is not None:
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
|
||||
else:
|
||||
mem_required = self.memory_needed(model_path)
|
||||
device_enc = device.encode("utf-8")
|
||||
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device_enc)
|
||||
def init_gpu(self, model_path: str, device: str, n_ctx: int):
|
||||
mem_required = self._memory_needed(model_path, n_ctx)
|
||||
|
||||
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode())
|
||||
if not success:
|
||||
# Retrieve all GPUs without considering memory requirements.
|
||||
num_devices = ctypes.c_int32(0)
|
||||
all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices))
|
||||
if not all_devices_ptr:
|
||||
raise ValueError("Unable to retrieve list of all GPU devices")
|
||||
all_gpus = [all_devices_ptr[i].name.decode('utf-8') for i in range(num_devices.value)]
|
||||
all_gpus = [d.name.decode() for d in all_devices_ptr[:num_devices.value]]
|
||||
|
||||
# Retrieve GPUs that meet the memory requirements using list_gpu
|
||||
available_gpus = [device.name.decode('utf-8') for device in self.list_gpu(model_path)]
|
||||
available_gpus = [device.name.decode() for device in self._list_gpu(mem_required)]
|
||||
|
||||
# Identify GPUs that are unavailable due to insufficient memory or features
|
||||
unavailable_gpus = set(all_gpus) - set(available_gpus)
|
||||
unavailable_gpus = set(all_gpus).difference(available_gpus)
|
||||
|
||||
# Formulate the error message
|
||||
error_msg = "Unable to initialize model on GPU: '{}'.".format(device)
|
||||
@@ -232,7 +241,7 @@ class LLModel:
|
||||
error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def load_model(self, model_path: str) -> bool:
|
||||
def load_model(self, model_path: str, n_ctx: int) -> bool:
|
||||
"""
|
||||
Load model from a file.
|
||||
|
||||
@@ -240,15 +249,16 @@ class LLModel:
|
||||
----------
|
||||
model_path : str
|
||||
Model filepath
|
||||
n_ctx : int
|
||||
Maximum size of context window
|
||||
|
||||
Returns
|
||||
-------
|
||||
True if model loaded successfully, False otherwise
|
||||
"""
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
self.model = _create_model(model_path_enc)
|
||||
self.model = _create_model(model_path.encode())
|
||||
|
||||
llmodel.llmodel_loadModel(self.model, model_path_enc)
|
||||
llmodel.llmodel_loadModel(self.model, model_path.encode(), n_ctx)
|
||||
|
||||
filename = os.path.basename(model_path)
|
||||
self.model_name = os.path.splitext(filename)[0]
|
||||
@@ -312,7 +322,7 @@ class LLModel:
|
||||
raise ValueError("Text must not be None or empty")
|
||||
|
||||
embedding_size = ctypes.c_size_t()
|
||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||
c_text = ctypes.c_char_p(text.encode())
|
||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
|
||||
llmodel.llmodel_free_embedding(embedding_ptr)
|
||||
@@ -357,7 +367,7 @@ class LLModel:
|
||||
prompt,
|
||||
)
|
||||
|
||||
prompt_bytes = prompt.encode("utf-8")
|
||||
prompt_bytes = prompt.encode()
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
self._set_context(
|
||||
@@ -385,10 +395,7 @@ class LLModel:
|
||||
def prompt_model_streaming(
|
||||
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
) -> Iterable[str]:
|
||||
# Symbol to terminate from generator
|
||||
TERMINATING_SYMBOL = object()
|
||||
|
||||
output_queue: Queue = Queue()
|
||||
output_queue: Queue[str | Sentinel] = Queue()
|
||||
|
||||
# Put response tokens into an output queue
|
||||
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||
@@ -405,7 +412,7 @@ class LLModel:
|
||||
|
||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, callback, **kwargs)
|
||||
output_queue.put(TERMINATING_SYMBOL)
|
||||
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
# immediately
|
||||
@@ -419,7 +426,7 @@ class LLModel:
|
||||
# Generator
|
||||
while True:
|
||||
response = output_queue.get()
|
||||
if response is TERMINATING_SYMBOL:
|
||||
if isinstance(response, Sentinel):
|
||||
break
|
||||
yield response
|
||||
|
||||
@@ -442,7 +449,7 @@ class LLModel:
|
||||
else:
|
||||
# beginning of a byte sequence
|
||||
if len(self.buffer) > 0:
|
||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||
decoded.append(self.buffer.decode(errors='replace'))
|
||||
|
||||
self.buffer.clear()
|
||||
|
||||
@@ -451,7 +458,7 @@ class LLModel:
|
||||
|
||||
if self.buff_expecting_cont_bytes <= 0:
|
||||
# received the whole sequence or an out of place continuation byte
|
||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||
decoded.append(self.buffer.decode(errors='replace'))
|
||||
|
||||
self.buffer.clear()
|
||||
self.buff_expecting_cont_bytes = 0
|
||||
|
@@ -117,7 +117,7 @@ def test_empty_embedding():
|
||||
def test_download_model(tmp_path: Path):
|
||||
import gpt4all.gpt4all
|
||||
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
|
||||
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens
|
||||
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens
|
||||
try:
|
||||
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
|
||||
model_path = tmp_path / model.config['filename']
|
||||
|
@@ -28,7 +28,7 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
||||
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
{
|
||||
auto env = info.Env();
|
||||
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str()) ));
|
||||
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048) ));
|
||||
|
||||
}
|
||||
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
|
||||
@@ -161,7 +161,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||
}
|
||||
}
|
||||
|
||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str());
|
||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048);
|
||||
if(!success) {
|
||||
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
||||
return;
|
||||
|
Reference in New Issue
Block a user