mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
community[minor]: integrate chat models with Yuan2.0 (#16575)
1. integrate chat models with [`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) 2. add a new doc for [Yuan2.0 integration](docs/docs/integrations/llms/yuan2.ipynb) Yuan2.0 is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
15baffc484
commit
5d06797905
463
docs/docs/integrations/chat/yuan2.ipynb
Normal file
463
docs/docs/integrations/chat/yuan2.ipynb
Normal file
@ -0,0 +1,463 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"sidebar_label: YUAN2\n",
|
||||||
|
"---"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% raw\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# YUAN2.0\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use [YUAN2 API](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/docs/inference_server.md) in LangChain with the langchain.chat_models.ChatYuan2.\n",
|
||||||
|
"\n",
|
||||||
|
"[*Yuan2.0*](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. And we provide relevant scripts for pretraining, fine-tuning, and inference services for other developers. Yuan2.0 is based on Yuan1.0, utilizing a wider range of high-quality pre training data and instruction fine-tuning datasets to enhance the model's understanding of semantics, mathematics, reasoning, code, knowledge, and other aspects."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
},
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Getting started\n",
|
||||||
|
"### Installation\n",
|
||||||
|
"First, Yuan2.0 provided an OpenAI compatible API, and we integrate ChatYuan2 into langchain chat model by using OpenAI client.\n",
|
||||||
|
"Therefore, ensure the openai package is installed in your Python environment. Run the following command:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install --upgrade --quiet openai"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Importing the Required Modules\n",
|
||||||
|
"After installation, import the necessary modules to your Python script:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.chat_models import ChatYuan2\n",
|
||||||
|
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Setting Up Your API server\n",
|
||||||
|
"Setting up your OpenAI compatible API server following [yuan2 openai api server](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md).\n",
|
||||||
|
"If you deployed api server locally, you can simply set `api_key=\"EMPTY\"` or anything you want.\n",
|
||||||
|
"Just make sure, the `api_base` is set correctly."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"yuan2_api_key = \"your_api_key\"\n",
|
||||||
|
"yuan2_api_base = \"http://127.0.0.1:8001/v1\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Initialize the ChatYuan2 Model\n",
|
||||||
|
"Here's how to initialize the chat model:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatYuan2(\n",
|
||||||
|
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
|
||||||
|
" temperature=1.0,\n",
|
||||||
|
" model_name=\"yuan2\",\n",
|
||||||
|
" max_retries=3,\n",
|
||||||
|
" streaming=False,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Basic Usage\n",
|
||||||
|
"Invoke the model with system and human messages like this:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
},
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" SystemMessage(content=\"你是一个人工智能助手。\"),\n",
|
||||||
|
" HumanMessage(content=\"你好,你是谁?\"),\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(chat(messages))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Basic Usage with streaming\n",
|
||||||
|
"For continuous interaction, use the streaming feature:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
},
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||||
|
"\n",
|
||||||
|
"chat = ChatYuan2(\n",
|
||||||
|
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
|
||||||
|
" temperature=1.0,\n",
|
||||||
|
" model_name=\"yuan2\",\n",
|
||||||
|
" max_retries=3,\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" callbacks=[StreamingStdOutCallbackHandler()],\n",
|
||||||
|
")\n",
|
||||||
|
"messages = [\n",
|
||||||
|
" SystemMessage(content=\"你是个旅游小助手。\"),\n",
|
||||||
|
" HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
},
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
},
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Advanced Features\n",
|
||||||
|
"### Usage with async calls\n",
|
||||||
|
"\n",
|
||||||
|
"Invoke the model with non-blocking calls, like this:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
},
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"async def basic_agenerate():\n",
|
||||||
|
" chat = ChatYuan2(\n",
|
||||||
|
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
|
||||||
|
" temperature=1.0,\n",
|
||||||
|
" model_name=\"yuan2\",\n",
|
||||||
|
" max_retries=3,\n",
|
||||||
|
" )\n",
|
||||||
|
" messages = [\n",
|
||||||
|
" [\n",
|
||||||
|
" SystemMessage(content=\"你是个旅游小助手。\"),\n",
|
||||||
|
" HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
|
||||||
|
" ]\n",
|
||||||
|
" ]\n",
|
||||||
|
"\n",
|
||||||
|
" result = await chat.agenerate(messages)\n",
|
||||||
|
" print(result)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
},
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import asyncio\n",
|
||||||
|
"\n",
|
||||||
|
"asyncio.run(basic_agenerate())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"jupyter": {
|
||||||
|
"outputs_hidden": false
|
||||||
|
},
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Usage with prompt template\n",
|
||||||
|
"\n",
|
||||||
|
"Invoke the model with non-blocking calls and used chat template like this:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"async def ainvoke_with_prompt_template():\n",
|
||||||
|
" from langchain.prompts.chat import (\n",
|
||||||
|
" ChatPromptTemplate,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" chat = ChatYuan2(\n",
|
||||||
|
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
|
||||||
|
" temperature=1.0,\n",
|
||||||
|
" model_name=\"yuan2\",\n",
|
||||||
|
" max_retries=3,\n",
|
||||||
|
" )\n",
|
||||||
|
" prompt = ChatPromptTemplate.from_messages(\n",
|
||||||
|
" [\n",
|
||||||
|
" (\"system\", \"你是一个诗人,擅长写诗。\"),\n",
|
||||||
|
" (\"human\", \"给我写首诗,主题是{theme}。\"),\n",
|
||||||
|
" ]\n",
|
||||||
|
" )\n",
|
||||||
|
" chain = prompt | chat\n",
|
||||||
|
" result = await chain.ainvoke({\"theme\": \"明月\"})\n",
|
||||||
|
" print(f\"type(result): {type(result)}; {result}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"asyncio.run(ainvoke_with_prompt_template())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Usage with async calls in streaming\n",
|
||||||
|
"For non-blocking calls with streaming output, use the astream method:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"async def basic_astream():\n",
|
||||||
|
" chat = ChatYuan2(\n",
|
||||||
|
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
|
||||||
|
" temperature=1.0,\n",
|
||||||
|
" model_name=\"yuan2\",\n",
|
||||||
|
" max_retries=3,\n",
|
||||||
|
" )\n",
|
||||||
|
" messages = [\n",
|
||||||
|
" SystemMessage(content=\"你是个旅游小助手。\"),\n",
|
||||||
|
" HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
|
||||||
|
" ]\n",
|
||||||
|
" result = chat.astream(messages)\n",
|
||||||
|
" async for chunk in result:\n",
|
||||||
|
" print(chunk.content, end=\"\", flush=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
},
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import asyncio\n",
|
||||||
|
"\n",
|
||||||
|
"asyncio.run(basic_astream())"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -54,6 +54,7 @@ from langchain_community.chat_models.tongyi import ChatTongyi
|
|||||||
from langchain_community.chat_models.vertexai import ChatVertexAI
|
from langchain_community.chat_models.vertexai import ChatVertexAI
|
||||||
from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
|
from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
|
||||||
from langchain_community.chat_models.yandex import ChatYandexGPT
|
from langchain_community.chat_models.yandex import ChatYandexGPT
|
||||||
|
from langchain_community.chat_models.yuan2 import ChatYuan2
|
||||||
from langchain_community.chat_models.zhipuai import ChatZhipuAI
|
from langchain_community.chat_models.zhipuai import ChatZhipuAI
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -94,5 +95,6 @@ __all__ = [
|
|||||||
"ChatSparkLLM",
|
"ChatSparkLLM",
|
||||||
"VolcEngineMaasChat",
|
"VolcEngineMaasChat",
|
||||||
"GPTRouter",
|
"GPTRouter",
|
||||||
|
"ChatYuan2",
|
||||||
"ChatZhipuAI",
|
"ChatZhipuAI",
|
||||||
]
|
]
|
||||||
|
486
libs/community/langchain_community/chat_models/yuan2.py
Normal file
486
libs/community/langchain_community/chat_models/yuan2.py
Normal file
@ -0,0 +1,486 @@
|
|||||||
|
"""ChatYuan2 wrapper."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
agenerate_from_stream,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
|
from langchain_core.utils import (
|
||||||
|
get_from_dict_or_env,
|
||||||
|
get_pydantic_field_names,
|
||||||
|
)
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatYuan2(BaseChatModel):
|
||||||
|
"""`Yuan2.0` Chat models API.
|
||||||
|
|
||||||
|
To use, you should have the ``openai-python`` package installed, if package
|
||||||
|
not installed, using ```pip install openai``` to install it. The
|
||||||
|
environment variable ``YUAN2_API_KEY`` set to your API key, if not set,
|
||||||
|
everyone can access apis.
|
||||||
|
|
||||||
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||||
|
in, even if not explicitly saved on this class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatYuan2
|
||||||
|
|
||||||
|
chat = ChatYuan2()
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
|
|
||||||
|
model_name: str = Field(default="yuan2", alias="model")
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
|
||||||
|
yuan2_api_key: Optional[str] = Field(default="EMPTY", alias="api_key")
|
||||||
|
"""Automatically inferred from env var `YUAN2_API_KEY` if not provided."""
|
||||||
|
|
||||||
|
yuan2_api_base: Optional[str] = Field(
|
||||||
|
default="http://127.0.0.1:8000", alias="base_url"
|
||||||
|
)
|
||||||
|
"""Base URL path for API requests, an OpenAI compatible API server."""
|
||||||
|
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||||
|
"""Timeout for requests to yuan2 completion API. Default is 600 seconds."""
|
||||||
|
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to generate."""
|
||||||
|
|
||||||
|
temperature: float = 1.0
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
|
||||||
|
top_p: Optional[float] = 0.9
|
||||||
|
"""The top-p value to use for sampling."""
|
||||||
|
|
||||||
|
stop: Optional[List[str]] = ["<eod>"]
|
||||||
|
"""A list of strings to stop generation when encountered."""
|
||||||
|
|
||||||
|
repeat_last_n: Optional[int] = 64
|
||||||
|
"Last n tokens to penalize"
|
||||||
|
|
||||||
|
repeat_penalty: Optional[float] = 1.18
|
||||||
|
"""The penalty to apply to repeated tokens."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"yuan2_api_key": "YUAN2_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if self.yuan2_api_base:
|
||||||
|
attributes["yuan2_api_base"] = self.yuan2_api_base
|
||||||
|
|
||||||
|
if self.yuan2_api_key:
|
||||||
|
attributes["yuan2_api_key"] = self.yuan2_api_key
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
logger.warning(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["yuan2_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "yuan2_api_key", "YUAN2_API_KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import openai python package. "
|
||||||
|
"Please install it with `pip install openai`."
|
||||||
|
)
|
||||||
|
client_params = {
|
||||||
|
"api_key": values["yuan2_api_key"],
|
||||||
|
"base_url": values["yuan2_api_base"],
|
||||||
|
"timeout": values["request_timeout"],
|
||||||
|
"max_retries": values["max_retries"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# generate client and async_client
|
||||||
|
if not values.get("client"):
|
||||||
|
values["client"] = openai.OpenAI(**client_params).chat.completions
|
||||||
|
if not values.get("async_client"):
|
||||||
|
values["async_client"] = openai.AsyncOpenAI(
|
||||||
|
**client_params
|
||||||
|
).chat.completions
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling yuan2 API."""
|
||||||
|
params = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"stream": self.streaming,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
if self.max_tokens is not None:
|
||||||
|
params["max_tokens"] = self.max_tokens
|
||||||
|
if self.request_timeout is not None:
|
||||||
|
params["request_timeout"] = self.request_timeout
|
||||||
|
return params
|
||||||
|
|
||||||
|
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(self)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.create(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
logger.debug(
|
||||||
|
f"type(llm_outputs): {type(llm_outputs)}; llm_outputs: {llm_outputs}"
|
||||||
|
)
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
for k, v in token_usage.__dict__.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.model_dump()
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
finish_reason = choice.get("finish_reason")
|
||||||
|
generation_info = (
|
||||||
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||||
|
)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=chunk,
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
params = dict(self._invocation_params)
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts, params
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: ChatCompletion) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
logger.debug(f"type(response): {type(response)}; response: {response}")
|
||||||
|
for res in response.choices:
|
||||||
|
message = _convert_dict_to_message(res.message)
|
||||||
|
generation_info = dict(finish_reason=res.finish_reason)
|
||||||
|
if "logprobs" in res:
|
||||||
|
generation_info["logprobs"] = res.logprobs
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message,
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
llm_output = {
|
||||||
|
"token_usage": response.usage,
|
||||||
|
"model_name": self.model_name,
|
||||||
|
}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
async for chunk in await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, **params
|
||||||
|
):
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.model_dump()
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
finish_reason = choice.get("finish_reason")
|
||||||
|
generation_info = (
|
||||||
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||||
|
)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=chunk,
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.content)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model."""
|
||||||
|
yuan2_creds: Dict[str, Any] = {
|
||||||
|
"model": self.model_name,
|
||||||
|
}
|
||||||
|
return {**yuan2_creds, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "chat-yuan2"
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(llm: ChatYuan2) -> Callable[[Any], Any]:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(llm.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.APITimeoutError)
|
||||||
|
| retry_if_exception_type(openai.APIError)
|
||||||
|
| retry_if_exception_type(openai.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.InternalServerError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(llm: ChatYuan2, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||||
|
return await llm.async_client.create(**kwargs)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
_dict: ChatCompletionMessage, default_class: Type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = _dict.get("role")
|
||||||
|
content = _dict.get("content") or ""
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
else:
|
||||||
|
return default_class(content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(_dict: ChatCompletionMessage) -> BaseMessage:
|
||||||
|
role = _dict.get("role")
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict.get("content"))
|
||||||
|
elif role == "assistant":
|
||||||
|
content = _dict.get("content") or ""
|
||||||
|
return AIMessage(content=content)
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict.get("content"))
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict.get("content"), role=role)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
"""Convert a LangChain message to a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The LangChain message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The dictionary.
|
||||||
|
"""
|
||||||
|
message_dict: Dict[str, Any]
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
message_dict = {
|
||||||
|
"role": "function",
|
||||||
|
"name": message.name,
|
||||||
|
"content": message.content,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
if "name" in message.additional_kwargs:
|
||||||
|
message_dict["name"] = message.additional_kwargs["name"]
|
||||||
|
return message_dict
|
152
libs/community/tests/integration_tests/chat_models/test_yuan2.py
Normal file
152
libs/community/tests/integration_tests/chat_models/test_yuan2.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
"""Test ChatYuan2 wrapper."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.callbacks import CallbackManager
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.outputs import (
|
||||||
|
ChatGeneration,
|
||||||
|
LLMResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_community.chat_models.yuan2 import ChatYuan2
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_chat_yuan2() -> None:
|
||||||
|
"""Test ChatYuan2 wrapper."""
|
||||||
|
chat = ChatYuan2(
|
||||||
|
yuan2_api_key="EMPTY",
|
||||||
|
yuan2_api_base="http://127.0.0.1:8001/v1",
|
||||||
|
temperature=1.0,
|
||||||
|
model_name="yuan2",
|
||||||
|
max_retries=3,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
]
|
||||||
|
response = chat(messages)
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_yuan2_system_message() -> None:
|
||||||
|
"""Test ChatYuan2 wrapper with system message."""
|
||||||
|
chat = ChatYuan2(
|
||||||
|
yuan2_api_key="EMPTY",
|
||||||
|
yuan2_api_base="http://127.0.0.1:8001/v1",
|
||||||
|
temperature=1.0,
|
||||||
|
model_name="yuan2",
|
||||||
|
max_retries=3,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are an AI assistant."),
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
]
|
||||||
|
response = chat(messages)
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_chat_yuan2_generate() -> None:
|
||||||
|
"""Test ChatYuan2 wrapper with generate."""
|
||||||
|
chat = ChatYuan2(
|
||||||
|
yuan2_api_key="EMPTY",
|
||||||
|
yuan2_api_base="http://127.0.0.1:8001/v1",
|
||||||
|
temperature=1.0,
|
||||||
|
model_name="yuan2",
|
||||||
|
max_retries=3,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
messages: List = [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
]
|
||||||
|
response = chat.generate([messages])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 1
|
||||||
|
assert response.llm_output
|
||||||
|
generation = response.generations[0]
|
||||||
|
for gen in generation:
|
||||||
|
assert isinstance(gen, ChatGeneration)
|
||||||
|
assert isinstance(gen.text, str)
|
||||||
|
assert gen.text == gen.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_chat_yuan2_streaming() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
|
||||||
|
chat = ChatYuan2(
|
||||||
|
yuan2_api_key="EMPTY",
|
||||||
|
yuan2_api_base="http://127.0.0.1:8001/v1",
|
||||||
|
temperature=1.0,
|
||||||
|
model_name="yuan2",
|
||||||
|
max_retries=3,
|
||||||
|
streaming=True,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
]
|
||||||
|
response = chat(messages)
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_yuan2() -> None:
|
||||||
|
"""Test async generation."""
|
||||||
|
chat = ChatYuan2(
|
||||||
|
yuan2_api_key="EMPTY",
|
||||||
|
yuan2_api_base="http://127.0.0.1:8001/v1",
|
||||||
|
temperature=1.0,
|
||||||
|
model_name="yuan2",
|
||||||
|
max_retries=3,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
messages: List = [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
]
|
||||||
|
response = await chat.agenerate([messages])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 1
|
||||||
|
generations = response.generations[0]
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_yuan2_streaming() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
|
||||||
|
chat = ChatYuan2(
|
||||||
|
yuan2_api_key="EMPTY",
|
||||||
|
yuan2_api_base="http://127.0.0.1:8001/v1",
|
||||||
|
temperature=1.0,
|
||||||
|
model_name="yuan2",
|
||||||
|
max_retries=3,
|
||||||
|
streaming=True,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
messages: List = [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
]
|
||||||
|
response = await chat.agenerate([messages])
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 1
|
||||||
|
generations = response.generations[0]
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
@ -38,6 +38,7 @@ EXPECTED_ALL = [
|
|||||||
"VolcEngineMaasChat",
|
"VolcEngineMaasChat",
|
||||||
"LlamaEdgeChatService",
|
"LlamaEdgeChatService",
|
||||||
"GPTRouter",
|
"GPTRouter",
|
||||||
|
"ChatYuan2",
|
||||||
"ChatZhipuAI",
|
"ChatZhipuAI",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
64
libs/community/tests/unit_tests/chat_models/test_yuan2.py
Normal file
64
libs/community/tests/unit_tests/chat_models/test_yuan2.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""Test ChatYuan2 wrapper."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_community.chat_models.yuan2 import (
|
||||||
|
ChatYuan2,
|
||||||
|
_convert_dict_to_message,
|
||||||
|
_convert_message_to_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_yuan2_model_param() -> None:
|
||||||
|
chat = ChatYuan2(model="foo")
|
||||||
|
assert chat.model_name == "foo"
|
||||||
|
chat = ChatYuan2(model_name="foo")
|
||||||
|
assert chat.model_name == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_human() -> None:
|
||||||
|
message = HumanMessage(content="foo")
|
||||||
|
result = _convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "user", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_ai() -> None:
|
||||||
|
message = AIMessage(content="foo")
|
||||||
|
result = _convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "assistant", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_message_to_dict_system() -> None:
|
||||||
|
message = SystemMessage(content="foo")
|
||||||
|
result = _convert_message_to_dict(message)
|
||||||
|
expected_output = {"role": "system", "content": "foo"}
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_human() -> None:
|
||||||
|
message = {"role": "user", "content": "hello"}
|
||||||
|
result = _convert_dict_to_message(message)
|
||||||
|
expected_output = HumanMessage(content="hello")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_ai() -> None:
|
||||||
|
message = {"role": "assistant", "content": "hello"}
|
||||||
|
result = _convert_dict_to_message(message)
|
||||||
|
expected_output = AIMessage(content="hello")
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__convert_dict_to_message_system() -> None:
|
||||||
|
message = {"role": "system", "content": "hello"}
|
||||||
|
result = _convert_dict_to_message(message)
|
||||||
|
expected_output = SystemMessage(content="hello")
|
||||||
|
assert result == expected_output
|
Loading…
Reference in New Issue
Block a user