mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-12-23 03:55:18 +00:00
mpt bindings
This commit is contained in:
@@ -29,7 +29,7 @@ class GPT4All():
|
||||
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
|
||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||
model_type: Model architecture to use - currently, only options are 'llama' or 'gptj'. Only required if model
|
||||
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
|
||||
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
|
||||
Default is None.
|
||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||
@@ -263,6 +263,8 @@ class GPT4All():
|
||||
return pyllmodel.GPTJModel()
|
||||
elif model_type == "llama":
|
||||
return pyllmodel.LlamaModel()
|
||||
elif model_type == "mpt":
|
||||
return pyllmodel.MPTModel()
|
||||
else:
|
||||
raise ValueError(f"No corresponding model for model_type: {model_type}")
|
||||
|
||||
@@ -286,13 +288,20 @@ class GPT4All():
|
||||
"ggml-vicuna-7b-1.1-q4_2.bin",
|
||||
"ggml-vicuna-13b-1.1-q4_2.bin",
|
||||
"ggml-wizardLM-7B.q4_2.bin",
|
||||
"ggml-stable-vicuna-13B.q4_2.bin"
|
||||
"ggml-stable-vicuna-13B.q4_2.bin",
|
||||
"ggml-nous-gpt4-vicuna-13b.bin"
|
||||
]
|
||||
|
||||
MPT_MODELS = [
|
||||
"ggml-mpt-7b-base.bin"
|
||||
]
|
||||
|
||||
if model_name in GPTJ_MODELS:
|
||||
return pyllmodel.GPTJModel()
|
||||
elif model_name in LLAMA_MODELS:
|
||||
return pyllmodel.LlamaModel()
|
||||
elif model_name in MPT_MODELS:
|
||||
return pyllmodel.MPTModel()
|
||||
else:
|
||||
err_msg = f"""No corresponding model for provided filename {model_name}.
|
||||
If this is a custom model, make sure to specify a valid model_type.
|
||||
|
||||
@@ -46,6 +46,9 @@ llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
|
||||
|
||||
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||
@@ -236,3 +239,17 @@ class LlamaModel(LLModel):
|
||||
if self.model is not None:
|
||||
llmodel.llmodel_llama_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
|
||||
class MPTModel(LLModel):
|
||||
|
||||
model_type = "mpt"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = llmodel.llmodel_mpt_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
llmodel.llmodel_mpt_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
Reference in New Issue
Block a user