mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-19 12:14:20 +00:00
python: implement close() and context manager interface (#2177)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
dddaf49428
commit
3313c7de0d
@ -9,7 +9,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Any, Callable, Generic, Iterable, TypeVar, overload
|
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
|
||||||
|
|
||||||
if sys.version_info >= (3, 9):
|
if sys.version_info >= (3, 9):
|
||||||
import importlib.resources as importlib_resources
|
import importlib.resources as importlib_resources
|
||||||
@ -200,13 +200,22 @@ class LLModel:
|
|||||||
if model is None:
|
if model is None:
|
||||||
s = err.value
|
s = err.value
|
||||||
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
|
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
|
||||||
self.model = model
|
self.model: ctypes.c_void_p | None = model
|
||||||
|
|
||||||
def __del__(self, llmodel=llmodel):
|
def __del__(self, llmodel=llmodel):
|
||||||
if hasattr(self, 'model'):
|
if hasattr(self, 'model'):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self.model is not None:
|
||||||
llmodel.llmodel_model_destroy(self.model)
|
llmodel.llmodel_model_destroy(self.model)
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def _raise_closed(self) -> NoReturn:
|
||||||
|
raise ValueError("Attempted operation on a closed LLModel")
|
||||||
|
|
||||||
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
|
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
|
||||||
|
assert self.model is not None
|
||||||
num_devices = ctypes.c_int32(0)
|
num_devices = ctypes.c_int32(0)
|
||||||
devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
|
devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
|
||||||
if not devices_ptr:
|
if not devices_ptr:
|
||||||
@ -214,6 +223,9 @@ class LLModel:
|
|||||||
return devices_ptr[:num_devices.value]
|
return devices_ptr[:num_devices.value]
|
||||||
|
|
||||||
def init_gpu(self, device: str):
|
def init_gpu(self, device: str):
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
|
|
||||||
mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl)
|
mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl)
|
||||||
|
|
||||||
if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()):
|
if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()):
|
||||||
@ -246,14 +258,21 @@ class LLModel:
|
|||||||
-------
|
-------
|
||||||
True if model loaded successfully, False otherwise
|
True if model loaded successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
|
|
||||||
return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl)
|
return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl)
|
||||||
|
|
||||||
def set_thread_count(self, n_threads):
|
def set_thread_count(self, n_threads):
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||||
raise Exception("Model not loaded")
|
raise Exception("Model not loaded")
|
||||||
llmodel.llmodel_setThreadCount(self.model, n_threads)
|
llmodel.llmodel_setThreadCount(self.model, n_threads)
|
||||||
|
|
||||||
def thread_count(self):
|
def thread_count(self):
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||||
raise Exception("Model not loaded")
|
raise Exception("Model not loaded")
|
||||||
return llmodel.llmodel_threadCount(self.model)
|
return llmodel.llmodel_threadCount(self.model)
|
||||||
@ -322,6 +341,9 @@ class LLModel:
|
|||||||
if not text:
|
if not text:
|
||||||
raise ValueError("text must not be None or empty")
|
raise ValueError("text must not be None or empty")
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
|
|
||||||
if (single_text := isinstance(text, str)):
|
if (single_text := isinstance(text, str)):
|
||||||
text = [text]
|
text = [text]
|
||||||
|
|
||||||
@ -387,6 +409,9 @@ class LLModel:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
|
|
||||||
self.buffer.clear()
|
self.buffer.clear()
|
||||||
self.buff_expecting_cont_bytes = 0
|
self.buff_expecting_cont_bytes = 0
|
||||||
|
|
||||||
@ -419,6 +444,9 @@ class LLModel:
|
|||||||
def prompt_model_streaming(
|
def prompt_model_streaming(
|
||||||
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
|
if self.model is None:
|
||||||
|
self._raise_closed()
|
||||||
|
|
||||||
output_queue: Queue[str | Sentinel] = Queue()
|
output_queue: Queue[str | Sentinel] = Queue()
|
||||||
|
|
||||||
# Put response tokens into an output queue
|
# Put response tokens into an output queue
|
||||||
|
@ -11,6 +11,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import TracebackType
|
||||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
|
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -22,7 +23,7 @@ from . import _pyllmodel
|
|||||||
from ._pyllmodel import EmbedResult as EmbedResult
|
from ._pyllmodel import EmbedResult as EmbedResult
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
|
|
||||||
if sys.platform == 'darwin':
|
if sys.platform == 'darwin':
|
||||||
import fcntl
|
import fcntl
|
||||||
@ -54,6 +55,18 @@ class Embed4All:
|
|||||||
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
|
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
|
||||||
self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs)
|
self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs)
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Delete the model instance and free associated system resources."""
|
||||||
|
self.gpt4all.close()
|
||||||
|
|
||||||
# return_dict=False
|
# return_dict=False
|
||||||
@overload
|
@overload
|
||||||
def embed(
|
def embed(
|
||||||
@ -190,6 +203,18 @@ class GPT4All:
|
|||||||
self._history: list[MessageType] | None = None
|
self._history: list[MessageType] | None = None
|
||||||
self._current_prompt_template: str = "{0}"
|
self._current_prompt_template: str = "{0}"
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Delete the model instance and free associated system resources."""
|
||||||
|
self.model.close()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_chat_session(self) -> list[MessageType] | None:
|
def current_chat_session(self) -> list[MessageType] | None:
|
||||||
return None if self._history is None else list(self._history)
|
return None if self._history is None else list(self._history)
|
||||||
|
@ -68,7 +68,7 @@ def get_long_description():
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=package_name,
|
||||||
version="2.3.2",
|
version="2.3.3",
|
||||||
description="Python bindings for GPT4All",
|
description="Python bindings for GPT4All",
|
||||||
long_description=get_long_description(),
|
long_description=get_long_description(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
Loading…
Reference in New Issue
Block a user