Add Baidu Qianfan endpoint for LLM (#10496)

- Description:
* Baidu AI Cloud's [Qianfan
Platform](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) is an
all-in-one platform for large model development and service deployment,
catering to enterprise developers in China. Qianfan Platform offers a
wide range of resources, including the Wenxin Yiyan model (ERNIE-Bot)
and various third-party open-source models.
- Issue: none
- Dependencies: 
    * qianfan
- Tag maintainer: @baskaryan
- Twitter handle:

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
stonekim
2023-09-14 07:23:49 +08:00
committed by GitHub
parent 0a0276bcdb
commit adabdfdfc7
12 changed files with 1284 additions and 0 deletions

View File

@@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.anyscale import ChatAnyscale
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.fake import FakeListChatModel
@@ -51,4 +52,5 @@ __all__ = [
"ChatLiteLLM",
"ErnieBotChat",
"ChatKonko",
"QianfanChatEndpoint",
]

View File

@@ -0,0 +1,293 @@
from __future__ import annotations
import logging
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
return AIMessageChunk(
content=resp["result"],
role="assistant",
)
def convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["functions"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
class QianfanChatEndpoint(BaseChatModel):
"""Baidu Qianfan chat models.
To use, you should have the ``qianfan`` python package installed, and
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with your
API key and Secret Key.
ak, sk are required parameters
which you could get from https://cloud.baidu.com/product/wenxinworkshop
Example:
.. code-block:: python
from langchain.chat_models import QianfanChatEndpoint
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
client: Any
qianfan_ak: Optional[str] = None
qianfan_sk: Optional[str] = None
streaming: Optional[bool] = False
"""Whether to stream the results or not."""
request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""
top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95
penalty_score: Optional[float] = 1
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
In the case of other model, passing these params will not affect the result.
"""
model: str = "ERNIE-Bot-turbo"
"""Model name.
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
endpoint: Optional[str] = None
"""Endpoint of the Qianfan LLM, required if custom model used."""
@root_validator()
def validate_enviroment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
)
values["qianfan_sk"] = get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
)
params = {
"ak": values["qianfan_ak"],
"sk": values["qianfan_sk"],
"model": values["model"],
"stream": values["streaming"],
}
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
try:
import qianfan
values["client"] = qianfan.ChatCompletion(**params)
except ImportError:
raise ValueError(
"qianfan package not found, please install it with "
"`pip install qianfan`"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"endpoint": self.endpoint, "model": self.model},
**super()._identifying_params,
}
@property
def _llm_type(self) -> str:
"""Return type of chat_model."""
return "baidu-qianfan-chat"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {
"stream": self.streaming,
"request_timeout": self.request_timeout,
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
}
return {**normal_params, **self.model_kwargs}
def _convert_prompt_msg_params(
self,
messages: List[BaseMessage],
**kwargs: Any,
) -> dict:
return {
**{"messages": [convert_message_to_dict(m) for m in messages]},
**self._default_params,
**kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to an qianfan models endpoint for each generation with a prompt.
Args:
messages: The messages to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = qianfan_model("Tell me a joke.")
"""
if self.streaming:
completion = ""
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
)
return ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = self.client.do(**params)
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=[gen], llm_output=llm_output)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
completion = ""
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
)
return ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = await self.client.ado(**params)
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
generations = []
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
)
generations.append(gen)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
for res in self.client.do(**params):
if res:
chunk = ChatGenerationChunk(
text=res["result"],
message=_convert_resp_to_message_chunk(res),
generation_info={"finish_reason": "finished"},
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
async for res in await self.client.ado(**params):
if res:
chunk = ChatGenerationChunk(
text=res["result"], message=_convert_resp_to_message_chunk(res)
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)

View File

@@ -19,6 +19,7 @@ from langchain.embeddings.aleph_alpha import (
AlephAlphaSymmetricSemanticEmbedding,
)
from langchain.embeddings.awa import AwaEmbeddings
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
from langchain.embeddings.bedrock import BedrockEmbeddings
from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain.embeddings.clarifai import ClarifaiEmbeddings
@@ -105,6 +106,7 @@ __all__ = [
"AwaEmbeddings",
"HuggingFaceBgeEmbeddings",
"ErnieEmbeddings",
"QianfanEmbeddingsEndpoint",
]

View File

@@ -0,0 +1,138 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
"""`Baidu Qianfan Embeddings` embedding models."""
qianfan_ak: Optional[str] = None
"""Qianfan application apikey"""
qianfan_sk: Optional[str] = None
"""Qianfan application secretkey"""
chunk_size: int = 16
"""Chunk size when multiple texts are input"""
model: str = "Embedding-V1"
"""Model name
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
for now, we support Embedding-V1 and
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
endpoint: str = ""
"""Endpoint of the Qianfan Embedding, required if custom model used."""
client: Any
"""Qianfan client"""
max_retries: int = 5
"""Max reties times"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
Validate whether qianfan_ak and qianfan_sk in the environment variables or
configuration file are available or not.
init qianfan embedding client with `ak`, `sk`, `model`, `endpoint`
Args:
values: a dictionary containing configuration information, must include the
fields of qianfan_ak and qianfan_sk
Returns:
a dictionary containing configuration information. If qianfan_ak and
qianfan_sk are not provided in the environment variables or configuration
file,the original values will be returned; otherwise, values containing
qianfan_ak and qianfan_sk will be returned.
Raises:
ValueError: qianfan package not found, please install it with `pip install
qianfan`
"""
values["qianfan_ak"] = get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
)
values["qianfan_sk"] = get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
)
try:
import qianfan
params = {
"ak": values["qianfan_ak"],
"sk": values["qianfan_sk"],
"model": values["model"],
}
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
values["client"] = qianfan.Embedding(**params)
except ImportError:
raise ValueError(
"qianfan package not found, please install it with "
"`pip install qianfan`"
)
return values
def embed_query(self, text: str) -> List[float]:
resp = self.embed_documents([text])
return resp[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.
Args:
texts (List[str]): A list of text documents to embed.
Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self.client.do(texts=chunk)
lst.extend([res["embedding"] for res in resp["data"]])
return lst
async def aembed_query(self, text: str) -> List[float]:
embeddings = await self.aembed_documents([text])
return embeddings[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = await self.client.ado(texts=chunk)
for res in resp["data"]:
lst.extend([res["embedding"]])
return lst

View File

@@ -26,6 +26,7 @@ from langchain.llms.anthropic import Anthropic
from langchain.llms.anyscale import Anyscale
from langchain.llms.aviary import Aviary
from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint
from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
from langchain.llms.bananadev import Banana
from langchain.llms.base import BaseLLM
from langchain.llms.baseten import Baseten
@@ -160,6 +161,7 @@ __all__ = [
"Writer",
"OctoAIEndpoint",
"Xinference",
"QianfanLLMEndpoint",
]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
@@ -228,4 +230,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"vllm_openai": VLLMOpenAI,
"writer": Writer,
"xinference": Xinference,
"qianfan_endpoint": QianfanLLMEndpoint,
}

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import logging
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class QianfanLLMEndpoint(LLM):
"""Baidu Qianfan hosted open source or customized models.
To use, you should have the ``qianfan`` python package installed, and
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with
your API key and Secret Key.
ak, sk are required parameters which you could get from
https://cloud.baidu.com/product/wenxinworkshop
Example:
.. code-block:: python
from langchain.llms import QianfanLLMEndpoint
qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
client: Any
qianfan_ak: Optional[str] = None
qianfan_sk: Optional[str] = None
streaming: Optional[bool] = False
"""Whether to stream the results or not."""
model: str = "ERNIE-Bot-turbo"
"""Model name.
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
endpoint: Optional[str] = None
"""Endpoint of the Qianfan LLM, required if custom model used."""
request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""
top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95
penalty_score: Optional[float] = 1
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
In the case of other model, passing these params will not affect the result.
"""
@root_validator()
def validate_enviroment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
)
values["qianfan_sk"] = get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
)
params = {
"ak": values["qianfan_ak"],
"sk": values["qianfan_sk"],
"model": values["model"],
}
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
try:
import qianfan
values["client"] = qianfan.Completion(**params)
except ImportError:
raise ValueError(
"qianfan package not found, please install it with "
"`pip install qianfan`"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"endpoint": self.endpoint, "model": self.model},
**super()._identifying_params,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "baidu-qianfan-endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {
"stream": self.streaming,
"request_timeout": self.request_timeout,
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
}
return {**normal_params, **self.model_kwargs}
def _convert_prompt_msg_params(
self,
prompt: str,
**kwargs: Any,
) -> dict:
return {
**{"prompt": prompt, "model": self.model},
**self._default_params,
**kwargs,
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to an qianfan models endpoint for each generation with a prompt.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = qianfan_model("Tell me a joke.")
"""
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
response_payload = self.client.do(**params)
return response_payload["result"]
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
completion = ""
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
response_payload = await self.client.ado(**params)
return response_payload["result"]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
for res in self.client.do(**params):
if res:
chunk = GenerationChunk(text=res["result"])
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
async for res in await self.client.ado(**params):
if res:
chunk = GenerationChunk(text=res["result"])
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)

View File

@@ -0,0 +1,85 @@
"""Test Baidu Qianfan Chat Endpoint."""
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_default_call() -> None:
"""Test default model(`ERNIE-Bot`) call."""
chat = QianfanChatEndpoint()
response = chat(messages=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_model() -> None:
"""Test model kwarg works."""
chat = QianfanChatEndpoint(model="BLOOMZ-7B")
response = chat(messages=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_endpoint() -> None:
"""Test user custom model deployments like some open source models."""
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
response = chat(messages=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_history() -> None:
"""Tests multiple history works."""
chat = QianfanChatEndpoint()
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_stream() -> None:
"""Test that stream works."""
chat = QianfanChatEndpoint(streaming=True)
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="Who are you?"),
],
stream=True,
callbacks=callback_manager,
)
assert callback_handler.llm_streams > 0
assert isinstance(response.content, str)
def test_multiple_messages() -> None:
"""Tests multiple messages works."""
chat = QianfanChatEndpoint()
message = HumanMessage(content="Hi, how are you.")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

View File

@@ -0,0 +1,25 @@
"""Test Baidu Qianfan Embedding Endpoint."""
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
def test_embedding_multiple_documents() -> None:
documents = ["foo", "bar"]
embedding = QianfanEmbeddingsEndpoint()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
assert len(output[1]) == 384
def test_embedding_query() -> None:
query = "foo"
embedding = QianfanEmbeddingsEndpoint()
output = embedding.embed_query(query)
assert len(output) == 384
def test_model() -> None:
documents = ["hi", "qianfan"]
embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1")
output = embedding.embed_documents(documents)
assert len(output) == 2

View File

@@ -0,0 +1,37 @@
"""Test Baidu Qianfan LLM Endpoint."""
from typing import Generator
import pytest
from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
from langchain.schema import LLMResult
def test_call() -> None:
"""Test valid call to qianfan."""
llm = QianfanLLMEndpoint()
output = llm("write a joke")
assert isinstance(output, str)
def test_generate() -> None:
"""Test valid call to qianfan."""
llm = QianfanLLMEndpoint()
output = llm.generate(["write a joke"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
def test_generate_stream() -> None:
"""Test valid call to qianfan."""
llm = QianfanLLMEndpoint()
output = llm.stream("write a joke")
assert isinstance(output, Generator)
@pytest.mark.asyncio
async def test_qianfan_aio() -> None:
llm = QianfanLLMEndpoint(streaming=True)
async for token in llm.astream("hi qianfan."):
assert isinstance(token, str)