mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
upstage: Support batch input in embedding request. (#21730)
**Description:** upstage embedding now supports batch input.
This commit is contained in:
parent
c5a981e3b4
commit
38c297a025
@ -31,6 +31,9 @@ from langchain_core.utils import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_EMBED_BATCH_SIZE = 10
|
||||||
|
MAX_EMBED_BATCH_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
class UpstageEmbeddings(BaseModel, Embeddings):
|
class UpstageEmbeddings(BaseModel, Embeddings):
|
||||||
"""UpstageEmbeddings embedding model.
|
"""UpstageEmbeddings embedding model.
|
||||||
@ -48,9 +51,9 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model: str = "solar-1-mini-embedding"
|
model: str = Field(...)
|
||||||
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
||||||
Instead, use 'solar-1-mini-embedding' for example.
|
Instead, use 'solar-embedding-1-large' for example.
|
||||||
"""
|
"""
|
||||||
dimensions: Optional[int] = None
|
dimensions: Optional[int] = None
|
||||||
"""The number of dimensions the resulting output embeddings should have.
|
"""The number of dimensions the resulting output embeddings should have.
|
||||||
@ -68,6 +71,7 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Not yet supported.
|
Not yet supported.
|
||||||
"""
|
"""
|
||||||
|
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE
|
||||||
allowed_special: Union[Literal["all"], Set[str]] = set()
|
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||||
"""Not yet supported."""
|
"""Not yet supported."""
|
||||||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||||
@ -193,16 +197,19 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = []
|
assert (
|
||||||
|
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
|
||||||
|
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
|
||||||
params = self._invocation_params
|
params = self._invocation_params
|
||||||
params["model"] = params["model"] + "-passage"
|
params["model"] = params["model"] + "-passage"
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
for text in texts:
|
batch_size = min(self.embed_batch_size, len(texts))
|
||||||
response = self.client.create(input=text, **params)
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
data = self.client.create(input=batch, **params).data
|
||||||
|
embeddings.extend([r.embedding for r in data])
|
||||||
|
|
||||||
if not isinstance(response, dict):
|
|
||||||
response = response.model_dump()
|
|
||||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
@ -232,16 +239,18 @@ class UpstageEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = []
|
assert (
|
||||||
|
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
|
||||||
|
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
|
||||||
params = self._invocation_params
|
params = self._invocation_params
|
||||||
params["model"] = params["model"] + "-passage"
|
params["model"] = params["model"] + "-passage"
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
for text in texts:
|
batch_size = min(self.embed_batch_size, len(texts))
|
||||||
response = await self.async_client.create(input=text, **params)
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
if not isinstance(response, dict):
|
response = await self.async_client.create(input=batch, **params)
|
||||||
response = response.model_dump()
|
embeddings.extend([r.embedding for r in response.data])
|
||||||
embeddings.extend([i["embedding"] for i in response["data"]])
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
@ -6,7 +6,7 @@ from langchain_upstage import UpstageEmbeddings
|
|||||||
def test_langchain_upstage_embed_documents() -> None:
|
def test_langchain_upstage_embed_documents() -> None:
|
||||||
"""Test Upstage embeddings."""
|
"""Test Upstage embeddings."""
|
||||||
documents = ["foo bar", "bar foo"]
|
documents = ["foo bar", "bar foo"]
|
||||||
embedding = UpstageEmbeddings()
|
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||||
output = embedding.embed_documents(documents)
|
output = embedding.embed_documents(documents)
|
||||||
assert len(output) == 2
|
assert len(output) == 2
|
||||||
assert len(output[0]) > 0
|
assert len(output[0]) > 0
|
||||||
@ -15,7 +15,7 @@ def test_langchain_upstage_embed_documents() -> None:
|
|||||||
def test_langchain_upstage_embed_query() -> None:
|
def test_langchain_upstage_embed_query() -> None:
|
||||||
"""Test Upstage embeddings."""
|
"""Test Upstage embeddings."""
|
||||||
query = "foo bar"
|
query = "foo bar"
|
||||||
embedding = UpstageEmbeddings()
|
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||||
output = embedding.embed_query(query)
|
output = embedding.embed_query(query)
|
||||||
assert len(output) > 0
|
assert len(output) > 0
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ def test_langchain_upstage_embed_query() -> None:
|
|||||||
async def test_langchain_upstage_aembed_documents() -> None:
|
async def test_langchain_upstage_aembed_documents() -> None:
|
||||||
"""Test Upstage embeddings asynchronous."""
|
"""Test Upstage embeddings asynchronous."""
|
||||||
documents = ["foo bar", "bar foo"]
|
documents = ["foo bar", "bar foo"]
|
||||||
embedding = UpstageEmbeddings()
|
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||||
output = await embedding.aembed_documents(documents)
|
output = await embedding.aembed_documents(documents)
|
||||||
assert len(output) == 2
|
assert len(output) == 2
|
||||||
assert len(output[0]) > 0
|
assert len(output[0]) > 0
|
||||||
@ -32,6 +32,6 @@ async def test_langchain_upstage_aembed_documents() -> None:
|
|||||||
async def test_langchain_upstage_aembed_query() -> None:
|
async def test_langchain_upstage_aembed_query() -> None:
|
||||||
"""Test Upstage embeddings asynchronous."""
|
"""Test Upstage embeddings asynchronous."""
|
||||||
query = "foo bar"
|
query = "foo bar"
|
||||||
embedding = UpstageEmbeddings()
|
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||||
output = await embedding.aembed_query(query)
|
output = await embedding.aembed_query(query)
|
||||||
assert len(output) > 0
|
assert len(output) > 0
|
||||||
|
@ -11,15 +11,22 @@ os.environ["UPSTAGE_API_KEY"] = "foo"
|
|||||||
|
|
||||||
def test_initialization() -> None:
|
def test_initialization() -> None:
|
||||||
"""Test embedding model initialization."""
|
"""Test embedding model initialization."""
|
||||||
UpstageEmbeddings()
|
UpstageEmbeddings(model="solar-embedding-1-large")
|
||||||
|
|
||||||
|
|
||||||
def test_upstage_invalid_model_kwargs() -> None:
|
def test_upstage_invalid_model_kwargs() -> None:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
UpstageEmbeddings(model_kwargs={"model": "foo"})
|
UpstageEmbeddings(
|
||||||
|
model="solar-embedding-1-large", model_kwargs={"model": "foo"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upstage_invalid_model() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
UpstageEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
def test_upstage_incorrect_field() -> None:
|
def test_upstage_incorrect_field() -> None:
|
||||||
with pytest.warns(match="not default parameter"):
|
with pytest.warns(match="not default parameter"):
|
||||||
llm = UpstageEmbeddings(foo="bar")
|
llm = UpstageEmbeddings(model="solar-embedding-1-large", foo="bar")
|
||||||
assert llm.model_kwargs == {"foo": "bar"}
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
@ -8,6 +8,6 @@ def test_chat_upstage_secrets() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_upstage_embeddings_secrets() -> None:
|
def test_upstage_embeddings_secrets() -> None:
|
||||||
o = UpstageEmbeddings(upstage_api_key="foo")
|
o = UpstageEmbeddings(model="solar-embedding-1-large", upstage_api_key="foo")
|
||||||
s = str(o)
|
s = str(o)
|
||||||
assert "foo" not in s
|
assert "foo" not in s
|
||||||
|
Loading…
Reference in New Issue
Block a user