community[minor]: integrate chat models with Yuan2.0 (#16575)

1. integrate chat models with
[`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md)
2. add a new doc for [Yuan2.0
integration](docs/docs/integrations/llms/yuan2.ipynb)
 
Yuan2.0 is a new generation Fundamental Large Language Model developed
by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan
2.0-51B, and Yuan 2.0-2B.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
wulixuan
2024-02-14 02:55:14 +08:00
committed by GitHub
parent 15baffc484
commit 5d06797905
6 changed files with 1168 additions and 0 deletions

View File

@@ -0,0 +1,152 @@
"""Test ChatYuan2 wrapper."""
from typing import List
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import (
ChatGeneration,
LLMResult,
)
from langchain_community.chat_models.yuan2 import ChatYuan2
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.mark.scheduled
def test_chat_yuan2() -> None:
"""Test ChatYuan2 wrapper."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_yuan2_system_message() -> None:
"""Test ChatYuan2 wrapper with system message."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello"),
]
response = chat(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
@pytest.mark.scheduled
def test_chat_yuan2_generate() -> None:
"""Test ChatYuan2 wrapper with generate."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = chat.generate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert response.llm_output
generation = response.generations[0]
for gen in generation:
assert isinstance(gen, ChatGeneration)
assert isinstance(gen.text, str)
assert gen.text == gen.message.content
@pytest.mark.scheduled
def test_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callback_manager=callback_manager,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat(messages)
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)
@pytest.mark.asyncio
async def test_async_chat_yuan2() -> None:
"""Test async generation."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = await chat.agenerate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_async_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callback_manager=callback_manager,
)
messages: List = [
HumanMessage(content="Hello"),
]
response = await chat.agenerate([messages])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content