documentation and cleanup

This commit is contained in:
Richard Guo
2023-05-11 11:02:44 -04:00
parent 0534ab59ec
commit 9a015e2b66
5 changed files with 46 additions and 26 deletions

View File

@@ -15,7 +15,7 @@ from . import pyllmodel
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
class GPT4All():
"""Python API for retrieving and interacting with GPT4All models
"""Python API for retrieving and interacting with GPT4All models.
Attribuies:
model: Pointer to underlying C model.
@@ -50,7 +50,7 @@ class GPT4All():
@staticmethod
def list_models():
"""
Fetch model list from https://gpt4all.io/models/models.json
Fetch model list from https://gpt4all.io/models/models.json.
Returns:
Model list in JSON format.
@@ -60,7 +60,7 @@ class GPT4All():
return model_json
@staticmethod
def retrieve_model(model_name: str, model_path: str = None, allow_download = True):
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True) -> str:
"""
Find model file, and if it doesn't exist, download the model.
@@ -113,8 +113,18 @@ class GPT4All():
raise ValueError("Invalid model directory")
@staticmethod
def download_model(model_filename, model_path):
# TODO: Find good way of safely removing file that got interrupted.
def download_model(model_filename: str, model_path: str) -> str:
"""
Download model from https://gpt4all.io.
Args:
model_filename: Filename of model (with .bin extension).
model_path: Path to download model to.
Returns:
Model file destination.
"""
def get_download_url(model_filename):
return f"https://gpt4all.io/models/{model_filename}"
@@ -122,6 +132,7 @@ class GPT4All():
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
download_url = get_download_url(model_filename)
# TODO: Find good way of safely removing file that got interrupted.
response = requests.get(download_url, stream=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1048576 # 1 MB
@@ -141,9 +152,16 @@ class GPT4All():
print("Model downloaded at: " + download_path)
return download_path
def generate(self, prompt: str, **generate_kwargs):
def generate(self, prompt: str, **generate_kwargs) -> str:
"""
Surfaced method of running generate without accessing model object.
Args:
prompt: Raw string to be passed to model.
**generate_kwargs: Optional kwargs to pass to prompt context.
Returns:
Raw string of generated model response.
"""
return self.model.generate(prompt, **generate_kwargs)
@@ -158,13 +176,13 @@ class GPT4All():
generated content.
Args:
messages: Each dictionary should have a "role" key
messages: List of dictionaries. Each dictionary should have a "role" key
with value of "system", "assistant", or "user" and a "content" key with a
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Reponse: {content}".
default_prompt_header: If True (default), add default prompt header after any user specified system messages and
before user/assistant messages.
default_prompt_header: If True (default), add default prompt header after any system role messages and
before user/assistant role messages.
default_prompt_footer: If True (default), add default footer at end of prompt.
verbose: If True (default), print full prompt and generated response.
@@ -175,7 +193,6 @@ class GPT4All():
generated tokens in response, and total tokens.
"choices": List of message dictionary where "content" is generated response and "role" is set
as "assistant". Right now, only one choice is returned by model.
"""
full_prompt = self._build_prompt(messages,
@@ -210,6 +227,7 @@ class GPT4All():
def _build_prompt(messages: List[Dict],
default_prompt_header=True,
default_prompt_footer=False) -> str:
# Helper method to format messages into prompt.
full_prompt = ""
for message in messages:
@@ -238,7 +256,7 @@ class GPT4All():
@staticmethod
def get_model_from_type(model_type: str) -> pyllmodel.LLModel:
# This needs to be updated for each new model
# This needs to be updated for each new model type
# TODO: Might be worth converting model_type to enum
if model_type == "gptj":