mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Harrison/batch embeds (#972)
Co-authored-by: John Dagdelen <jdagdelen@users.noreply.github.com> Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
This commit is contained in:
parent
ba54d36787
commit
91c6cea227
@ -75,20 +75,27 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
|
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(
|
||||||
|
self, texts: List[str], chunk_size: int = 1000
|
||||||
|
) -> List[List[float]]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: The list of texts to embed.
|
texts: The list of texts to embed.
|
||||||
|
chunk_size: The maximum number of texts to send to OpenAI at once
|
||||||
|
(max 1000).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
responses = [
|
# handle large batches of texts
|
||||||
self._embedding_func(text, engine=self.document_model_name)
|
results = []
|
||||||
for text in texts
|
for i in range(0, len(texts), chunk_size):
|
||||||
]
|
response = self.client.create(
|
||||||
return responses
|
input=texts[i : i + chunk_size], engine=self.document_model_name
|
||||||
|
)
|
||||||
|
results += [r["embedding"] for r in response["data"]]
|
||||||
|
return results
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||||
|
@ -8,7 +8,18 @@ def test_openai_embedding_documents() -> None:
|
|||||||
embedding = OpenAIEmbeddings()
|
embedding = OpenAIEmbeddings()
|
||||||
output = embedding.embed_documents(documents)
|
output = embedding.embed_documents(documents)
|
||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
assert len(output[0]) == 2048
|
assert len(output[0]) == 1536
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_embedding_documents_multiple() -> None:
|
||||||
|
"""Test openai embeddings."""
|
||||||
|
documents = ["foo bar", "bar foo", "foo"]
|
||||||
|
embedding = OpenAIEmbeddings()
|
||||||
|
output = embedding.embed_documents(documents, chunk_size=2)
|
||||||
|
assert len(output) == 3
|
||||||
|
assert len(output[0]) == 1536
|
||||||
|
assert len(output[1]) == 1536
|
||||||
|
assert len(output[2]) == 1536
|
||||||
|
|
||||||
|
|
||||||
def test_openai_embedding_query() -> None:
|
def test_openai_embedding_query() -> None:
|
||||||
@ -16,4 +27,4 @@ def test_openai_embedding_query() -> None:
|
|||||||
document = "foo bar"
|
document = "foo bar"
|
||||||
embedding = OpenAIEmbeddings()
|
embedding = OpenAIEmbeddings()
|
||||||
output = embedding.embed_query(document)
|
output = embedding.embed_query(document)
|
||||||
assert len(output) == 2048
|
assert len(output) == 1536
|
||||||
|
Loading…
Reference in New Issue
Block a user