mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-15 12:02:11 +00:00
Add Baidu Qianfan endpoint for LLM (#10496)
- Description: * Baidu AI Cloud's [Qianfan Platform](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) is an all-in-one platform for large model development and service deployment, catering to enterprise developers in China. Qianfan Platform offers a wide range of resources, including the Wenxin Yiyan model (ERNIE-Bot) and various third-party open-source models. - Issue: none - Dependencies: * qianfan - Tag maintainer: @baskaryan - Twitter handle: --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
0a0276bcdb
commit
adabdfdfc7
181
docs/extras/integrations/chat/baidu_qianfan_endpoint.ipynb
Normal file
181
docs/extras/integrations/chat/baidu_qianfan_endpoint.ipynb
Normal file
@ -0,0 +1,181 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Baidu Qianfan\n",
|
||||
"\n",
|
||||
"Baidu AI Cloud Qianfan Platform is a one-stop large model development and service operation platform for enterprise developers. Qianfan not only provides including the model of Wenxin Yiyan (ERNIE-Bot) and the third-party open source models, but also provides various AI development tools and the whole set of development environment, which facilitates customers to use and develop large model applications easily.\n",
|
||||
"\n",
|
||||
"Basically, those model are split into the following type:\n",
|
||||
"\n",
|
||||
"- Embedding\n",
|
||||
"- Chat\n",
|
||||
"- Completion\n",
|
||||
"\n",
|
||||
"In this notebook, we will introduce how to use langchain with [Qianfan](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) mainly in `Chat` corresponding\n",
|
||||
" to the package `langchain/chat_models` in langchain:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## API Initialization\n",
|
||||
"\n",
|
||||
"To use the LLM services based on Baidu Qianfan, you have to initialize these parameters:\n",
|
||||
"\n",
|
||||
"You could either choose to init the AK,SK in enviroment variables or init params:\n",
|
||||
"\n",
|
||||
"```base\n",
|
||||
"export QIANFAN_AK=XXX\n",
|
||||
"export QIANFAN_SK=XXX\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"## Current supported models:\n",
|
||||
"\n",
|
||||
"- ERNIE-Bot-turbo (default models)\n",
|
||||
"- ERNIE-Bot\n",
|
||||
"- BLOOMZ-7B\n",
|
||||
"- Llama-2-7b-chat\n",
|
||||
"- Llama-2-13b-chat\n",
|
||||
"- Llama-2-70b-chat\n",
|
||||
"- Qianfan-BLOOMZ-7B-compressed\n",
|
||||
"- Qianfan-Chinese-Llama-2-7B\n",
|
||||
"- ChatGLM2-6B-32K\n",
|
||||
"- AquilaChat-7B"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"For basic init and call\"\"\"\n",
|
||||
"from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint \n",
|
||||
"from langchain.chat_models.base import HumanMessage\n",
|
||||
"import os\n",
|
||||
"os.environ[\"QIAFAN_AK\"] = \"xxx\"\n",
|
||||
"os.environ[\"QIAFAN_AK\"] = \"xxx\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chat = QianfanChatEndpoint(\n",
|
||||
" qianfan_ak=\"xxx\",\n",
|
||||
" qianfan_sk=\"xxx\",\n",
|
||||
" streaming=True, \n",
|
||||
" )\n",
|
||||
"res = chat([HumanMessage(content=\"write a funny joke\")])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" \n",
|
||||
"from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint\n",
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"chatLLM = QianfanChatEndpoint(\n",
|
||||
" streaming=True,\n",
|
||||
")\n",
|
||||
"res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
|
||||
"for r in res:\n",
|
||||
" print(\"chat resp1:\", r)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def run_aio_generate():\n",
|
||||
" resp = await chatLLM.agenerate(messages=[[HumanMessage(content=\"write a 20 words sentence about sea.\")]])\n",
|
||||
" print(resp)\n",
|
||||
" \n",
|
||||
"await run_aio_generate()\n",
|
||||
"\n",
|
||||
"async def run_aio_stream():\n",
|
||||
" async for res in chatLLM.astream([HumanMessage(content=\"write a 20 words sentence about sea.\")]):\n",
|
||||
" print(\"astream\", res)\n",
|
||||
" \n",
|
||||
"await run_aio_stream()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use different models in Qianfan\n",
|
||||
"\n",
|
||||
"In the case you want to deploy your own model based on Ernie Bot or third-party open sources model, you could follow these steps:\n",
|
||||
"\n",
|
||||
"- 1. (Optional, if the model are included in the default models, skip it)Deploy your model in Qianfan Console, get your own customized deploy endpoint.\n",
|
||||
"- 2. Set up the field called `endpoint` in the initlization:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chatBloom = QianfanChatEndpoint(\n",
|
||||
" streaming=True, \n",
|
||||
" model=\"BLOOMZ-7B\",\n",
|
||||
" )\n",
|
||||
"res = chatBloom([HumanMessage(content=\"hi\")])\n",
|
||||
"print(res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model Params:\n",
|
||||
"\n",
|
||||
"For now, only `ERNIE-Bot` and `ERNIE-Bot-turbo` support model params below, we might support more models in the future.\n",
|
||||
"\n",
|
||||
"- temperature\n",
|
||||
"- top_p\n",
|
||||
"- penalty_score\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = chat.stream([HumanMessage(content=\"hi\")], **{'top_p': 0.4, 'temperature': 0.1, 'penalty_score': 1})\n",
|
||||
"\n",
|
||||
"for r in res:\n",
|
||||
" print(r)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.2"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "2d8226dd90b7dc6e8932aea372a8bf9fc71abac4be3cdd5a63a36c2a19e3700f"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
177
docs/extras/integrations/llms/baidu_qianfan_endpoint.ipynb
Normal file
177
docs/extras/integrations/llms/baidu_qianfan_endpoint.ipynb
Normal file
@ -0,0 +1,177 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Baidu Qianfan\n",
|
||||
"\n",
|
||||
"Baidu AI Cloud Qianfan Platform is a one-stop large model development and service operation platform for enterprise developers. Qianfan not only provides including the model of Wenxin Yiyan (ERNIE-Bot) and the third-party open source models, but also provides various AI development tools and the whole set of development environment, which facilitates customers to use and develop large model applications easily.\n",
|
||||
"\n",
|
||||
"Basically, those model are split into the following type:\n",
|
||||
"\n",
|
||||
"- Embedding\n",
|
||||
"- Chat\n",
|
||||
"- Coompletion\n",
|
||||
"\n",
|
||||
"In this notebook, we will introduce how to use langchain with [Qianfan](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) mainly in `Completion` corresponding\n",
|
||||
" to the package `langchain/llms` in langchain:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## API Initialization\n",
|
||||
"\n",
|
||||
"To use the LLM services based on Baidu Qianfan, you have to initialize these parameters:\n",
|
||||
"\n",
|
||||
"You could either choose to init the AK,SK in enviroment variables or init params:\n",
|
||||
"\n",
|
||||
"```base\n",
|
||||
"export QIANFAN_AK=XXX\n",
|
||||
"export QIANFAN_SK=XXX\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"## Current supported models:\n",
|
||||
"\n",
|
||||
"- ERNIE-Bot-turbo (default models)\n",
|
||||
"- ERNIE-Bot\n",
|
||||
"- BLOOMZ-7B\n",
|
||||
"- Llama-2-7b-chat\n",
|
||||
"- Llama-2-13b-chat\n",
|
||||
"- Llama-2-70b-chat\n",
|
||||
"- Qianfan-BLOOMZ-7B-compressed\n",
|
||||
"- Qianfan-Chinese-Llama-2-7B\n",
|
||||
"- ChatGLM2-6B-32K\n",
|
||||
"- AquilaChat-7B"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"For basic init and call\"\"\"\n",
|
||||
"from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"QIANFAN_AK\"] = \"xx\"\n",
|
||||
"os.environ[\"QIANFAN_SK\"] = \"xx\"\n",
|
||||
"\n",
|
||||
"llm = QianfanLLMEndpoint(streaming=True, ak=\"xx\", sk=\"xx\")\n",
|
||||
"res = llm(\"hi\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\"\"\"Test for llm generate \"\"\"\n",
|
||||
"res = llm.generate(prompts=[\"hillo?\"])\n",
|
||||
"import asyncio\n",
|
||||
"\"\"\"Test for llm aio generate\"\"\"\n",
|
||||
"async def run_aio_generate():\n",
|
||||
" resp = await llm.agenerate(prompts=[\"Write a 20-word article about rivers.\"])\n",
|
||||
" print(resp)\n",
|
||||
"\n",
|
||||
"await run_aio_generate()\n",
|
||||
"\n",
|
||||
"\"\"\"Test for llm stream\"\"\"\n",
|
||||
"for res in llm.stream(\"write a joke.\"):\n",
|
||||
" print(res)\n",
|
||||
"\n",
|
||||
"\"\"\"Test for llm aio stream\"\"\"\n",
|
||||
"async def run_aio_stream():\n",
|
||||
" async for res in llm.astream(\"Write a 20-word article about mountains\"):\n",
|
||||
" print(res)\n",
|
||||
"\n",
|
||||
"await run_aio_stream()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use different models in Qianfan\n",
|
||||
"\n",
|
||||
"In the case you want to deploy your own model based on EB or serval open sources model, you could follow these steps:\n",
|
||||
"\n",
|
||||
"- 1. (Optional, if the model are included in the default models, skip it)Deploy your model in Qianfan Console, get your own customized deploy endpoint.\n",
|
||||
"- 2. Set up the field called `endpoint` in the initlization:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = QianfanLLMEndpoint(qianfan_ak='xxx', \n",
|
||||
" qianfan_sk='xxx', \n",
|
||||
" streaming=True, \n",
|
||||
" model=\"ERNIE-Bot-turbo\",\n",
|
||||
" endpoint=\"eb-instant\",\n",
|
||||
" )\n",
|
||||
"res = llm(\"hi\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model Params:\n",
|
||||
"\n",
|
||||
"For now, only `ERNIE-Bot` and `ERNIE-Bot-turbo` support model params below, we might support more models in the future.\n",
|
||||
"\n",
|
||||
"- temperature\n",
|
||||
"- top_p\n",
|
||||
"- penalty_score\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = llm.generate(prompts=[\"hi\"], streaming=True, **{'top_p': 0.4, 'temperature': 0.1, 'penalty_score': 1})\n",
|
||||
"\n",
|
||||
"for r in res:\n",
|
||||
" print(r)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,124 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Baidu Qianfan\n",
|
||||
"\n",
|
||||
"Baidu AI Cloud Qianfan Platform is a one-stop large model development and service operation platform for enterprise developers. Qianfan not only provides including the model of Wenxin Yiyan (ERNIE-Bot) and the third-party open source models, but also provides various AI development tools and the whole set of development environment, which facilitates customers to use and develop large model applications easily.\n",
|
||||
"\n",
|
||||
"Basically, those model are split into the following type:\n",
|
||||
"\n",
|
||||
"- Embedding\n",
|
||||
"- Chat\n",
|
||||
"- Completion\n",
|
||||
"\n",
|
||||
"In this notebook, we will introduce how to use langchain with [Qianfan](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) mainly in `Embedding` corresponding\n",
|
||||
" to the package `langchain/embeddings` in langchain:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## API Initialization\n",
|
||||
"\n",
|
||||
"To use the LLM services based on Baidu Qianfan, you have to initialize these parameters:\n",
|
||||
"\n",
|
||||
"You could either choose to init the AK,SK in enviroment variables or init params:\n",
|
||||
"\n",
|
||||
"```base\n",
|
||||
"export QIANFAN_AK=XXX\n",
|
||||
"export QIANFAN_SK=XXX\n",
|
||||
"```\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"For basic init and call\"\"\"\n",
|
||||
"from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint \n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"QIANFAN_AK\"] = \"xx\"\n",
|
||||
"os.environ[\"QIANFAN_SK\"] = \"xx\"\n",
|
||||
"\n",
|
||||
"embed = QianfanEmbeddingsEndpoint(qianfan_ak='xxx', \n",
|
||||
" qianfan_sk='xxx')\n",
|
||||
"res = embed.embed_documents([\"hi\", \"world\"])\n",
|
||||
"\n",
|
||||
"import asyncio\n",
|
||||
"\n",
|
||||
"async def aioEmbed():\n",
|
||||
" res = await embed.aembed_query(\"qianfan\")\n",
|
||||
" print(res)\n",
|
||||
"await aioEmbed()\n",
|
||||
"\n",
|
||||
"import asyncio\n",
|
||||
"async def aioEmbedDocs():\n",
|
||||
" res = await embed.aembed_documents([\"hi\", \"world\"])\n",
|
||||
" for r in res:\n",
|
||||
" print(\"\", r[:8])\n",
|
||||
"await aioEmbedDocs()\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use different models in Qianfan\n",
|
||||
"\n",
|
||||
"In the case you want to deploy your own model based on Ernie Bot or third-party open sources model, you could follow these steps:\n",
|
||||
"\n",
|
||||
"- 1. (Optional, if the model are included in the default models, skip it)Deploy your model in Qianfan Console, get your own customized deploy endpoint.\n",
|
||||
"- 2. Set up the field called `endpoint` in the initlization:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embed = QianfanEmbeddingsEndpoint(qianfan_ak='xxx', \n",
|
||||
" qianfan_sk='xxx',\n",
|
||||
" model=\"bge_large_zh\",\n",
|
||||
" endpoint=\"bge_large_zh\")\n",
|
||||
"\n",
|
||||
"res = embed.embed_documents([\"hi\", \"world\"])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.chat_models.anyscale import ChatAnyscale
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.chat_models.bedrock import BedrockChat
|
||||
from langchain.chat_models.ernie import ErnieBotChat
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
@ -51,4 +52,5 @@ __all__ = [
|
||||
"ChatLiteLLM",
|
||||
"ErnieBotChat",
|
||||
"ChatKonko",
|
||||
"QianfanChatEndpoint",
|
||||
]
|
||||
|
293
libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py
Normal file
293
libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py
Normal file
@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=resp["result"],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["functions"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
class QianfanChatEndpoint(BaseChatModel):
|
||||
"""Baidu Qianfan chat models.
|
||||
|
||||
To use, you should have the ``qianfan`` python package installed, and
|
||||
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with your
|
||||
API key and Secret Key.
|
||||
|
||||
ak, sk are required parameters
|
||||
which you could get from https://cloud.baidu.com/product/wenxinworkshop
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import QianfanChatEndpoint
|
||||
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
|
||||
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
client: Any
|
||||
|
||||
qianfan_ak: Optional[str] = None
|
||||
qianfan_sk: Optional[str] = None
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
request_timeout: Optional[int] = 60
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
temperature: Optional[float] = 0.95
|
||||
penalty_score: Optional[float] = 1
|
||||
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
|
||||
In the case of other model, passing these params will not affect the result.
|
||||
"""
|
||||
|
||||
model: str = "ERNIE-Bot-turbo"
|
||||
"""Model name.
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
"""Endpoint of the Qianfan LLM, required if custom model used."""
|
||||
|
||||
@root_validator()
|
||||
def validate_enviroment(cls, values: Dict) -> Dict:
|
||||
values["qianfan_ak"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_ak",
|
||||
"QIANFAN_AK",
|
||||
)
|
||||
values["qianfan_sk"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_sk",
|
||||
"QIANFAN_SK",
|
||||
)
|
||||
params = {
|
||||
"ak": values["qianfan_ak"],
|
||||
"sk": values["qianfan_sk"],
|
||||
"model": values["model"],
|
||||
"stream": values["streaming"],
|
||||
}
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
params["endpoint"] = values["endpoint"]
|
||||
try:
|
||||
import qianfan
|
||||
|
||||
values["client"] = qianfan.ChatCompletion(**params)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"qianfan package not found, please install it with "
|
||||
"`pip install qianfan`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**{"endpoint": self.endpoint, "model": self.model},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat_model."""
|
||||
return "baidu-qianfan-chat"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"stream": self.streaming,
|
||||
"request_timeout": self.request_timeout,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
return {
|
||||
**{"messages": [convert_message_to_dict(m) for m in messages]},
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to an qianfan models endpoint for each generation with a prompt.
|
||||
Args:
|
||||
messages: The messages to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = qianfan_model("Tell me a joke.")
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=[gen], llm_output=llm_output)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = await self.client.ado(**params)
|
||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
||||
generations = []
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=_convert_resp_to_message_chunk(res),
|
||||
generation_info={"finish_reason": "finished"},
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"], message=_convert_resp_to_message_chunk(res)
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
@ -19,6 +19,7 @@ from langchain.embeddings.aleph_alpha import (
|
||||
AlephAlphaSymmetricSemanticEmbedding,
|
||||
)
|
||||
from langchain.embeddings.awa import AwaEmbeddings
|
||||
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
|
||||
from langchain.embeddings.bedrock import BedrockEmbeddings
|
||||
from langchain.embeddings.cache import CacheBackedEmbeddings
|
||||
from langchain.embeddings.clarifai import ClarifaiEmbeddings
|
||||
@ -105,6 +106,7 @@ __all__ = [
|
||||
"AwaEmbeddings",
|
||||
"HuggingFaceBgeEmbeddings",
|
||||
"ErnieEmbeddings",
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
]
|
||||
|
||||
|
||||
|
138
libs/langchain/langchain/embeddings/baidu_qianfan_endpoint.py
Normal file
138
libs/langchain/langchain/embeddings/baidu_qianfan_endpoint.py
Normal file
@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
|
||||
"""`Baidu Qianfan Embeddings` embedding models."""
|
||||
|
||||
qianfan_ak: Optional[str] = None
|
||||
"""Qianfan application apikey"""
|
||||
|
||||
qianfan_sk: Optional[str] = None
|
||||
"""Qianfan application secretkey"""
|
||||
|
||||
chunk_size: int = 16
|
||||
"""Chunk size when multiple texts are input"""
|
||||
|
||||
model: str = "Embedding-V1"
|
||||
"""Model name
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
for now, we support Embedding-V1 and
|
||||
- Embedding-V1 (默认模型)
|
||||
- bge-large-en
|
||||
- bge-large-zh
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set
|
||||
"""
|
||||
|
||||
endpoint: str = ""
|
||||
"""Endpoint of the Qianfan Embedding, required if custom model used."""
|
||||
|
||||
client: Any
|
||||
"""Qianfan client"""
|
||||
|
||||
max_retries: int = 5
|
||||
"""Max reties times"""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""
|
||||
Validate whether qianfan_ak and qianfan_sk in the environment variables or
|
||||
configuration file are available or not.
|
||||
|
||||
init qianfan embedding client with `ak`, `sk`, `model`, `endpoint`
|
||||
|
||||
Args:
|
||||
|
||||
values: a dictionary containing configuration information, must include the
|
||||
fields of qianfan_ak and qianfan_sk
|
||||
Returns:
|
||||
|
||||
a dictionary containing configuration information. If qianfan_ak and
|
||||
qianfan_sk are not provided in the environment variables or configuration
|
||||
file,the original values will be returned; otherwise, values containing
|
||||
qianfan_ak and qianfan_sk will be returned.
|
||||
Raises:
|
||||
|
||||
ValueError: qianfan package not found, please install it with `pip install
|
||||
qianfan`
|
||||
"""
|
||||
values["qianfan_ak"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_ak",
|
||||
"QIANFAN_AK",
|
||||
)
|
||||
values["qianfan_sk"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_sk",
|
||||
"QIANFAN_SK",
|
||||
)
|
||||
|
||||
try:
|
||||
import qianfan
|
||||
|
||||
params = {
|
||||
"ak": values["qianfan_ak"],
|
||||
"sk": values["qianfan_sk"],
|
||||
"model": values["model"],
|
||||
}
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
params["endpoint"] = values["endpoint"]
|
||||
values["client"] = qianfan.Embedding(**params)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"qianfan package not found, please install it with "
|
||||
"`pip install qianfan`"
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
resp = self.embed_documents([text])
|
||||
return resp[0]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Embeds a list of text documents using the AutoVOT algorithm.
|
||||
|
||||
Args:
|
||||
texts (List[str]): A list of text documents to embed.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings for each document in the input list.
|
||||
Each embedding is represented as a list of float values.
|
||||
"""
|
||||
text_in_chunks = [
|
||||
texts[i : i + self.chunk_size]
|
||||
for i in range(0, len(texts), self.chunk_size)
|
||||
]
|
||||
lst = []
|
||||
for chunk in text_in_chunks:
|
||||
resp = self.client.do(texts=chunk)
|
||||
lst.extend([res["embedding"] for res in resp["data"]])
|
||||
return lst
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
text_in_chunks = [
|
||||
texts[i : i + self.chunk_size]
|
||||
for i in range(0, len(texts), self.chunk_size)
|
||||
]
|
||||
lst = []
|
||||
for chunk in text_in_chunks:
|
||||
resp = await self.client.ado(texts=chunk)
|
||||
for res in resp["data"]:
|
||||
lst.extend([res["embedding"]])
|
||||
return lst
|
@ -26,6 +26,7 @@ from langchain.llms.anthropic import Anthropic
|
||||
from langchain.llms.anyscale import Anyscale
|
||||
from langchain.llms.aviary import Aviary
|
||||
from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint
|
||||
from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
|
||||
from langchain.llms.bananadev import Banana
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.baseten import Baseten
|
||||
@ -160,6 +161,7 @@ __all__ = [
|
||||
"Writer",
|
||||
"OctoAIEndpoint",
|
||||
"Xinference",
|
||||
"QianfanLLMEndpoint",
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
@ -228,4 +230,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"vllm_openai": VLLMOpenAI,
|
||||
"writer": Writer,
|
||||
"xinference": Xinference,
|
||||
"qianfan_endpoint": QianfanLLMEndpoint,
|
||||
}
|
||||
|
217
libs/langchain/langchain/llms/baidu_qianfan_endpoint.py
Normal file
217
libs/langchain/langchain/llms/baidu_qianfan_endpoint.py
Normal file
@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QianfanLLMEndpoint(LLM):
|
||||
"""Baidu Qianfan hosted open source or customized models.
|
||||
|
||||
To use, you should have the ``qianfan`` python package installed, and
|
||||
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with
|
||||
your API key and Secret Key.
|
||||
|
||||
ak, sk are required parameters which you could get from
|
||||
https://cloud.baidu.com/product/wenxinworkshop
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import QianfanLLMEndpoint
|
||||
qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
|
||||
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
client: Any
|
||||
|
||||
qianfan_ak: Optional[str] = None
|
||||
qianfan_sk: Optional[str] = None
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
model: str = "ERNIE-Bot-turbo"
|
||||
"""Model name.
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
"""Endpoint of the Qianfan LLM, required if custom model used."""
|
||||
|
||||
request_timeout: Optional[int] = 60
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
temperature: Optional[float] = 0.95
|
||||
penalty_score: Optional[float] = 1
|
||||
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
|
||||
In the case of other model, passing these params will not affect the result.
|
||||
"""
|
||||
|
||||
@root_validator()
|
||||
def validate_enviroment(cls, values: Dict) -> Dict:
|
||||
values["qianfan_ak"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_ak",
|
||||
"QIANFAN_AK",
|
||||
)
|
||||
values["qianfan_sk"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_sk",
|
||||
"QIANFAN_SK",
|
||||
)
|
||||
|
||||
params = {
|
||||
"ak": values["qianfan_ak"],
|
||||
"sk": values["qianfan_sk"],
|
||||
"model": values["model"],
|
||||
}
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
params["endpoint"] = values["endpoint"]
|
||||
try:
|
||||
import qianfan
|
||||
|
||||
values["client"] = qianfan.Completion(**params)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"qianfan package not found, please install it with "
|
||||
"`pip install qianfan`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**{"endpoint": self.endpoint, "model": self.model},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "baidu-qianfan-endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"stream": self.streaming,
|
||||
"request_timeout": self.request_timeout,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
return {
|
||||
**{"prompt": prompt, "model": self.model},
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an qianfan models endpoint for each generation with a prompt.
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = qianfan_model("Tell me a joke.")
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
|
||||
return response_payload["result"]
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
response_payload = await self.client.ado(**params)
|
||||
|
||||
return response_payload["result"]
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
chunk = GenerationChunk(text=res["result"])
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
chunk = GenerationChunk(text=res["result"])
|
||||
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
@ -0,0 +1,85 @@
|
||||
"""Test Baidu Qianfan Chat Endpoint."""
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test default model(`ERNIE-Bot`) call."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_model() -> None:
|
||||
"""Test model kwarg works."""
|
||||
chat = QianfanChatEndpoint(model="BLOOMZ-7B")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_endpoint() -> None:
|
||||
"""Test user custom model deployments like some open source models."""
|
||||
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = QianfanChatEndpoint(streaming=True)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
message = HumanMessage(content="Hi, how are you.")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@ -0,0 +1,25 @@
|
||||
"""Test Baidu Qianfan Embedding Endpoint."""
|
||||
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
|
||||
|
||||
|
||||
def test_embedding_multiple_documents() -> None:
|
||||
documents = ["foo", "bar"]
|
||||
embedding = QianfanEmbeddingsEndpoint()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 384
|
||||
assert len(output[1]) == 384
|
||||
|
||||
|
||||
def test_embedding_query() -> None:
|
||||
query = "foo"
|
||||
embedding = QianfanEmbeddingsEndpoint()
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) == 384
|
||||
|
||||
|
||||
def test_model() -> None:
|
||||
documents = ["hi", "qianfan"]
|
||||
embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
@ -0,0 +1,37 @@
|
||||
"""Test Baidu Qianfan LLM Endpoint."""
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
output = llm("write a joke")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_generate() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
output = llm.generate(["write a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
|
||||
|
||||
def test_generate_stream() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
output = llm.stream("write a joke")
|
||||
assert isinstance(output, Generator)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qianfan_aio() -> None:
|
||||
llm = QianfanLLMEndpoint(streaming=True)
|
||||
|
||||
async for token in llm.astream("hi qianfan."):
|
||||
assert isinstance(token, str)
|
Loading…
Reference in New Issue
Block a user