From 3313c7de0d26fb1ed602f572247f5aad980b59e4 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 28 Mar 2024 16:48:07 -0400 Subject: [PATCH] python: implement close() and context manager interface (#2177) Signed-off-by: Jared Van Bortel --- gpt4all-bindings/python/gpt4all/_pyllmodel.py | 32 +++++++++++++++++-- gpt4all-bindings/python/gpt4all/gpt4all.py | 27 +++++++++++++++- gpt4all-bindings/python/setup.py | 2 +- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index e8c6266d..1c50d0aa 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -9,7 +9,7 @@ import sys import threading from enum import Enum from queue import Queue -from typing import Any, Callable, Generic, Iterable, TypeVar, overload +from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload if sys.version_info >= (3, 9): import importlib.resources as importlib_resources @@ -200,13 +200,22 @@ class LLModel: if model is None: s = err.value raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}") - self.model = model + self.model: ctypes.c_void_p | None = model def __del__(self, llmodel=llmodel): if hasattr(self, 'model'): + self.close() + + def close(self) -> None: + if self.model is not None: llmodel.llmodel_model_destroy(self.model) + self.model = None + + def _raise_closed(self) -> NoReturn: + raise ValueError("Attempted operation on a closed LLModel") def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]: + assert self.model is not None num_devices = ctypes.c_int32(0) devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices)) if not devices_ptr: @@ -214,6 +223,9 @@ class LLModel: return devices_ptr[:num_devices.value] def init_gpu(self, device: str): + if self.model is None: + self._raise_closed() + mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl) if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()): @@ -246,14 +258,21 @@ class LLModel: ------- True if model loaded successfully, False otherwise """ + if self.model is None: + self._raise_closed() + return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl) def set_thread_count(self, n_threads): + if self.model is None: + self._raise_closed() if not llmodel.llmodel_isModelLoaded(self.model): raise Exception("Model not loaded") llmodel.llmodel_setThreadCount(self.model, n_threads) def thread_count(self): + if self.model is None: + self._raise_closed() if not llmodel.llmodel_isModelLoaded(self.model): raise Exception("Model not loaded") return llmodel.llmodel_threadCount(self.model) @@ -322,6 +341,9 @@ class LLModel: if not text: raise ValueError("text must not be None or empty") + if self.model is None: + self._raise_closed() + if (single_text := isinstance(text, str)): text = [text] @@ -387,6 +409,9 @@ class LLModel: None """ + if self.model is None: + self._raise_closed() + self.buffer.clear() self.buff_expecting_cont_bytes = 0 @@ -419,6 +444,9 @@ class LLModel: def prompt_model_streaming( self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs ) -> Iterable[str]: + if self.model is None: + self._raise_closed() + output_queue: Queue[str | Sentinel] = Queue() # Put response tokens into an output queue diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index fadcc684..b7c11b07 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -11,6 +11,7 @@ import time import warnings from contextlib import contextmanager from pathlib import Path +from types import TracebackType from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload import requests @@ -22,7 +23,7 @@ from . import _pyllmodel from ._pyllmodel import EmbedResult as EmbedResult if TYPE_CHECKING: - from typing_extensions import TypeAlias + from typing_extensions import Self, TypeAlias if sys.platform == 'darwin': import fcntl @@ -54,6 +55,18 @@ class Embed4All: model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf' self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs) + def __enter__(self) -> Self: + return self + + def __exit__( + self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """Delete the model instance and free associated system resources.""" + self.gpt4all.close() + # return_dict=False @overload def embed( @@ -190,6 +203,18 @@ class GPT4All: self._history: list[MessageType] | None = None self._current_prompt_template: str = "{0}" + def __enter__(self) -> Self: + return self + + def __exit__( + self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """Delete the model instance and free associated system resources.""" + self.model.close() + @property def current_chat_session(self) -> list[MessageType] | None: return None if self._history is None else list(self._history) diff --git a/gpt4all-bindings/python/setup.py b/gpt4all-bindings/python/setup.py index 9c6b389b..c8b50e86 100644 --- a/gpt4all-bindings/python/setup.py +++ b/gpt4all-bindings/python/setup.py @@ -68,7 +68,7 @@ def get_long_description(): setup( name=package_name, - version="2.3.2", + version="2.3.3", description="Python bindings for GPT4All", long_description=get_long_description(), long_description_content_type="text/markdown",