mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-02 00:57:09 +00:00
Bert
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from .gpt4all import GPT4All # noqa
|
||||
from .gpt4all import GPT4All, embed # noqa
|
||||
from .pyllmodel import LLModel # noqa
|
||||
|
@@ -15,6 +15,20 @@ from . import pyllmodel
|
||||
# TODO: move to config
|
||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
||||
|
||||
def embed(
|
||||
text: str
|
||||
) -> list[float]:
|
||||
"""
|
||||
Generate an embedding for all GPT4All.
|
||||
|
||||
Args:
|
||||
text: The text document to generate an embedding for.
|
||||
|
||||
Returns:
|
||||
An embedding of your document of text.
|
||||
"""
|
||||
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
|
||||
return model.model.generate_embedding(text)
|
||||
|
||||
class GPT4All:
|
||||
"""
|
||||
|
@@ -112,6 +112,19 @@ llmodel.llmodel_prompt.argtypes = [
|
||||
|
||||
llmodel.llmodel_prompt.restype = None
|
||||
|
||||
llmodel.llmodel_embedding.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.POINTER(ctypes.c_size_t),
|
||||
]
|
||||
|
||||
llmodel.llmodel_embedding.restype = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
llmodel.llmodel_free_embedding.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_float)
|
||||
]
|
||||
llmodel.llmodel_free_embedding.restype = None
|
||||
|
||||
llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32]
|
||||
llmodel.llmodel_setThreadCount.restype = None
|
||||
|
||||
@@ -233,6 +246,17 @@ class LLModel:
|
||||
self.context.repeat_last_n = repeat_last_n
|
||||
self.context.context_erase = context_erase
|
||||
|
||||
def generate_embedding(
|
||||
self,
|
||||
text: str
|
||||
) -> list[float]:
|
||||
embedding_size = ctypes.c_size_t()
|
||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||
embedding_array = ctypes.cast(embedding_ptr, ctypes.POINTER(ctypes.c_float * embedding_size.value)).contents
|
||||
llmodel.llmodel_free_embedding(embedding_ptr)
|
||||
return list(embedding_array)
|
||||
|
||||
def prompt_model(
|
||||
self,
|
||||
prompt: str,
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user