mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-01 09:41:59 +00:00
python bindings: typing fixes, misc fixes (#1131)
* python: do not mutate locals() * python: fix (some) typing complaints * python: queue sentinel need not be a str * python: make long inference tests opt in
This commit is contained in:
parent
01bd3d6802
commit
6987910668
@ -5,7 +5,7 @@ import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Union
|
||||
from typing import Dict, Iterable, List, Union, Optional
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
@ -22,7 +22,12 @@ class GPT4All:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None
|
||||
self,
|
||||
model_name: str,
|
||||
model_path: Optional[str] = None,
|
||||
model_type: Optional[str] = None,
|
||||
allow_download: bool = True,
|
||||
n_threads: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Constructor
|
||||
@ -60,7 +65,7 @@ class GPT4All:
|
||||
|
||||
@staticmethod
|
||||
def retrieve_model(
|
||||
model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True
|
||||
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Find model file, and if it doesn't exist, download the model.
|
||||
@ -120,7 +125,7 @@ class GPT4All:
|
||||
raise ValueError("Failed to retrieve model")
|
||||
|
||||
@staticmethod
|
||||
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: str = None) -> str:
|
||||
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
|
||||
"""
|
||||
Download model from https://gpt4all.io.
|
||||
|
||||
@ -181,7 +186,7 @@ class GPT4All:
|
||||
repeat_penalty: float = 1.18,
|
||||
repeat_last_n: int = 64,
|
||||
n_batch: int = 8,
|
||||
n_predict: int = None,
|
||||
n_predict: Optional[int] = None,
|
||||
streaming: bool = False,
|
||||
) -> Union[str, Iterable]:
|
||||
"""
|
||||
@ -202,13 +207,16 @@ class GPT4All:
|
||||
Returns:
|
||||
Either the entire completion or a generator that yields the completion token by token.
|
||||
"""
|
||||
generate_kwargs = locals()
|
||||
generate_kwargs.pop('self')
|
||||
generate_kwargs.pop('max_tokens')
|
||||
generate_kwargs.pop('streaming')
|
||||
generate_kwargs['n_predict'] = max_tokens
|
||||
if n_predict is not None:
|
||||
generate_kwargs['n_predict'] = n_predict
|
||||
generate_kwargs = dict(
|
||||
prompt=prompt,
|
||||
temp=temp,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
n_batch=n_batch,
|
||||
n_predict=n_predict if n_predict is not None else max_tokens,
|
||||
)
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||
|
@ -262,8 +262,8 @@ class LLModel:
|
||||
Model response str
|
||||
"""
|
||||
|
||||
prompt = prompt.encode('utf-8')
|
||||
prompt = ctypes.c_char_p(prompt)
|
||||
prompt_bytes = prompt.encode('utf-8')
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
old_stdout = sys.stdout
|
||||
|
||||
@ -288,7 +288,7 @@ class LLModel:
|
||||
|
||||
llmodel.llmodel_prompt(
|
||||
self.model,
|
||||
prompt,
|
||||
prompt_ptr,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._response_callback),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
@ -314,12 +314,12 @@ class LLModel:
|
||||
reset_context: bool = False,
|
||||
) -> Iterable:
|
||||
# Symbol to terminate from generator
|
||||
TERMINATING_SYMBOL = "#TERMINATE#"
|
||||
TERMINATING_SYMBOL = object()
|
||||
|
||||
output_queue = queue.Queue()
|
||||
|
||||
prompt = prompt.encode('utf-8')
|
||||
prompt = ctypes.c_char_p(prompt)
|
||||
prompt_bytes = prompt.encode('utf-8')
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
@ -348,7 +348,7 @@ class LLModel:
|
||||
target=run_llmodel_prompt,
|
||||
args=(
|
||||
self.model,
|
||||
prompt,
|
||||
prompt_ptr,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(_generator_response_callback),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
@ -360,7 +360,7 @@ class LLModel:
|
||||
# Generator
|
||||
while True:
|
||||
response = output_queue.get()
|
||||
if response == TERMINATING_SYMBOL:
|
||||
if response is TERMINATING_SYMBOL:
|
||||
break
|
||||
yield response
|
||||
|
||||
|
@ -25,4 +25,7 @@ isort:
|
||||
source env/bin/activate; isort --ignore-whitespace --atomic -w 120 gpt4all
|
||||
|
||||
test:
|
||||
source env/bin/activate; pytest -s gpt4all/tests -k "not test_inference_long"
|
||||
|
||||
test_all:
|
||||
source env/bin/activate; pytest -s gpt4all/tests
|
||||
|
Loading…
Reference in New Issue
Block a user