community[minor]: integrate chat models with Yuan2.0 (#16575)

1. integrate chat models with
[`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md)
2. add a new doc for [Yuan2.0
integration](docs/docs/integrations/llms/yuan2.ipynb)
 
Yuan2.0 is a new generation Fundamental Large Language Model developed
by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan
2.0-51B, and Yuan 2.0-2B.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
wulixuan 2024-02-14 02:55:14 +08:00 committed by GitHub
parent 15baffc484
commit 5d06797905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1168 additions and 0 deletions

View File

@ -0,0 +1,463 @@
{
"cells": [
{
"cell_type": "raw",
"source": [
"---\n",
"sidebar_label: YUAN2\n",
"---"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% raw\n"
}
}
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# YUAN2.0\n",
"\n",
"This notebook shows how to use [YUAN2 API](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/docs/inference_server.md) in LangChain with the langchain.chat_models.ChatYuan2.\n",
"\n",
"[*Yuan2.0*](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. And we provide relevant scripts for pretraining, fine-tuning, and inference services for other developers. Yuan2.0 is based on Yuan1.0, utilizing a wider range of high-quality pre training data and instruction fine-tuning datasets to enhance the model's understanding of semantics, mathematics, reasoning, code, knowledge, and other aspects."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Getting started\n",
"### Installation\n",
"First, Yuan2.0 provided an OpenAI compatible API, and we integrate ChatYuan2 into langchain chat model by using OpenAI client.\n",
"Therefore, ensure the openai package is installed in your Python environment. Run the following command:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"%pip install --upgrade --quiet openai"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Importing the Required Modules\n",
"After installation, import the necessary modules to your Python script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from langchain_community.chat_models import ChatYuan2\n",
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Setting Up Your API server\n",
"Setting up your OpenAI compatible API server following [yuan2 openai api server](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md).\n",
"If you deployed api server locally, you can simply set `api_key=\"EMPTY\"` or anything you want.\n",
"Just make sure, the `api_base` is set correctly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"yuan2_api_key = \"your_api_key\"\n",
"yuan2_api_base = \"http://127.0.0.1:8001/v1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Initialize the ChatYuan2 Model\n",
"Here's how to initialize the chat model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"chat = ChatYuan2(\n",
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
" temperature=1.0,\n",
" model_name=\"yuan2\",\n",
" max_retries=3,\n",
" streaming=False,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Basic Usage\n",
"Invoke the model with system and human messages like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"scrolled": true
},
"outputs": [],
"source": [
"messages = [\n",
" SystemMessage(content=\"你是一个人工智能助手。\"),\n",
" HumanMessage(content=\"你好,你是谁?\"),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"print(chat(messages))"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Basic Usage with streaming\n",
"For continuous interaction, use the streaming feature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"chat = ChatYuan2(\n",
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
" temperature=1.0,\n",
" model_name=\"yuan2\",\n",
" max_retries=3,\n",
" streaming=True,\n",
" callbacks=[StreamingStdOutCallbackHandler()],\n",
")\n",
"messages = [\n",
" SystemMessage(content=\"你是个旅游小助手。\"),\n",
" HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"chat(messages)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Advanced Features\n",
"### Usage with async calls\n",
"\n",
"Invoke the model with non-blocking calls, like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"async def basic_agenerate():\n",
" chat = ChatYuan2(\n",
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
" temperature=1.0,\n",
" model_name=\"yuan2\",\n",
" max_retries=3,\n",
" )\n",
" messages = [\n",
" [\n",
" SystemMessage(content=\"你是个旅游小助手。\"),\n",
" HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
" ]\n",
" ]\n",
"\n",
" result = await chat.agenerate(messages)\n",
" print(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"asyncio.run(basic_agenerate())"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Usage with prompt template\n",
"\n",
"Invoke the model with non-blocking calls and used chat template like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"async def ainvoke_with_prompt_template():\n",
" from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" )\n",
"\n",
" chat = ChatYuan2(\n",
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
" temperature=1.0,\n",
" model_name=\"yuan2\",\n",
" max_retries=3,\n",
" )\n",
" prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"你是一个诗人,擅长写诗。\"),\n",
" (\"human\", \"给我写首诗,主题是{theme}。\"),\n",
" ]\n",
" )\n",
" chain = prompt | chat\n",
" result = await chain.ainvoke({\"theme\": \"明月\"})\n",
" print(f\"type(result): {type(result)}; {result}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"asyncio.run(ainvoke_with_prompt_template())"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Usage with async calls in streaming\n",
"For non-blocking calls with streaming output, use the astream method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"async def basic_astream():\n",
" chat = ChatYuan2(\n",
" yuan2_api_base=\"http://127.0.0.1:8001/v1\",\n",
" temperature=1.0,\n",
" model_name=\"yuan2\",\n",
" max_retries=3,\n",
" )\n",
" messages = [\n",
" SystemMessage(content=\"你是个旅游小助手。\"),\n",
" HumanMessage(content=\"给我介绍一下北京有哪些好玩的。\"),\n",
" ]\n",
" result = chat.astream(messages)\n",
" async for chunk in result:\n",
" print(chunk.content, end=\"\", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
},
"scrolled": true
},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"asyncio.run(basic_astream())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -54,6 +54,7 @@ from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.chat_models.vertexai import ChatVertexAI
from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
from langchain_community.chat_models.yandex import ChatYandexGPT
from langchain_community.chat_models.yuan2 import ChatYuan2
from langchain_community.chat_models.zhipuai import ChatZhipuAI
__all__ = [
@ -94,5 +95,6 @@ __all__ = [
"ChatSparkLLM",
"VolcEngineMaasChat",
"GPTRouter",
"ChatYuan2",
"ChatZhipuAI",
]

View File

@ -0,0 +1,486 @@
"""ChatYuan2 wrapper."""
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
if TYPE_CHECKING:
from openai.types.chat import ChatCompletion, ChatCompletionMessage
logger = logging.getLogger(__name__)
class ChatYuan2(BaseChatModel):
"""`Yuan2.0` Chat models API.
To use, you should have the ``openai-python`` package installed, if package
not installed, using ```pip install openai``` to install it. The
environment variable ``YUAN2_API_KEY`` set to your API key, if not set,
everyone can access apis.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatYuan2
chat = ChatYuan2()
"""
client: Any #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="yuan2", alias="model")
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
yuan2_api_key: Optional[str] = Field(default="EMPTY", alias="api_key")
"""Automatically inferred from env var `YUAN2_API_KEY` if not provided."""
yuan2_api_base: Optional[str] = Field(
default="http://127.0.0.1:8000", alias="base_url"
)
"""Base URL path for API requests, an OpenAI compatible API server."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to yuan2 completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
temperature: float = 1.0
"""What sampling temperature to use."""
top_p: Optional[float] = 0.9
"""The top-p value to use for sampling."""
stop: Optional[List[str]] = ["<eod>"]
"""A list of strings to stop generation when encountered."""
repeat_last_n: Optional[int] = 64
"Last n tokens to penalize"
repeat_penalty: Optional[float] = 1.18
"""The penalty to apply to repeated tokens."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"yuan2_api_key": "YUAN2_API_KEY"}
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.yuan2_api_base:
attributes["yuan2_api_base"] = self.yuan2_api_base
if self.yuan2_api_key:
attributes["yuan2_api_key"] = self.yuan2_api_key
return attributes
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["yuan2_api_key"] = get_from_dict_or_env(
values, "yuan2_api_key", "YUAN2_API_KEY"
)
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
client_params = {
"api_key": values["yuan2_api_key"],
"base_url": values["yuan2_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
}
# generate client and async_client
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).chat.completions
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling yuan2 API."""
params = {
"model": self.model_name,
"stream": self.streaming,
"temperature": self.temperature,
"top_p": self.top_p,
**self.model_kwargs,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.request_timeout is not None:
params["request_timeout"] = self.request_timeout
return params
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
logger.debug(
f"type(llm_outputs): {type(llm_outputs)}; llm_outputs: {llm_outputs}"
)
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.__dict__.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model_name}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: ChatCompletion) -> ChatResult:
generations = []
logger.debug(f"type(response): {type(response)}; response: {response}")
for res in response.choices:
message = _convert_dict_to_message(res.message)
generation_info = dict(finish_reason=res.finish_reason)
if "logprobs" in res:
generation_info["logprobs"] = res.logprobs
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(gen)
llm_output = {
"token_usage": response.usage,
"model_name": self.model_name,
}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params)
return self._create_chat_result(response)
@property
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
yuan2_creds: Dict[str, Any] = {
"model": self.model_name,
}
return {**yuan2_creds, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-yuan2"
def _create_retry_decorator(llm: ChatYuan2) -> Callable[[Any], Any]:
import openai
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
async def acompletion_with_retry(llm: ChatYuan2, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.async_client.create(**kwargs)
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: ChatCompletionMessage, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def _convert_dict_to_message(_dict: ChatCompletionMessage) -> BaseMessage:
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content"))
elif role == "assistant":
content = _dict.get("content") or ""
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=_dict.get("content"))
else:
return ChatMessage(content=_dict.get("content"), role=role)
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"name": message.name,
"content": message.content,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict

View File

@ -0,0 +1,152 @@
"""Test ChatYuan2 wrapper."""
from typing import List
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import (
ChatGeneration,
LLMResult,
)
from langchain_community.chat_models.yuan2 import ChatYuan2
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.mark.scheduled
def test_chat_yuan2() -> None:
"""Test ChatYuan2 wrapper."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_yuan2_system_message() -> None:
"""Test ChatYuan2 wrapper with system message."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello"),
]
response = chat(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
@pytest.mark.scheduled
def test_chat_yuan2_generate() -> None:
"""Test ChatYuan2 wrapper with generate."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = chat.generate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert response.llm_output
generation = response.generations[0]
for gen in generation:
assert isinstance(gen, ChatGeneration)
assert isinstance(gen.text, str)
assert gen.text == gen.message.content
@pytest.mark.scheduled
def test_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callback_manager=callback_manager,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat(messages)
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)
@pytest.mark.asyncio
async def test_async_chat_yuan2() -> None:
"""Test async generation."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = await chat.agenerate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_async_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callback_manager=callback_manager,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = await chat.agenerate([messages])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

View File

@ -38,6 +38,7 @@ EXPECTED_ALL = [
"VolcEngineMaasChat",
"LlamaEdgeChatService",
"GPTRouter",
"ChatYuan2",
"ChatZhipuAI",
]

View File

@ -0,0 +1,64 @@
"""Test ChatYuan2 wrapper."""
import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain_community.chat_models.yuan2 import (
ChatYuan2,
_convert_dict_to_message,
_convert_message_to_dict,
)
@pytest.mark.requires("openai")
def test_yuan2_model_param() -> None:
chat = ChatYuan2(model="foo")
assert chat.model_name == "foo"
chat = ChatYuan2(model_name="foo")
assert chat.model_name == "foo"
def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "system", "content": "foo"}
assert result == expected_output
def test__convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "hello"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="hello")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "hello"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="hello")
assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "hello"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="hello")
assert result == expected_output