Improvement[Embeddings] Add dimension support to ZhipuAIEmbeddings (#25274)

- In the in ` embedding-3 ` and later models of Zhipu AI, it is
supported to specify the dimensions parameter of Embedding. Ref:
https://bigmodel.cn/dev/api#text_embedding-3 .
- Add test case for `embedding-3` model by assigning dimensions.
This commit is contained in:
ZhangShenao 2024-08-12 04:20:37 +08:00 committed by GitHub
parent 9cd608efb3
commit 43deed2a95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@ -70,6 +70,11 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""Model name"""
api_key: str
"""Automatically inferred from env var `ZHIPU_API_KEY` if not provided."""
dimensions: Optional[int] = None
"""The number of dimensions the resulting output embeddings should have.
Only supported in `embedding-3` and later models.
"""
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
@ -110,6 +115,13 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
if self.dimensions is not None:
resp = self.client.embeddings.create(
model=self.model,
input=texts,
dimensions=self.dimensions,
)
else:
resp = self.client.embeddings.create(model=self.model, input=texts)
embeddings = [r.embedding for r in resp.data]
return embeddings

View File

@ -18,3 +18,14 @@ def test_zhipuai_embedding_query() -> None:
embedding = ZhipuAIEmbeddings() # type: ignore[call-arg]
res = embedding.embed_query(document)
assert len(res) == 1024 # type: ignore[arg-type]
def test_zhipuai_embedding_dimensions() -> None:
"""Test ZhipuAI Text Embedding for query by assigning dimensions"""
document = "This is a test query."
embedding = ZhipuAIEmbeddings(
model="embedding-3",
dimensions=2048,
) # type: ignore[call-arg]
res = embedding.embed_query(document)
assert len(res) == 2048 # type: ignore[arg-type]