mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
community:qianfan endpoint support init params & remove useless params definietion (#15381)
- **Description:** - support custom kwargs in object initialization. For instantance, QPS differs from multiple object(chat/completion/embedding with diverse models), for which global env is not a good choice for configuration. - **Issue:** no - **Dependencies:** no - **Twitter handle:** no @baskaryan PTAL
This commit is contained in:
parent
26f84b74d0
commit
7773943a51
@ -83,7 +83,12 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
||||
"""
|
||||
|
||||
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||
associated with qianfan resource object to limit QPS"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra params for model invoke using with `do`."""
|
||||
|
||||
client: Any
|
||||
|
||||
@ -134,6 +139,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
)
|
||||
)
|
||||
params = {
|
||||
**values.get("init_kwargs", {}),
|
||||
"model": values["model"],
|
||||
"stream": values["streaming"],
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -41,8 +41,12 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
||||
client: Any
|
||||
"""Qianfan client"""
|
||||
|
||||
max_retries: int = 5
|
||||
"""Max reties times"""
|
||||
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||
associated with qianfan resource object to limit QPS"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra params for model invoke using with `do`."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -88,6 +92,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
||||
import qianfan
|
||||
|
||||
params = {
|
||||
**values.get("init_kwargs", {}),
|
||||
"model": values["model"],
|
||||
}
|
||||
if values["qianfan_ak"].get_secret_value() != "":
|
||||
@ -125,7 +130,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
||||
]
|
||||
lst = []
|
||||
for chunk in text_in_chunks:
|
||||
resp = self.client.do(texts=chunk)
|
||||
resp = self.client.do(texts=chunk, **self.model_kwargs)
|
||||
lst.extend([res["embedding"] for res in resp["data"]])
|
||||
return lst
|
||||
|
||||
@ -140,7 +145,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
||||
]
|
||||
lst = []
|
||||
for chunk in text_in_chunks:
|
||||
resp = await self.client.ado(texts=chunk)
|
||||
resp = await self.client.ado(texts=chunk, **self.model_kwargs)
|
||||
for res in resp["data"]:
|
||||
lst.extend([res["embedding"]])
|
||||
return lst
|
||||
|
@ -40,7 +40,12 @@ class QianfanLLMEndpoint(LLM):
|
||||
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
||||
"""
|
||||
|
||||
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||
associated with qianfan resource object to limit QPS"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra params for model invoke using with `do`."""
|
||||
|
||||
client: Any
|
||||
|
||||
@ -91,6 +96,7 @@ class QianfanLLMEndpoint(LLM):
|
||||
)
|
||||
|
||||
params = {
|
||||
**values.get("init_kwargs", {}),
|
||||
"model": values["model"],
|
||||
}
|
||||
if values["qianfan_ak"].get_secret_value() != "":
|
||||
|
@ -217,3 +217,18 @@ def test_functions_call() -> None:
|
||||
chain = prompt | chat.bind(functions=_FUNCTIONS)
|
||||
resp = chain.invoke({})
|
||||
assert isinstance(resp, AIMessage)
|
||||
|
||||
|
||||
def test_rate_limit() -> None:
|
||||
chat = QianfanChatEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2})
|
||||
assert chat.client._client._rate_limiter._sync_limiter._query_per_second == 2
|
||||
responses = chat.batch(
|
||||
[
|
||||
[HumanMessage(content="Hello")],
|
||||
[HumanMessage(content="who are you")],
|
||||
[HumanMessage(content="what is baidu")],
|
||||
]
|
||||
)
|
||||
for res in responses:
|
||||
assert isinstance(res, BaseMessage)
|
||||
assert isinstance(res.content, str)
|
||||
|
@ -25,3 +25,15 @@ def test_model() -> None:
|
||||
embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
|
||||
|
||||
def test_rate_limit() -> None:
|
||||
llm = QianfanEmbeddingsEndpoint(
|
||||
model="Embedding-V1", init_kwargs={"query_per_second": 2}
|
||||
)
|
||||
assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2
|
||||
documents = ["foo", "bar"]
|
||||
output = llm.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 384
|
||||
assert len(output[1]) == 384
|
||||
|
@ -33,3 +33,11 @@ async def test_qianfan_aio() -> None:
|
||||
|
||||
async for token in llm.astream("hi qianfan."):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_rate_limit() -> None:
|
||||
llm = QianfanLLMEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2})
|
||||
assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2
|
||||
output = llm.generate(["write a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
|
Loading…
Reference in New Issue
Block a user