Fixup bert python bindings.

This commit is contained in:
Adam Treat
2023-07-13 17:57:48 -04:00
committed by AT
parent 6200900677
commit ee4186d579
5 changed files with 37 additions and 23 deletions

View File

@@ -1,2 +1,2 @@
from .gpt4all import GPT4All, embed # noqa
from .gpt4all import GPT4All, Embedder # noqa
from .pyllmodel import LLModel # noqa

View File

@@ -15,20 +15,26 @@ from . import pyllmodel
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
def embed(
text: str
) -> list[float]:
"""
Generate an embedding for all GPT4All.
class Embedder:
def __init__(
self
):
self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=8)
Args:
text: The text document to generate an embedding for.
def embed(
self,
text: str
) -> list[float]:
"""
Generate an embedding for all GPT4All.
Returns:
An embedding of your document of text.
"""
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
return model.model.generate_embedding(text)
Args:
text: The text document to generate an embedding for.
Returns:
An embedding of your document of text.
"""
return self.gpt4all.model.generate_embedding(text)
class GPT4All:
"""

View File

@@ -253,7 +253,7 @@ class LLModel:
embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode('utf-8'))
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
embedding_array = ctypes.cast(embedding_ptr, ctypes.POINTER(ctypes.c_float * embedding_size.value)).contents
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
llmodel.llmodel_free_embedding(embedding_ptr)
return list(embedding_array)

File diff suppressed because one or more lines are too long