mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
community:qianfan endpoint support init params & remove useless params definietion (#15381)
- **Description:** - support custom kwargs in object initialization. For instantance, QPS differs from multiple object(chat/completion/embedding with diverse models), for which global env is not a good choice for configuration. - **Issue:** no - **Dependencies:** no - **Twitter handle:** no @baskaryan PTAL
This commit is contained in:
parent
26f84b74d0
commit
7773943a51
@ -83,7 +83,12 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||||
|
associated with qianfan resource object to limit QPS"""
|
||||||
|
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""extra params for model invoke using with `do`."""
|
||||||
|
|
||||||
client: Any
|
client: Any
|
||||||
|
|
||||||
@ -134,6 +139,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
params = {
|
params = {
|
||||||
|
**values.get("init_kwargs", {}),
|
||||||
"model": values["model"],
|
"model": values["model"],
|
||||||
"stream": values["streaming"],
|
"stream": values["streaming"],
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -41,8 +41,12 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
|||||||
client: Any
|
client: Any
|
||||||
"""Qianfan client"""
|
"""Qianfan client"""
|
||||||
|
|
||||||
max_retries: int = 5
|
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Max reties times"""
|
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||||
|
associated with qianfan resource object to limit QPS"""
|
||||||
|
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""extra params for model invoke using with `do`."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -88,6 +92,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
|||||||
import qianfan
|
import qianfan
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
**values.get("init_kwargs", {}),
|
||||||
"model": values["model"],
|
"model": values["model"],
|
||||||
}
|
}
|
||||||
if values["qianfan_ak"].get_secret_value() != "":
|
if values["qianfan_ak"].get_secret_value() != "":
|
||||||
@ -125,7 +130,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
|||||||
]
|
]
|
||||||
lst = []
|
lst = []
|
||||||
for chunk in text_in_chunks:
|
for chunk in text_in_chunks:
|
||||||
resp = self.client.do(texts=chunk)
|
resp = self.client.do(texts=chunk, **self.model_kwargs)
|
||||||
lst.extend([res["embedding"] for res in resp["data"]])
|
lst.extend([res["embedding"] for res in resp["data"]])
|
||||||
return lst
|
return lst
|
||||||
|
|
||||||
@ -140,7 +145,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
|||||||
]
|
]
|
||||||
lst = []
|
lst = []
|
||||||
for chunk in text_in_chunks:
|
for chunk in text_in_chunks:
|
||||||
resp = await self.client.ado(texts=chunk)
|
resp = await self.client.ado(texts=chunk, **self.model_kwargs)
|
||||||
for res in resp["data"]:
|
for res in resp["data"]:
|
||||||
lst.extend([res["embedding"]])
|
lst.extend([res["embedding"]])
|
||||||
return lst
|
return lst
|
||||||
|
@ -40,7 +40,12 @@ class QianfanLLMEndpoint(LLM):
|
|||||||
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||||
|
associated with qianfan resource object to limit QPS"""
|
||||||
|
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""extra params for model invoke using with `do`."""
|
||||||
|
|
||||||
client: Any
|
client: Any
|
||||||
|
|
||||||
@ -91,6 +96,7 @@ class QianfanLLMEndpoint(LLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
**values.get("init_kwargs", {}),
|
||||||
"model": values["model"],
|
"model": values["model"],
|
||||||
}
|
}
|
||||||
if values["qianfan_ak"].get_secret_value() != "":
|
if values["qianfan_ak"].get_secret_value() != "":
|
||||||
|
@ -217,3 +217,18 @@ def test_functions_call() -> None:
|
|||||||
chain = prompt | chat.bind(functions=_FUNCTIONS)
|
chain = prompt | chat.bind(functions=_FUNCTIONS)
|
||||||
resp = chain.invoke({})
|
resp = chain.invoke({})
|
||||||
assert isinstance(resp, AIMessage)
|
assert isinstance(resp, AIMessage)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit() -> None:
|
||||||
|
chat = QianfanChatEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2})
|
||||||
|
assert chat.client._client._rate_limiter._sync_limiter._query_per_second == 2
|
||||||
|
responses = chat.batch(
|
||||||
|
[
|
||||||
|
[HumanMessage(content="Hello")],
|
||||||
|
[HumanMessage(content="who are you")],
|
||||||
|
[HumanMessage(content="what is baidu")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for res in responses:
|
||||||
|
assert isinstance(res, BaseMessage)
|
||||||
|
assert isinstance(res.content, str)
|
||||||
|
@ -25,3 +25,15 @@ def test_model() -> None:
|
|||||||
embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1")
|
embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1")
|
||||||
output = embedding.embed_documents(documents)
|
output = embedding.embed_documents(documents)
|
||||||
assert len(output) == 2
|
assert len(output) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit() -> None:
|
||||||
|
llm = QianfanEmbeddingsEndpoint(
|
||||||
|
model="Embedding-V1", init_kwargs={"query_per_second": 2}
|
||||||
|
)
|
||||||
|
assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2
|
||||||
|
documents = ["foo", "bar"]
|
||||||
|
output = llm.embed_documents(documents)
|
||||||
|
assert len(output) == 2
|
||||||
|
assert len(output[0]) == 384
|
||||||
|
assert len(output[1]) == 384
|
||||||
|
@ -33,3 +33,11 @@ async def test_qianfan_aio() -> None:
|
|||||||
|
|
||||||
async for token in llm.astream("hi qianfan."):
|
async for token in llm.astream("hi qianfan."):
|
||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit() -> None:
|
||||||
|
llm = QianfanLLMEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2})
|
||||||
|
assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2
|
||||||
|
output = llm.generate(["write a joke"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert isinstance(output.generations, list)
|
||||||
|
Loading…
Reference in New Issue
Block a user