upstage: Support batch input in embedding request. (#21730)

**Description:** upstage embedding now supports batch input.
This commit is contained in:
JuHyung Son 2024-05-16 03:13:44 +02:00 committed by GitHub
parent c5a981e3b4
commit 38c297a025
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 23 deletions

View File

@ -31,6 +31,9 @@ from langchain_core.utils import (
logger = logging.getLogger(__name__)
DEFAULT_EMBED_BATCH_SIZE = 10
MAX_EMBED_BATCH_SIZE = 100
class UpstageEmbeddings(BaseModel, Embeddings):
"""UpstageEmbeddings embedding model.
@ -48,9 +51,9 @@ class UpstageEmbeddings(BaseModel, Embeddings):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "solar-1-mini-embedding"
model: str = Field(...)
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
Instead, use 'solar-1-mini-embedding' for example.
Instead, use 'solar-embedding-1-large' for example.
"""
dimensions: Optional[int] = None
"""The number of dimensions the resulting output embeddings should have.
@ -68,6 +71,7 @@ class UpstageEmbeddings(BaseModel, Embeddings):
Not yet supported.
"""
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE
allowed_special: Union[Literal["all"], Set[str]] = set()
"""Not yet supported."""
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
@ -193,16 +197,19 @@ class UpstageEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = []
assert (
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
params = self._invocation_params
params["model"] = params["model"] + "-passage"
embeddings = []
for text in texts:
response = self.client.create(input=text, **params)
batch_size = min(self.embed_batch_size, len(texts))
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
data = self.client.create(input=batch, **params).data
embeddings.extend([r.embedding for r in data])
if not isinstance(response, dict):
response = response.model_dump()
embeddings.extend([i["embedding"] for i in response["data"]])
return embeddings
def embed_query(self, text: str) -> List[float]:
@ -232,16 +239,18 @@ class UpstageEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
embeddings = []
assert (
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
params = self._invocation_params
params["model"] = params["model"] + "-passage"
embeddings = []
for text in texts:
response = await self.async_client.create(input=text, **params)
if not isinstance(response, dict):
response = response.model_dump()
embeddings.extend([i["embedding"] for i in response["data"]])
batch_size = min(self.embed_batch_size, len(texts))
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
response = await self.async_client.create(input=batch, **params)
embeddings.extend([r.embedding for r in response.data])
return embeddings
async def aembed_query(self, text: str) -> List[float]:

View File

@ -6,7 +6,7 @@ from langchain_upstage import UpstageEmbeddings
def test_langchain_upstage_embed_documents() -> None:
"""Test Upstage embeddings."""
documents = ["foo bar", "bar foo"]
embedding = UpstageEmbeddings()
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) > 0
@ -15,7 +15,7 @@ def test_langchain_upstage_embed_documents() -> None:
def test_langchain_upstage_embed_query() -> None:
"""Test Upstage embeddings."""
query = "foo bar"
embedding = UpstageEmbeddings()
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = embedding.embed_query(query)
assert len(output) > 0
@ -23,7 +23,7 @@ def test_langchain_upstage_embed_query() -> None:
async def test_langchain_upstage_aembed_documents() -> None:
"""Test Upstage embeddings asynchronous."""
documents = ["foo bar", "bar foo"]
embedding = UpstageEmbeddings()
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = await embedding.aembed_documents(documents)
assert len(output) == 2
assert len(output[0]) > 0
@ -32,6 +32,6 @@ async def test_langchain_upstage_aembed_documents() -> None:
async def test_langchain_upstage_aembed_query() -> None:
"""Test Upstage embeddings asynchronous."""
query = "foo bar"
embedding = UpstageEmbeddings()
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
output = await embedding.aembed_query(query)
assert len(output) > 0

View File

@ -11,15 +11,22 @@ os.environ["UPSTAGE_API_KEY"] = "foo"
def test_initialization() -> None:
"""Test embedding model initialization."""
UpstageEmbeddings()
UpstageEmbeddings(model="solar-embedding-1-large")
def test_upstage_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
UpstageEmbeddings(model_kwargs={"model": "foo"})
UpstageEmbeddings(
model="solar-embedding-1-large", model_kwargs={"model": "foo"}
)
def test_upstage_invalid_model() -> None:
with pytest.raises(ValueError):
UpstageEmbeddings()
def test_upstage_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = UpstageEmbeddings(foo="bar")
llm = UpstageEmbeddings(model="solar-embedding-1-large", foo="bar")
assert llm.model_kwargs == {"foo": "bar"}

View File

@ -8,6 +8,6 @@ def test_chat_upstage_secrets() -> None:
def test_upstage_embeddings_secrets() -> None:
o = UpstageEmbeddings(upstage_api_key="foo")
o = UpstageEmbeddings(model="solar-embedding-1-large", upstage_api_key="foo")
s = str(o)
assert "foo" not in s