mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
Improvement[Embeddings] Add dimension support to ZhipuAIEmbeddings
(#25274)
- In the in ` embedding-3 ` and later models of Zhipu AI, it is supported to specify the dimensions parameter of Embedding. Ref: https://bigmodel.cn/dev/api#text_embedding-3 . - Add test case for `embedding-3` model by assigning dimensions.
This commit is contained in:
parent
9cd608efb3
commit
43deed2a95
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
@ -70,6 +70,11 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Model name"""
|
"""Model name"""
|
||||||
api_key: str
|
api_key: str
|
||||||
"""Automatically inferred from env var `ZHIPU_API_KEY` if not provided."""
|
"""Automatically inferred from env var `ZHIPU_API_KEY` if not provided."""
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
"""The number of dimensions the resulting output embeddings should have.
|
||||||
|
|
||||||
|
Only supported in `embedding-3` and later models.
|
||||||
|
"""
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -110,6 +115,13 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
|||||||
A list of embeddings for each document in the input list.
|
A list of embeddings for each document in the input list.
|
||||||
Each embedding is represented as a list of float values.
|
Each embedding is represented as a list of float values.
|
||||||
"""
|
"""
|
||||||
resp = self.client.embeddings.create(model=self.model, input=texts)
|
if self.dimensions is not None:
|
||||||
|
resp = self.client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=texts,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resp = self.client.embeddings.create(model=self.model, input=texts)
|
||||||
embeddings = [r.embedding for r in resp.data]
|
embeddings = [r.embedding for r in resp.data]
|
||||||
return embeddings
|
return embeddings
|
||||||
|
@ -18,3 +18,14 @@ def test_zhipuai_embedding_query() -> None:
|
|||||||
embedding = ZhipuAIEmbeddings() # type: ignore[call-arg]
|
embedding = ZhipuAIEmbeddings() # type: ignore[call-arg]
|
||||||
res = embedding.embed_query(document)
|
res = embedding.embed_query(document)
|
||||||
assert len(res) == 1024 # type: ignore[arg-type]
|
assert len(res) == 1024 # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_zhipuai_embedding_dimensions() -> None:
|
||||||
|
"""Test ZhipuAI Text Embedding for query by assigning dimensions"""
|
||||||
|
document = "This is a test query."
|
||||||
|
embedding = ZhipuAIEmbeddings(
|
||||||
|
model="embedding-3",
|
||||||
|
dimensions=2048,
|
||||||
|
) # type: ignore[call-arg]
|
||||||
|
res = embedding.embed_query(document)
|
||||||
|
assert len(res) == 2048 # type: ignore[arg-type]
|
||||||
|
Loading…
Reference in New Issue
Block a user