Dev2049/hf emb encode kwargs (#3925)

Thanks @amogkam for the addition! Refactored slightly

---------

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Davis Chase
2023-05-01 20:27:41 -07:00
committed by GitHub
parent ffc87233a1
commit 5db6b796cf
2 changed files with 7 additions and 8 deletions

View File

@@ -1,5 +1,4 @@
"""Test huggingface embeddings."""
import unittest
from langchain.embeddings.huggingface import (
HuggingFaceEmbeddings,
@@ -7,7 +6,6 @@ from langchain.embeddings.huggingface import (
)
@unittest.skip("This test causes a segfault.")
def test_huggingface_embedding_documents() -> None:
"""Test huggingface embeddings."""
documents = ["foo bar"]
@@ -17,11 +15,10 @@ def test_huggingface_embedding_documents() -> None:
assert len(output[0]) == 768
@unittest.skip("This test causes a segfault.")
def test_huggingface_embedding_query() -> None:
"""Test huggingface embeddings."""
document = "foo bar"
embedding = HuggingFaceEmbeddings()
embedding = HuggingFaceEmbeddings(encode_kwargs={"batch_size": 16})
output = embedding.embed_query(document)
assert len(output) == 768