diff --git a/libs/community/langchain_community/embeddings/zhipuai.py b/libs/community/langchain_community/embeddings/zhipuai.py index a6c3279010c..b8415aed747 100644 --- a/libs/community/langchain_community/embeddings/zhipuai.py +++ b/libs/community/langchain_community/embeddings/zhipuai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Field, root_validator @@ -70,6 +70,11 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings): """Model name""" api_key: str """Automatically inferred from env var `ZHIPU_API_KEY` if not provided.""" + dimensions: Optional[int] = None + """The number of dimensions the resulting output embeddings should have. + + Only supported in `embedding-3` and later models. + """ @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: @@ -110,6 +115,13 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings): A list of embeddings for each document in the input list. Each embedding is represented as a list of float values. """ - resp = self.client.embeddings.create(model=self.model, input=texts) + if self.dimensions is not None: + resp = self.client.embeddings.create( + model=self.model, + input=texts, + dimensions=self.dimensions, + ) + else: + resp = self.client.embeddings.create(model=self.model, input=texts) embeddings = [r.embedding for r in resp.data] return embeddings diff --git a/libs/community/tests/integration_tests/embeddings/test_zhipuai.py b/libs/community/tests/integration_tests/embeddings/test_zhipuai.py index 3b15bd26440..894088832a5 100644 --- a/libs/community/tests/integration_tests/embeddings/test_zhipuai.py +++ b/libs/community/tests/integration_tests/embeddings/test_zhipuai.py @@ -18,3 +18,14 @@ def test_zhipuai_embedding_query() -> None: embedding = ZhipuAIEmbeddings() # type: ignore[call-arg] res = embedding.embed_query(document) assert len(res) == 1024 # type: ignore[arg-type] + + +def test_zhipuai_embedding_dimensions() -> None: + """Test ZhipuAI Text Embedding for query by assigning dimensions""" + document = "This is a test query." + embedding = ZhipuAIEmbeddings( + model="embedding-3", + dimensions=2048, + ) # type: ignore[call-arg] + res = embedding.embed_query(document) + assert len(res) == 2048 # type: ignore[arg-type]