upstage: Support batch input in embedding request. (#21730)

**Description:** upstage embedding now supports batch input.
This commit is contained in:
JuHyung Son 2024-05-16 03:13:44 +02:00 committed by GitHub
parent c5a981e3b4
commit 38c297a025
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 23 deletions

View File

@ -31,6 +31,9 @@ from langchain_core.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_EMBED_BATCH_SIZE = 10
MAX_EMBED_BATCH_SIZE = 100
class UpstageEmbeddings(BaseModel, Embeddings): class UpstageEmbeddings(BaseModel, Embeddings):
"""UpstageEmbeddings embedding model. """UpstageEmbeddings embedding model.
@ -48,9 +51,9 @@ class UpstageEmbeddings(BaseModel, Embeddings):
client: Any = Field(default=None, exclude=True) #: :meta private: client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "solar-1-mini-embedding" model: str = Field(...)
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`. """Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
Instead, use 'solar-1-mini-embedding' for example. Instead, use 'solar-embedding-1-large' for example.
""" """
dimensions: Optional[int] = None dimensions: Optional[int] = None
"""The number of dimensions the resulting output embeddings should have. """The number of dimensions the resulting output embeddings should have.
@ -68,6 +71,7 @@ class UpstageEmbeddings(BaseModel, Embeddings):
Not yet supported. Not yet supported.
""" """
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE
allowed_special: Union[Literal["all"], Set[str]] = set() allowed_special: Union[Literal["all"], Set[str]] = set()
"""Not yet supported.""" """Not yet supported."""
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
@ -193,16 +197,19 @@ class UpstageEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
embeddings = [] assert (
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
params = self._invocation_params params = self._invocation_params
params["model"] = params["model"] + "-passage" params["model"] = params["model"] + "-passage"
embeddings = []
for text in texts: batch_size = min(self.embed_batch_size, len(texts))
response = self.client.create(input=text, **params) for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
data = self.client.create(input=batch, **params).data
embeddings.extend([r.embedding for r in data])
if not isinstance(response, dict):
response = response.model_dump()
embeddings.extend([i["embedding"] for i in response["data"]])
return embeddings return embeddings
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
@ -232,16 +239,18 @@ class UpstageEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
embeddings = [] assert (
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
params = self._invocation_params params = self._invocation_params
params["model"] = params["model"] + "-passage" params["model"] = params["model"] + "-passage"
embeddings = []
for text in texts: batch_size = min(self.embed_batch_size, len(texts))
response = await self.async_client.create(input=text, **params) for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
if not isinstance(response, dict): response = await self.async_client.create(input=batch, **params)
response = response.model_dump() embeddings.extend([r.embedding for r in response.data])
embeddings.extend([i["embedding"] for i in response["data"]])
return embeddings return embeddings
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]:

View File

@ -6,7 +6,7 @@ from langchain_upstage import UpstageEmbeddings
def test_langchain_upstage_embed_documents() -> None: def test_langchain_upstage_embed_documents() -> None:
"""Test Upstage embeddings.""" """Test Upstage embeddings."""
documents = ["foo bar", "bar foo"] documents = ["foo bar", "bar foo"]
embedding = UpstageEmbeddings() embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = embedding.embed_documents(documents) output = embedding.embed_documents(documents)
assert len(output) == 2 assert len(output) == 2
assert len(output[0]) > 0 assert len(output[0]) > 0
@ -15,7 +15,7 @@ def test_langchain_upstage_embed_documents() -> None:
def test_langchain_upstage_embed_query() -> None: def test_langchain_upstage_embed_query() -> None:
"""Test Upstage embeddings.""" """Test Upstage embeddings."""
query = "foo bar" query = "foo bar"
embedding = UpstageEmbeddings() embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = embedding.embed_query(query) output = embedding.embed_query(query)
assert len(output) > 0 assert len(output) > 0
@ -23,7 +23,7 @@ def test_langchain_upstage_embed_query() -> None:
async def test_langchain_upstage_aembed_documents() -> None: async def test_langchain_upstage_aembed_documents() -> None:
"""Test Upstage embeddings asynchronous.""" """Test Upstage embeddings asynchronous."""
documents = ["foo bar", "bar foo"] documents = ["foo bar", "bar foo"]
embedding = UpstageEmbeddings() embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = await embedding.aembed_documents(documents) output = await embedding.aembed_documents(documents)
assert len(output) == 2 assert len(output) == 2
assert len(output[0]) > 0 assert len(output[0]) > 0
@ -32,6 +32,6 @@ async def test_langchain_upstage_aembed_documents() -> None:
async def test_langchain_upstage_aembed_query() -> None: async def test_langchain_upstage_aembed_query() -> None:
"""Test Upstage embeddings asynchronous.""" """Test Upstage embeddings asynchronous."""
query = "foo bar" query = "foo bar"
embedding = UpstageEmbeddings() embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = await embedding.aembed_query(query) output = await embedding.aembed_query(query)
assert len(output) > 0 assert len(output) > 0

View File

@ -11,15 +11,22 @@ os.environ["UPSTAGE_API_KEY"] = "foo"
def test_initialization() -> None: def test_initialization() -> None:
"""Test embedding model initialization.""" """Test embedding model initialization."""
UpstageEmbeddings() UpstageEmbeddings(model="solar-embedding-1-large")
def test_upstage_invalid_model_kwargs() -> None: def test_upstage_invalid_model_kwargs() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
UpstageEmbeddings(model_kwargs={"model": "foo"}) UpstageEmbeddings(
model="solar-embedding-1-large", model_kwargs={"model": "foo"}
)
def test_upstage_invalid_model() -> None:
with pytest.raises(ValueError):
UpstageEmbeddings()
def test_upstage_incorrect_field() -> None: def test_upstage_incorrect_field() -> None:
with pytest.warns(match="not default parameter"): with pytest.warns(match="not default parameter"):
llm = UpstageEmbeddings(foo="bar") llm = UpstageEmbeddings(model="solar-embedding-1-large", foo="bar")
assert llm.model_kwargs == {"foo": "bar"} assert llm.model_kwargs == {"foo": "bar"}

View File

@ -8,6 +8,6 @@ def test_chat_upstage_secrets() -> None:
def test_upstage_embeddings_secrets() -> None: def test_upstage_embeddings_secrets() -> None:
o = UpstageEmbeddings(upstage_api_key="foo") o = UpstageEmbeddings(model="solar-embedding-1-large", upstage_api_key="foo")
s = str(o) s = str(o)
assert "foo" not in s assert "foo" not in s