community[minor]: integrate with model Yuan2.0 (#15411)

1. integrate with
[`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md)
2. update `langchain.llms`
3. add a new doc for [Yuan2.0
integration](docs/docs/integrations/llms/yuan2.ipynb)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
wulixuan 2024-02-15 03:46:20 +08:00 committed by GitHub
parent d07db457fc
commit c776cfc599
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 354 additions and 0 deletions

View File

@ -34,3 +34,4 @@ jobs:
with: with:
skip: guide_imports.json skip: guide_imports.json
ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }} ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }}
exclude_file: libs/community/langchain_community/llms/yuan2.py

View File

@ -0,0 +1,117 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Yuan2.0\n",
"\n",
"[Yuan2.0](https://github.com/IEIT-Yuan/Yuan-2.0) is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. And we provide relevant scripts for pretraining, fine-tuning, and inference services for other developers. Yuan2.0 is based on Yuan1.0, utilizing a wider range of high-quality pre training data and instruction fine-tuning datasets to enhance the model's understanding of semantics, mathematics, reasoning, code, knowledge, and other aspects.\n",
"\n",
"This example goes over how to use LangChain to interact with `Yuan2.0`(2B/51B/102B) Inference for text generation.\n",
"\n",
"Yuan2.0 set up an inference service so user just need request the inference api to get result, which is introduced in [Yuan2.0 Inference-Server](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/docs/inference_server.md)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain_community.llms.yuan2 import Yuan2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# default infer_api for a local deployed Yuan2.0 inference server\n",
"infer_api = \"http://127.0.0.1:8000\"\n",
"\n",
"# direct access endpoint in a proxied environment\n",
"# import os\n",
"# os.environ[\"no_proxy\"]=\"localhost,127.0.0.1,::1\"\n",
"\n",
"yuan_llm = Yuan2(\n",
" infer_api=infer_api,\n",
" max_tokens=2048,\n",
" temp=1.0,\n",
" top_p=0.9,\n",
" top_k=40,\n",
" use_history=False,\n",
")\n",
"\n",
"# turn on use_history only when you want the Yuan2.0 to keep track of the conversation history\n",
"# and send the accumulated context to the backend model api, which make it stateful. By default it is stateless.\n",
"# llm.use_history = True"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"question = \"请介绍一下中国。\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"print(yuan_llm(question))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -570,6 +570,12 @@ def _import_yandex_gpt() -> Any:
return YandexGPT return YandexGPT
def _import_yuan2() -> Any:
from langchain_community.llms.yuan2 import Yuan2
return Yuan2
def _import_volcengine_maas() -> Any: def _import_volcengine_maas() -> Any:
from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
@ -753,6 +759,8 @@ def __getattr__(name: str) -> Any:
return _import_xinference() return _import_xinference()
elif name == "YandexGPT": elif name == "YandexGPT":
return _import_yandex_gpt() return _import_yandex_gpt()
elif name == "Yuan2":
return _import_yuan2()
elif name == "VolcEngineMaasLLM": elif name == "VolcEngineMaasLLM":
return _import_volcengine_maas() return _import_volcengine_maas()
elif name == "type_to_cls_dict": elif name == "type_to_cls_dict":
@ -851,6 +859,7 @@ __all__ = [
"JavelinAIGateway", "JavelinAIGateway",
"QianfanLLMEndpoint", "QianfanLLMEndpoint",
"YandexGPT", "YandexGPT",
"Yuan2",
"VolcEngineMaasLLM", "VolcEngineMaasLLM",
] ]
@ -939,5 +948,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"javelin-ai-gateway": _import_javelin_ai_gateway, "javelin-ai-gateway": _import_javelin_ai_gateway,
"qianfan_endpoint": _import_baidu_qianfan_endpoint, "qianfan_endpoint": _import_baidu_qianfan_endpoint,
"yandex_gpt": _import_yandex_gpt, "yandex_gpt": _import_yandex_gpt,
"yuan2": _import_yuan2,
"VolcEngineMaasLLM": _import_volcengine_maas, "VolcEngineMaasLLM": _import_volcengine_maas,
} }

View File

@ -0,0 +1,192 @@
import json
import logging
from typing import Any, Dict, List, Mapping, Optional, Set
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field
from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
class Yuan2(LLM):
"""Yuan2.0 language models.
Example:
.. code-block:: python
yuan_llm = Yuan2(
infer_api="http://127.0.0.1:8000/yuan",
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
)
print(yuan_llm)
print(yuan_llm("你是谁?"))
"""
infer_api: str = "http://127.0.0.1:8000/yuan"
"""Yuan2.0 inference api"""
max_tokens: int = Field(1024, alias="max_token")
"""Token context window."""
temp: Optional[float] = 0.7
"""The temperature to use for sampling."""
top_p: Optional[float] = 0.9
"""The top-p value to use for sampling."""
top_k: Optional[int] = 40
"""The top-k value to use for sampling."""
do_sample: bool = False
"""The do_sample is a Boolean value that determines whether
to use the sampling method during text generation.
"""
echo: Optional[bool] = False
"""Whether to echo the prompt."""
stop: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
repeat_last_n: Optional[int] = 64
"Last n tokens to penalize"
repeat_penalty: Optional[float] = 1.18
"""The penalty to apply to repeated tokens."""
streaming: bool = False
"""Whether to stream the results or not."""
history: List[str] = []
"""History of the conversation"""
use_history: bool = False
"""Whether to use history or not"""
@property
def _llm_type(self) -> str:
return "Yuan2.0"
@staticmethod
def _model_param_names() -> Set[str]:
return {
"max_tokens",
"temp",
"top_k",
"top_p",
"do_sample",
}
def _default_params(self) -> Dict[str, Any]:
return {
"infer_api": self.infer_api,
"max_tokens": self.max_tokens,
"temp": self.temp,
"top_k": self.top_k,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_history": self.use_history,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model": self._llm_type,
**self._default_params(),
**{
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to a Yuan2.0 LLM inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = yuan_llm("你能做什么?")
"""
if self.use_history:
self.history.append(prompt)
input = "<n>".join(self.history)
else:
input = prompt
headers = {"Content-Type": "application/json"}
data = json.dumps(
{
"ques_list": [{"id": "000", "ques": input}],
"tokens_to_generate": self.max_tokens,
"temperature": self.temp,
"top_p": self.top_p,
"top_k": self.top_k,
"do_sample": self.do_sample,
}
)
logger.debug("Yuan2.0 prompt:", input)
# call api
try:
response = requests.put(self.infer_api, headers=headers, data=data)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference api: {e}")
logger.debug(f"Yuan2.0 response: {response}")
if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")
try:
resp = response.json()
if resp["errCode"] != "0":
raise ValueError(
f"Failed with error code [{resp['errCode']}], "
f"error message: [{resp['errMessage']}]"
)
if "resData" in resp:
if len(resp["resData"]["output"]) >= 0:
generate_text = resp["resData"]["output"][0]["ans"]
else:
raise ValueError("No output found in response.")
else:
raise ValueError("No resData found in response.")
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised during decoding response from inference api: {e}."
f"\nResponse: {response.text}"
)
if stop is not None:
generate_text = enforce_stop_tokens(generate_text, stop)
# support multi-turn chat
if self.use_history:
self.history.append(generate_text)
logger.debug(f"history: {self.history}")
return generate_text

View File

@ -0,0 +1,33 @@
"""Test Yuan2.0 API wrapper."""
from langchain_core.outputs import LLMResult
from langchain_community.llms import Yuan2
def test_yuan2_call_method() -> None:
"""Test valid call to Yuan2.0."""
llm = Yuan2(
infer_api="http://127.0.0.1:8000/yuan",
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
use_history=False,
)
output = llm("写一段快速排序算法。")
assert isinstance(output, str)
def test_yuan2_generate_method() -> None:
"""Test valid call to Yuan2.0 inference api."""
llm = Yuan2(
infer_api="http://127.0.0.1:8000/yuan",
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
use_history=False,
)
output = llm.generate(["who are you?"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)

View File

@ -87,6 +87,7 @@ EXPECT_ALL = [
"JavelinAIGateway", "JavelinAIGateway",
"QianfanLLMEndpoint", "QianfanLLMEndpoint",
"YandexGPT", "YandexGPT",
"Yuan2",
"VolcEngineMaasLLM", "VolcEngineMaasLLM",
"WatsonxLLM", "WatsonxLLM",
] ]