python: implement close() and context manager interface (#2177)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-03-28 16:48:07 -04:00 committed by GitHub
parent dddaf49428
commit 3313c7de0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 4 deletions

View File

@ -9,7 +9,7 @@ import sys
import threading import threading
from enum import Enum from enum import Enum
from queue import Queue from queue import Queue
from typing import Any, Callable, Generic, Iterable, TypeVar, overload from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
if sys.version_info >= (3, 9): if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources import importlib.resources as importlib_resources
@ -200,13 +200,22 @@ class LLModel:
if model is None: if model is None:
s = err.value s = err.value
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}") raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
self.model = model self.model: ctypes.c_void_p | None = model
def __del__(self, llmodel=llmodel): def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'): if hasattr(self, 'model'):
self.close()
def close(self) -> None:
if self.model is not None:
llmodel.llmodel_model_destroy(self.model) llmodel.llmodel_model_destroy(self.model)
self.model = None
def _raise_closed(self) -> NoReturn:
raise ValueError("Attempted operation on a closed LLModel")
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]: def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
assert self.model is not None
num_devices = ctypes.c_int32(0) num_devices = ctypes.c_int32(0)
devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices)) devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
if not devices_ptr: if not devices_ptr:
@ -214,6 +223,9 @@ class LLModel:
return devices_ptr[:num_devices.value] return devices_ptr[:num_devices.value]
def init_gpu(self, device: str): def init_gpu(self, device: str):
if self.model is None:
self._raise_closed()
mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl) mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl)
if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()): if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()):
@ -246,14 +258,21 @@ class LLModel:
------- -------
True if model loaded successfully, False otherwise True if model loaded successfully, False otherwise
""" """
if self.model is None:
self._raise_closed()
return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl) return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl)
def set_thread_count(self, n_threads): def set_thread_count(self, n_threads):
if self.model is None:
self._raise_closed()
if not llmodel.llmodel_isModelLoaded(self.model): if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded") raise Exception("Model not loaded")
llmodel.llmodel_setThreadCount(self.model, n_threads) llmodel.llmodel_setThreadCount(self.model, n_threads)
def thread_count(self): def thread_count(self):
if self.model is None:
self._raise_closed()
if not llmodel.llmodel_isModelLoaded(self.model): if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded") raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model) return llmodel.llmodel_threadCount(self.model)
@ -322,6 +341,9 @@ class LLModel:
if not text: if not text:
raise ValueError("text must not be None or empty") raise ValueError("text must not be None or empty")
if self.model is None:
self._raise_closed()
if (single_text := isinstance(text, str)): if (single_text := isinstance(text, str)):
text = [text] text = [text]
@ -387,6 +409,9 @@ class LLModel:
None None
""" """
if self.model is None:
self._raise_closed()
self.buffer.clear() self.buffer.clear()
self.buff_expecting_cont_bytes = 0 self.buff_expecting_cont_bytes = 0
@ -419,6 +444,9 @@ class LLModel:
def prompt_model_streaming( def prompt_model_streaming(
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]: ) -> Iterable[str]:
if self.model is None:
self._raise_closed()
output_queue: Queue[str | Sentinel] = Queue() output_queue: Queue[str | Sentinel] = Queue()
# Put response tokens into an output queue # Put response tokens into an output queue

View File

@ -11,6 +11,7 @@ import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
import requests import requests
@ -22,7 +23,7 @@ from . import _pyllmodel
from ._pyllmodel import EmbedResult as EmbedResult from ._pyllmodel import EmbedResult as EmbedResult
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import TypeAlias from typing_extensions import Self, TypeAlias
if sys.platform == 'darwin': if sys.platform == 'darwin':
import fcntl import fcntl
@ -54,6 +55,18 @@ class Embed4All:
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf' model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs) self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs)
def __enter__(self) -> Self:
return self
def __exit__(
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""Delete the model instance and free associated system resources."""
self.gpt4all.close()
# return_dict=False # return_dict=False
@overload @overload
def embed( def embed(
@ -190,6 +203,18 @@ class GPT4All:
self._history: list[MessageType] | None = None self._history: list[MessageType] | None = None
self._current_prompt_template: str = "{0}" self._current_prompt_template: str = "{0}"
def __enter__(self) -> Self:
return self
def __exit__(
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""Delete the model instance and free associated system resources."""
self.model.close()
@property @property
def current_chat_session(self) -> list[MessageType] | None: def current_chat_session(self) -> list[MessageType] | None:
return None if self._history is None else list(self._history) return None if self._history is None else list(self._history)

View File

@ -68,7 +68,7 @@ def get_long_description():
setup( setup(
name=package_name, name=package_name,
version="2.3.2", version="2.3.3",
description="Python bindings for GPT4All", description="Python bindings for GPT4All",
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",