diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index ca22d18b..987a21aa 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -5,7 +5,7 @@ import os import time from contextlib import contextmanager from pathlib import Path -from typing import Dict, Iterable, List, Union +from typing import Dict, Iterable, List, Union, Optional import requests from tqdm import tqdm @@ -22,7 +22,12 @@ class GPT4All: """ def __init__( - self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None + self, + model_name: str, + model_path: Optional[str] = None, + model_type: Optional[str] = None, + allow_download: bool = True, + n_threads: Optional[int] = None, ): """ Constructor @@ -60,7 +65,7 @@ class GPT4All: @staticmethod def retrieve_model( - model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True + model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True ) -> str: """ Find model file, and if it doesn't exist, download the model. @@ -120,7 +125,7 @@ class GPT4All: raise ValueError("Failed to retrieve model") @staticmethod - def download_model(model_filename: str, model_path: str, verbose: bool = True, url: str = None) -> str: + def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str: """ Download model from https://gpt4all.io. @@ -181,7 +186,7 @@ class GPT4All: repeat_penalty: float = 1.18, repeat_last_n: int = 64, n_batch: int = 8, - n_predict: int = None, + n_predict: Optional[int] = None, streaming: bool = False, ) -> Union[str, Iterable]: """ @@ -202,13 +207,16 @@ class GPT4All: Returns: Either the entire completion or a generator that yields the completion token by token. """ - generate_kwargs = locals() - generate_kwargs.pop('self') - generate_kwargs.pop('max_tokens') - generate_kwargs.pop('streaming') - generate_kwargs['n_predict'] = max_tokens - if n_predict is not None: - generate_kwargs['n_predict'] = n_predict + generate_kwargs = dict( + prompt=prompt, + temp=temp, + top_k=top_k, + top_p=top_p, + repeat_penalty=repeat_penalty, + repeat_last_n=repeat_last_n, + n_batch=n_batch, + n_predict=n_predict if n_predict is not None else max_tokens, + ) if self._is_chat_session_activated: self.current_chat_session.append({"role": "user", "content": prompt}) diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index f5a0c40d..7e091207 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -262,8 +262,8 @@ class LLModel: Model response str """ - prompt = prompt.encode('utf-8') - prompt = ctypes.c_char_p(prompt) + prompt_bytes = prompt.encode('utf-8') + prompt_ptr = ctypes.c_char_p(prompt_bytes) old_stdout = sys.stdout @@ -288,7 +288,7 @@ class LLModel: llmodel.llmodel_prompt( self.model, - prompt, + prompt_ptr, PromptCallback(self._prompt_callback), ResponseCallback(self._response_callback), RecalculateCallback(self._recalculate_callback), @@ -314,12 +314,12 @@ class LLModel: reset_context: bool = False, ) -> Iterable: # Symbol to terminate from generator - TERMINATING_SYMBOL = "#TERMINATE#" + TERMINATING_SYMBOL = object() output_queue = queue.Queue() - prompt = prompt.encode('utf-8') - prompt = ctypes.c_char_p(prompt) + prompt_bytes = prompt.encode('utf-8') + prompt_ptr = ctypes.c_char_p(prompt_bytes) self._set_context( n_predict=n_predict, @@ -348,7 +348,7 @@ class LLModel: target=run_llmodel_prompt, args=( self.model, - prompt, + prompt_ptr, PromptCallback(self._prompt_callback), ResponseCallback(_generator_response_callback), RecalculateCallback(self._recalculate_callback), @@ -360,7 +360,7 @@ class LLModel: # Generator while True: response = output_queue.get() - if response == TERMINATING_SYMBOL: + if response is TERMINATING_SYMBOL: break yield response diff --git a/gpt4all-bindings/python/makefile b/gpt4all-bindings/python/makefile index a32ee93f..0b3395e5 100644 --- a/gpt4all-bindings/python/makefile +++ b/gpt4all-bindings/python/makefile @@ -25,4 +25,7 @@ isort: source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all test: + source env/bin/activate; pytest -s gpt4all/tests -k "not test_inference_long" + +test_all: source env/bin/activate; pytest -s gpt4all/tests