python bindings: typing fixes, misc fixes (#1131)

* python: do not mutate locals()

* python: fix (some) typing complaints

* python: queue sentinel need not be a str

* python: make long inference tests opt in
This commit is contained in:
Aaron Miller 2023-07-03 18:30:24 -07:00 committed by GitHub
parent 01bd3d6802
commit 6987910668
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 20 deletions

View File

@ -5,7 +5,7 @@ import os
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterable, List, Union
from typing import Dict, Iterable, List, Union, Optional
import requests
from tqdm import tqdm
@ -22,7 +22,12 @@ class GPT4All:
"""
def __init__(
self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None
self,
model_name: str,
model_path: Optional[str] = None,
model_type: Optional[str] = None,
allow_download: bool = True,
n_threads: Optional[int] = None,
):
"""
Constructor
@ -60,7 +65,7 @@ class GPT4All:
@staticmethod
def retrieve_model(
model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
) -> str:
"""
Find model file, and if it doesn't exist, download the model.
@ -120,7 +125,7 @@ class GPT4All:
raise ValueError("Failed to retrieve model")
@staticmethod
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: str = None) -> str:
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
"""
Download model from https://gpt4all.io.
@ -181,7 +186,7 @@ class GPT4All:
repeat_penalty: float = 1.18,
repeat_last_n: int = 64,
n_batch: int = 8,
n_predict: int = None,
n_predict: Optional[int] = None,
streaming: bool = False,
) -> Union[str, Iterable]:
"""
@ -202,13 +207,16 @@ class GPT4All:
Returns:
Either the entire completion or a generator that yields the completion token by token.
"""
generate_kwargs = locals()
generate_kwargs.pop('self')
generate_kwargs.pop('max_tokens')
generate_kwargs.pop('streaming')
generate_kwargs['n_predict'] = max_tokens
if n_predict is not None:
generate_kwargs['n_predict'] = n_predict
generate_kwargs = dict(
prompt=prompt,
temp=temp,
top_k=top_k,
top_p=top_p,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
n_batch=n_batch,
n_predict=n_predict if n_predict is not None else max_tokens,
)
if self._is_chat_session_activated:
self.current_chat_session.append({"role": "user", "content": prompt})

View File

@ -262,8 +262,8 @@ class LLModel:
Model response str
"""
prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt)
prompt_bytes = prompt.encode('utf-8')
prompt_ptr = ctypes.c_char_p(prompt_bytes)
old_stdout = sys.stdout
@ -288,7 +288,7 @@ class LLModel:
llmodel.llmodel_prompt(
self.model,
prompt,
prompt_ptr,
PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback),
RecalculateCallback(self._recalculate_callback),
@ -314,12 +314,12 @@ class LLModel:
reset_context: bool = False,
) -> Iterable:
# Symbol to terminate from generator
TERMINATING_SYMBOL = "#TERMINATE#"
TERMINATING_SYMBOL = object()
output_queue = queue.Queue()
prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt)
prompt_bytes = prompt.encode('utf-8')
prompt_ptr = ctypes.c_char_p(prompt_bytes)
self._set_context(
n_predict=n_predict,
@ -348,7 +348,7 @@ class LLModel:
target=run_llmodel_prompt,
args=(
self.model,
prompt,
prompt_ptr,
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
@ -360,7 +360,7 @@ class LLModel:
# Generator
while True:
response = output_queue.get()
if response == TERMINATING_SYMBOL:
if response is TERMINATING_SYMBOL:
break
yield response

View File

@ -25,4 +25,7 @@ isort:
source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all
test:
source env/bin/activate; pytest -s gpt4all/tests -k "not test_inference_long"
test_all:
source env/bin/activate; pytest -s gpt4all/tests