mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
upstage: Support batch input in embedding request. (#21730)
**Description:** upstage embedding now supports batch input.
This commit is contained in:
parent
c5a981e3b4
commit
38c297a025
@ -31,6 +31,9 @@ from langchain_core.utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_EMBED_BATCH_SIZE = 10
|
||||
MAX_EMBED_BATCH_SIZE = 100
|
||||
|
||||
|
||||
class UpstageEmbeddings(BaseModel, Embeddings):
|
||||
"""UpstageEmbeddings embedding model.
|
||||
@ -48,9 +51,9 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model: str = "solar-1-mini-embedding"
|
||||
model: str = Field(...)
|
||||
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
||||
Instead, use 'solar-1-mini-embedding' for example.
|
||||
Instead, use 'solar-embedding-1-large' for example.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""The number of dimensions the resulting output embeddings should have.
|
||||
@ -68,6 +71,7 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||
"""Not yet supported."""
|
||||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||
@ -193,16 +197,19 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
assert (
|
||||
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
|
||||
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
embeddings = []
|
||||
|
||||
for text in texts:
|
||||
response = self.client.create(input=text, **params)
|
||||
batch_size = min(self.embed_batch_size, len(texts))
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
data = self.client.create(input=batch, **params).data
|
||||
embeddings.extend([r.embedding for r in data])
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@ -232,16 +239,18 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
assert (
|
||||
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
|
||||
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
embeddings = []
|
||||
|
||||
for text in texts:
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
||||
batch_size = min(self.embed_batch_size, len(texts))
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
response = await self.async_client.create(input=batch, **params)
|
||||
embeddings.extend([r.embedding for r in response.data])
|
||||
return embeddings
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
|
@ -6,7 +6,7 @@ from langchain_upstage import UpstageEmbeddings
|
||||
def test_langchain_upstage_embed_documents() -> None:
|
||||
"""Test Upstage embeddings."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = UpstageEmbeddings()
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
@ -15,7 +15,7 @@ def test_langchain_upstage_embed_documents() -> None:
|
||||
def test_langchain_upstage_embed_query() -> None:
|
||||
"""Test Upstage embeddings."""
|
||||
query = "foo bar"
|
||||
embedding = UpstageEmbeddings()
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) > 0
|
||||
|
||||
@ -23,7 +23,7 @@ def test_langchain_upstage_embed_query() -> None:
|
||||
async def test_langchain_upstage_aembed_documents() -> None:
|
||||
"""Test Upstage embeddings asynchronous."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = UpstageEmbeddings()
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = await embedding.aembed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
@ -32,6 +32,6 @@ async def test_langchain_upstage_aembed_documents() -> None:
|
||||
async def test_langchain_upstage_aembed_query() -> None:
|
||||
"""Test Upstage embeddings asynchronous."""
|
||||
query = "foo bar"
|
||||
embedding = UpstageEmbeddings()
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = await embedding.aembed_query(query)
|
||||
assert len(output) > 0
|
||||
|
@ -11,15 +11,22 @@ os.environ["UPSTAGE_API_KEY"] = "foo"
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
UpstageEmbeddings()
|
||||
UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
|
||||
|
||||
def test_upstage_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
UpstageEmbeddings(model_kwargs={"model": "foo"})
|
||||
UpstageEmbeddings(
|
||||
model="solar-embedding-1-large", model_kwargs={"model": "foo"}
|
||||
)
|
||||
|
||||
|
||||
def test_upstage_invalid_model() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
UpstageEmbeddings()
|
||||
|
||||
|
||||
def test_upstage_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = UpstageEmbeddings(foo="bar")
|
||||
llm = UpstageEmbeddings(model="solar-embedding-1-large", foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
@ -8,6 +8,6 @@ def test_chat_upstage_secrets() -> None:
|
||||
|
||||
|
||||
def test_upstage_embeddings_secrets() -> None:
|
||||
o = UpstageEmbeddings(upstage_api_key="foo")
|
||||
o = UpstageEmbeddings(model="solar-embedding-1-large", upstage_api_key="foo")
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
||||
|
Loading…
Reference in New Issue
Block a user