mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-10-31 13:51:43 +00:00 
			
		
		
		
	Fixup bert python bindings.
This commit is contained in:
		| @@ -1,2 +1,2 @@ | ||||
| from .gpt4all import GPT4All, embed  # noqa | ||||
| from .gpt4all import GPT4All, Embedder  # noqa | ||||
| from .pyllmodel import LLModel  # noqa | ||||
|   | ||||
| @@ -15,20 +15,26 @@ from . import pyllmodel | ||||
| # TODO: move to config | ||||
| DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") | ||||
|  | ||||
| def embed( | ||||
|     text: str | ||||
| ) -> list[float]: | ||||
|     """ | ||||
|     Generate an embedding for all GPT4All. | ||||
| class Embedder: | ||||
|     def __init__( | ||||
|         self | ||||
|     ): | ||||
|         self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=8) | ||||
|  | ||||
|     Args: | ||||
|         text: The text document to generate an embedding for. | ||||
|     def embed( | ||||
|         self, | ||||
|         text: str | ||||
|     ) -> list[float]: | ||||
|         """ | ||||
|         Generate an embedding for all GPT4All. | ||||
|  | ||||
|     Returns: | ||||
|         An embedding of your document of text. | ||||
|     """ | ||||
|     model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin') | ||||
|     return model.model.generate_embedding(text) | ||||
|         Args: | ||||
|             text: The text document to generate an embedding for. | ||||
|  | ||||
|         Returns: | ||||
|             An embedding of your document of text. | ||||
|         """ | ||||
|         return self.gpt4all.model.generate_embedding(text) | ||||
|  | ||||
| class GPT4All: | ||||
|     """ | ||||
|   | ||||
| @@ -253,7 +253,7 @@ class LLModel: | ||||
|         embedding_size = ctypes.c_size_t() | ||||
|         c_text = ctypes.c_char_p(text.encode('utf-8')) | ||||
|         embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size)) | ||||
|         embedding_array = ctypes.cast(embedding_ptr, ctypes.POINTER(ctypes.c_float * embedding_size.value)).contents | ||||
|         embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)] | ||||
|         llmodel.llmodel_free_embedding(embedding_ptr) | ||||
|         return list(embedding_array) | ||||
|  | ||||
|   | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user