diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 4b617a10f66..ecf00a3982a 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -83,7 +83,12 @@ class QianfanChatEndpoint(BaseChatModel): endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk") """ + init_kwargs: Dict[str, Any] = Field(default_factory=dict) + """init kwargs for qianfan client init, such as `query_per_second` which is + associated with qianfan resource object to limit QPS""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """extra params for model invoke using with `do`.""" client: Any @@ -134,6 +139,7 @@ class QianfanChatEndpoint(BaseChatModel): ) ) params = { + **values.get("init_kwargs", {}), "model": values["model"], "stream": values["streaming"], } diff --git a/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py index 2920447c23b..41bbd96984d 100644 --- a/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py @@ -4,7 +4,7 @@ import logging from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) @@ -41,8 +41,12 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings): client: Any """Qianfan client""" - max_retries: int = 5 - """Max reties times""" + init_kwargs: Dict[str, Any] = Field(default_factory=dict) + """init kwargs for qianfan client init, such as `query_per_second` which is + associated with qianfan resource object to limit QPS""" + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """extra params for model invoke using with `do`.""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -88,6 +92,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings): import qianfan params = { + **values.get("init_kwargs", {}), "model": values["model"], } if values["qianfan_ak"].get_secret_value() != "": @@ -125,7 +130,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings): ] lst = [] for chunk in text_in_chunks: - resp = self.client.do(texts=chunk) + resp = self.client.do(texts=chunk, **self.model_kwargs) lst.extend([res["embedding"] for res in resp["data"]]) return lst @@ -140,7 +145,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings): ] lst = [] for chunk in text_in_chunks: - resp = await self.client.ado(texts=chunk) + resp = await self.client.ado(texts=chunk, **self.model_kwargs) for res in resp["data"]: lst.extend([res["embedding"]]) return lst diff --git a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py index 36607f1b1e4..48844c2d22d 100644 --- a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py @@ -40,7 +40,12 @@ class QianfanLLMEndpoint(LLM): endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk") """ + init_kwargs: Dict[str, Any] = Field(default_factory=dict) + """init kwargs for qianfan client init, such as `query_per_second` which is + associated with qianfan resource object to limit QPS""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """extra params for model invoke using with `do`.""" client: Any @@ -91,6 +96,7 @@ class QianfanLLMEndpoint(LLM): ) params = { + **values.get("init_kwargs", {}), "model": values["model"], } if values["qianfan_ak"].get_secret_value() != "": diff --git a/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py b/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py index 88bfc66a382..f3bc4bb7746 100644 --- a/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py +++ b/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py @@ -217,3 +217,18 @@ def test_functions_call() -> None: chain = prompt | chat.bind(functions=_FUNCTIONS) resp = chain.invoke({}) assert isinstance(resp, AIMessage) + + +def test_rate_limit() -> None: + chat = QianfanChatEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2}) + assert chat.client._client._rate_limiter._sync_limiter._query_per_second == 2 + responses = chat.batch( + [ + [HumanMessage(content="Hello")], + [HumanMessage(content="who are you")], + [HumanMessage(content="what is baidu")], + ] + ) + for res in responses: + assert isinstance(res, BaseMessage) + assert isinstance(res.content, str) diff --git a/libs/community/tests/integration_tests/embeddings/test_qianfan_endpoint.py b/libs/community/tests/integration_tests/embeddings/test_qianfan_endpoint.py index f257f61a021..c575f8475cb 100644 --- a/libs/community/tests/integration_tests/embeddings/test_qianfan_endpoint.py +++ b/libs/community/tests/integration_tests/embeddings/test_qianfan_endpoint.py @@ -25,3 +25,15 @@ def test_model() -> None: embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1") output = embedding.embed_documents(documents) assert len(output) == 2 + + +def test_rate_limit() -> None: + llm = QianfanEmbeddingsEndpoint( + model="Embedding-V1", init_kwargs={"query_per_second": 2} + ) + assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2 + documents = ["foo", "bar"] + output = llm.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 384 + assert len(output[1]) == 384 diff --git a/libs/community/tests/integration_tests/llms/test_qianfan_endpoint.py b/libs/community/tests/integration_tests/llms/test_qianfan_endpoint.py index acafba2c451..30a9e135e18 100644 --- a/libs/community/tests/integration_tests/llms/test_qianfan_endpoint.py +++ b/libs/community/tests/integration_tests/llms/test_qianfan_endpoint.py @@ -33,3 +33,11 @@ async def test_qianfan_aio() -> None: async for token in llm.astream("hi qianfan."): assert isinstance(token, str) + + +def test_rate_limit() -> None: + llm = QianfanLLMEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2}) + assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2 + output = llm.generate(["write a joke"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list)