mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-04 19:17:59 +00:00
Handle edge cases when generating embeddings (#1215)
* Handle edge cases when generating embeddings * Improve Python handling & add llmodel_c.h note - In the Python bindings fail fast with a ValueError when text is empty - Advice other bindings authors to do likewise in llmodel_c.h
This commit is contained in:
parent
1e74171a7b
commit
2d02c65177
@ -168,10 +168,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
|
|
||||||
float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size)
|
float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size)
|
||||||
{
|
{
|
||||||
|
if (model == nullptr || text == nullptr || !strlen(text)) {
|
||||||
|
*embedding_size = 0;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||||
std::vector<float> embeddingVector = wrapper->llModel->embedding(text);
|
std::vector<float> embeddingVector = wrapper->llModel->embedding(text);
|
||||||
float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float));
|
float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float));
|
||||||
if(embedding == nullptr) {
|
if (embedding == nullptr) {
|
||||||
*embedding_size = 0;
|
*embedding_size = 0;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -173,6 +173,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate an embedding using the model.
|
* Generate an embedding using the model.
|
||||||
|
* NOTE: If given NULL pointers for the model or text, or an empty text, a NULL pointer will be
|
||||||
|
* returned. Bindings should signal an error when NULL is the return value.
|
||||||
* @param model A pointer to the llmodel_model instance.
|
* @param model A pointer to the llmodel_model instance.
|
||||||
* @param text A string representing the text to generate an embedding for.
|
* @param text A string representing the text to generate an embedding for.
|
||||||
* @param embedding_size A pointer to a size_t type that will be set by the call indicating the length
|
* @param embedding_size A pointer to a size_t type that will be set by the call indicating the length
|
||||||
|
@ -251,6 +251,8 @@ class LLModel:
|
|||||||
self,
|
self,
|
||||||
text: str
|
text: str
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Text must not be None or empty")
|
||||||
embedding_size = ctypes.c_size_t()
|
embedding_size = ctypes.c_size_t()
|
||||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||||
|
@ -3,6 +3,7 @@ from io import StringIO
|
|||||||
|
|
||||||
from gpt4all import GPT4All, Embed4All
|
from gpt4all import GPT4All, Embed4All
|
||||||
import time
|
import time
|
||||||
|
import pytest
|
||||||
|
|
||||||
def test_inference():
|
def test_inference():
|
||||||
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
|
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
|
||||||
@ -107,3 +108,9 @@ def test_embedding():
|
|||||||
#for i, value in enumerate(output):
|
#for i, value in enumerate(output):
|
||||||
#print(f'Value at index {i}: {value}')
|
#print(f'Value at index {i}: {value}')
|
||||||
assert len(output) == 384
|
assert len(output) == 384
|
||||||
|
|
||||||
|
def test_empty_embedding():
|
||||||
|
text = ''
|
||||||
|
embedder = Embed4All()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
output = embedder.embed(text)
|
||||||
|
Loading…
Reference in New Issue
Block a user