From 38c297a0256d35bc64ea8c652786daa0e34b860d Mon Sep 17 00:00:00 2001 From: JuHyung Son Date: Thu, 16 May 2024 03:13:44 +0200 Subject: [PATCH] upstage: Support batch input in embedding request. (#21730) **Description:** upstage embedding now supports batch input. --- .../upstage/langchain_upstage/embeddings.py | 39 ++++++++++++------- .../integration_tests/test_embeddings.py | 8 ++-- .../tests/unit_tests/test_embeddings.py | 13 +++++-- .../upstage/tests/unit_tests/test_secrets.py | 2 +- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/libs/partners/upstage/langchain_upstage/embeddings.py b/libs/partners/upstage/langchain_upstage/embeddings.py index 31708239be2..08976c608f7 100644 --- a/libs/partners/upstage/langchain_upstage/embeddings.py +++ b/libs/partners/upstage/langchain_upstage/embeddings.py @@ -31,6 +31,9 @@ from langchain_core.utils import ( logger = logging.getLogger(__name__) +DEFAULT_EMBED_BATCH_SIZE = 10 +MAX_EMBED_BATCH_SIZE = 100 + class UpstageEmbeddings(BaseModel, Embeddings): """UpstageEmbeddings embedding model. @@ -48,9 +51,9 @@ class UpstageEmbeddings(BaseModel, Embeddings): client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: - model: str = "solar-1-mini-embedding" + model: str = Field(...) """Embeddings model name to use. Do not add suffixes like `-query` and `-passage`. - Instead, use 'solar-1-mini-embedding' for example. + Instead, use 'solar-embedding-1-large' for example. """ dimensions: Optional[int] = None """The number of dimensions the resulting output embeddings should have. @@ -68,6 +71,7 @@ class UpstageEmbeddings(BaseModel, Embeddings): Not yet supported. """ + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE allowed_special: Union[Literal["all"], Set[str]] = set() """Not yet supported.""" disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" @@ -193,16 +197,19 @@ class UpstageEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - embeddings = [] + assert ( + self.embed_batch_size <= MAX_EMBED_BATCH_SIZE + ), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}." params = self._invocation_params params["model"] = params["model"] + "-passage" + embeddings = [] - for text in texts: - response = self.client.create(input=text, **params) + batch_size = min(self.embed_batch_size, len(texts)) + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + data = self.client.create(input=batch, **params).data + embeddings.extend([r.embedding for r in data]) - if not isinstance(response, dict): - response = response.model_dump() - embeddings.extend([i["embedding"] for i in response["data"]]) return embeddings def embed_query(self, text: str) -> List[float]: @@ -232,16 +239,18 @@ class UpstageEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - embeddings = [] + assert ( + self.embed_batch_size <= MAX_EMBED_BATCH_SIZE + ), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}." params = self._invocation_params params["model"] = params["model"] + "-passage" + embeddings = [] - for text in texts: - response = await self.async_client.create(input=text, **params) - - if not isinstance(response, dict): - response = response.model_dump() - embeddings.extend([i["embedding"] for i in response["data"]]) + batch_size = min(self.embed_batch_size, len(texts)) + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + response = await self.async_client.create(input=batch, **params) + embeddings.extend([r.embedding for r in response.data]) return embeddings async def aembed_query(self, text: str) -> List[float]: diff --git a/libs/partners/upstage/tests/integration_tests/test_embeddings.py b/libs/partners/upstage/tests/integration_tests/test_embeddings.py index 81730275dfd..bd056d2d40b 100644 --- a/libs/partners/upstage/tests/integration_tests/test_embeddings.py +++ b/libs/partners/upstage/tests/integration_tests/test_embeddings.py @@ -6,7 +6,7 @@ from langchain_upstage import UpstageEmbeddings def test_langchain_upstage_embed_documents() -> None: """Test Upstage embeddings.""" documents = ["foo bar", "bar foo"] - embedding = UpstageEmbeddings() + embedding = UpstageEmbeddings(model="solar-embedding-1-large") output = embedding.embed_documents(documents) assert len(output) == 2 assert len(output[0]) > 0 @@ -15,7 +15,7 @@ def test_langchain_upstage_embed_documents() -> None: def test_langchain_upstage_embed_query() -> None: """Test Upstage embeddings.""" query = "foo bar" - embedding = UpstageEmbeddings() + embedding = UpstageEmbeddings(model="solar-embedding-1-large") output = embedding.embed_query(query) assert len(output) > 0 @@ -23,7 +23,7 @@ def test_langchain_upstage_embed_query() -> None: async def test_langchain_upstage_aembed_documents() -> None: """Test Upstage embeddings asynchronous.""" documents = ["foo bar", "bar foo"] - embedding = UpstageEmbeddings() + embedding = UpstageEmbeddings(model="solar-embedding-1-large") output = await embedding.aembed_documents(documents) assert len(output) == 2 assert len(output[0]) > 0 @@ -32,6 +32,6 @@ async def test_langchain_upstage_aembed_documents() -> None: async def test_langchain_upstage_aembed_query() -> None: """Test Upstage embeddings asynchronous.""" query = "foo bar" - embedding = UpstageEmbeddings() + embedding = UpstageEmbeddings(model="solar-embedding-1-large") output = await embedding.aembed_query(query) assert len(output) > 0 diff --git a/libs/partners/upstage/tests/unit_tests/test_embeddings.py b/libs/partners/upstage/tests/unit_tests/test_embeddings.py index 08627bf5f59..8a838a6c3b4 100644 --- a/libs/partners/upstage/tests/unit_tests/test_embeddings.py +++ b/libs/partners/upstage/tests/unit_tests/test_embeddings.py @@ -11,15 +11,22 @@ os.environ["UPSTAGE_API_KEY"] = "foo" def test_initialization() -> None: """Test embedding model initialization.""" - UpstageEmbeddings() + UpstageEmbeddings(model="solar-embedding-1-large") def test_upstage_invalid_model_kwargs() -> None: with pytest.raises(ValueError): - UpstageEmbeddings(model_kwargs={"model": "foo"}) + UpstageEmbeddings( + model="solar-embedding-1-large", model_kwargs={"model": "foo"} + ) + + +def test_upstage_invalid_model() -> None: + with pytest.raises(ValueError): + UpstageEmbeddings() def test_upstage_incorrect_field() -> None: with pytest.warns(match="not default parameter"): - llm = UpstageEmbeddings(foo="bar") + llm = UpstageEmbeddings(model="solar-embedding-1-large", foo="bar") assert llm.model_kwargs == {"foo": "bar"} diff --git a/libs/partners/upstage/tests/unit_tests/test_secrets.py b/libs/partners/upstage/tests/unit_tests/test_secrets.py index 23e72cb86c5..5a023cfeb9e 100644 --- a/libs/partners/upstage/tests/unit_tests/test_secrets.py +++ b/libs/partners/upstage/tests/unit_tests/test_secrets.py @@ -8,6 +8,6 @@ def test_chat_upstage_secrets() -> None: def test_upstage_embeddings_secrets() -> None: - o = UpstageEmbeddings(upstage_api_key="foo") + o = UpstageEmbeddings(model="solar-embedding-1-large", upstage_api_key="foo") s = str(o) assert "foo" not in s