From 2d02c65177d81e9945f54d59606394119078b2c4 Mon Sep 17 00:00:00 2001 From: cosmic-snow <134004613+cosmic-snow@users.noreply.github.com> Date: Mon, 17 Jul 2023 22:21:03 +0200 Subject: [PATCH] Handle edge cases when generating embeddings (#1215) * Handle edge cases when generating embeddings * Improve Python handling & add llmodel_c.h note - In the Python bindings fail fast with a ValueError when text is empty - Advice other bindings authors to do likewise in llmodel_c.h --- gpt4all-backend/llmodel_c.cpp | 6 +++++- gpt4all-backend/llmodel_c.h | 2 ++ gpt4all-bindings/python/gpt4all/pyllmodel.py | 2 ++ gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py | 7 +++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index fb916d95..58fe27f5 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -168,10 +168,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt, float *llmodel_embedding(llmodel_model model, const char *text, size_t *embedding_size) { + if (model == nullptr || text == nullptr || !strlen(text)) { + *embedding_size = 0; + return nullptr; + } LLModelWrapper *wrapper = reinterpret_cast(model); std::vector embeddingVector = wrapper->llModel->embedding(text); float *embedding = (float *)malloc(embeddingVector.size() * sizeof(float)); - if(embedding == nullptr) { + if (embedding == nullptr) { *embedding_size = 0; return nullptr; } diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 8d582d08..138a8853 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -173,6 +173,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt, /** * Generate an embedding using the model. + * NOTE: If given NULL pointers for the model or text, or an empty text, a NULL pointer will be + * returned. Bindings should signal an error when NULL is the return value. * @param model A pointer to the llmodel_model instance. * @param text A string representing the text to generate an embedding for. * @param embedding_size A pointer to a size_t type that will be set by the call indicating the length diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index e8895a9c..91395f53 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -251,6 +251,8 @@ class LLModel: self, text: str ) -> list[float]: + if not text: + raise ValueError("Text must not be None or empty") embedding_size = ctypes.c_size_t() c_text = ctypes.c_char_p(text.encode('utf-8')) embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size)) diff --git a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py index 6fdaa6cc..fa798c0c 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py @@ -3,6 +3,7 @@ from io import StringIO from gpt4all import GPT4All, Embed4All import time +import pytest def test_inference(): model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin') @@ -107,3 +108,9 @@ def test_embedding(): #for i, value in enumerate(output): #print(f'Value at index {i}: {value}') assert len(output) == 384 + +def test_empty_embedding(): + text = '' + embedder = Embed4All() + with pytest.raises(ValueError): + output = embedder.embed(text)