mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
community[minor]: integrate with model Yuan2.0 (#15411)
1. integrate with [`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) 2. update `langchain.llms` 3. add a new doc for [Yuan2.0 integration](docs/docs/integrations/llms/yuan2.ipynb) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d07db457fc
commit
c776cfc599
1
.github/workflows/codespell.yml
vendored
1
.github/workflows/codespell.yml
vendored
@ -34,3 +34,4 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
skip: guide_imports.json
|
skip: guide_imports.json
|
||||||
ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }}
|
ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }}
|
||||||
|
exclude_file: libs/community/langchain_community/llms/yuan2.py
|
||||||
|
117
docs/docs/integrations/llms/yuan2.ipynb
Normal file
117
docs/docs/integrations/llms/yuan2.ipynb
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Yuan2.0\n",
|
||||||
|
"\n",
|
||||||
|
"[Yuan2.0](https://github.com/IEIT-Yuan/Yuan-2.0) is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. And we provide relevant scripts for pretraining, fine-tuning, and inference services for other developers. Yuan2.0 is based on Yuan1.0, utilizing a wider range of high-quality pre training data and instruction fine-tuning datasets to enhance the model's understanding of semantics, mathematics, reasoning, code, knowledge, and other aspects.\n",
|
||||||
|
"\n",
|
||||||
|
"This example goes over how to use LangChain to interact with `Yuan2.0`(2B/51B/102B) Inference for text generation.\n",
|
||||||
|
"\n",
|
||||||
|
"Yuan2.0 set up an inference service so user just need request the inference api to get result, which is introduced in [Yuan2.0 Inference-Server](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/docs/inference_server.md)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import LLMChain\n",
|
||||||
|
"from langchain_community.llms.yuan2 import Yuan2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# default infer_api for a local deployed Yuan2.0 inference server\n",
|
||||||
|
"infer_api = \"http://127.0.0.1:8000\"\n",
|
||||||
|
"\n",
|
||||||
|
"# direct access endpoint in a proxied environment\n",
|
||||||
|
"# import os\n",
|
||||||
|
"# os.environ[\"no_proxy\"]=\"localhost,127.0.0.1,::1\"\n",
|
||||||
|
"\n",
|
||||||
|
"yuan_llm = Yuan2(\n",
|
||||||
|
" infer_api=infer_api,\n",
|
||||||
|
" max_tokens=2048,\n",
|
||||||
|
" temp=1.0,\n",
|
||||||
|
" top_p=0.9,\n",
|
||||||
|
" top_k=40,\n",
|
||||||
|
" use_history=False,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# turn on use_history only when you want the Yuan2.0 to keep track of the conversation history\n",
|
||||||
|
"# and send the accumulated context to the backend model api, which make it stateful. By default it is stateless.\n",
|
||||||
|
"# llm.use_history = True"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"question = \"请介绍一下中国。\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true,
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(yuan_llm(question))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "langchain-dev",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -570,6 +570,12 @@ def _import_yandex_gpt() -> Any:
|
|||||||
return YandexGPT
|
return YandexGPT
|
||||||
|
|
||||||
|
|
||||||
|
def _import_yuan2() -> Any:
|
||||||
|
from langchain_community.llms.yuan2 import Yuan2
|
||||||
|
|
||||||
|
return Yuan2
|
||||||
|
|
||||||
|
|
||||||
def _import_volcengine_maas() -> Any:
|
def _import_volcengine_maas() -> Any:
|
||||||
from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
|
from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
|
||||||
|
|
||||||
@ -753,6 +759,8 @@ def __getattr__(name: str) -> Any:
|
|||||||
return _import_xinference()
|
return _import_xinference()
|
||||||
elif name == "YandexGPT":
|
elif name == "YandexGPT":
|
||||||
return _import_yandex_gpt()
|
return _import_yandex_gpt()
|
||||||
|
elif name == "Yuan2":
|
||||||
|
return _import_yuan2()
|
||||||
elif name == "VolcEngineMaasLLM":
|
elif name == "VolcEngineMaasLLM":
|
||||||
return _import_volcengine_maas()
|
return _import_volcengine_maas()
|
||||||
elif name == "type_to_cls_dict":
|
elif name == "type_to_cls_dict":
|
||||||
@ -851,6 +859,7 @@ __all__ = [
|
|||||||
"JavelinAIGateway",
|
"JavelinAIGateway",
|
||||||
"QianfanLLMEndpoint",
|
"QianfanLLMEndpoint",
|
||||||
"YandexGPT",
|
"YandexGPT",
|
||||||
|
"Yuan2",
|
||||||
"VolcEngineMaasLLM",
|
"VolcEngineMaasLLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -939,5 +948,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
|||||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||||
"yandex_gpt": _import_yandex_gpt,
|
"yandex_gpt": _import_yandex_gpt,
|
||||||
|
"yuan2": _import_yuan2,
|
||||||
"VolcEngineMaasLLM": _import_volcengine_maas,
|
"VolcEngineMaasLLM": _import_volcengine_maas,
|
||||||
}
|
}
|
||||||
|
192
libs/community/langchain_community/llms/yuan2.py
Normal file
192
libs/community/langchain_community/llms/yuan2.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional, Set
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.llms import LLM
|
||||||
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Yuan2(LLM):
|
||||||
|
"""Yuan2.0 language models.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
yuan_llm = Yuan2(
|
||||||
|
infer_api="http://127.0.0.1:8000/yuan",
|
||||||
|
max_tokens=1024,
|
||||||
|
temp=1.0,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=40,
|
||||||
|
)
|
||||||
|
print(yuan_llm)
|
||||||
|
print(yuan_llm("你是谁?"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
infer_api: str = "http://127.0.0.1:8000/yuan"
|
||||||
|
"""Yuan2.0 inference api"""
|
||||||
|
|
||||||
|
max_tokens: int = Field(1024, alias="max_token")
|
||||||
|
"""Token context window."""
|
||||||
|
|
||||||
|
temp: Optional[float] = 0.7
|
||||||
|
"""The temperature to use for sampling."""
|
||||||
|
|
||||||
|
top_p: Optional[float] = 0.9
|
||||||
|
"""The top-p value to use for sampling."""
|
||||||
|
|
||||||
|
top_k: Optional[int] = 40
|
||||||
|
"""The top-k value to use for sampling."""
|
||||||
|
|
||||||
|
do_sample: bool = False
|
||||||
|
"""The do_sample is a Boolean value that determines whether
|
||||||
|
to use the sampling method during text generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
"""Whether to echo the prompt."""
|
||||||
|
|
||||||
|
stop: Optional[List[str]] = []
|
||||||
|
"""A list of strings to stop generation when encountered."""
|
||||||
|
|
||||||
|
repeat_last_n: Optional[int] = 64
|
||||||
|
"Last n tokens to penalize"
|
||||||
|
|
||||||
|
repeat_penalty: Optional[float] = 1.18
|
||||||
|
"""The penalty to apply to repeated tokens."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
|
||||||
|
history: List[str] = []
|
||||||
|
"""History of the conversation"""
|
||||||
|
|
||||||
|
use_history: bool = False
|
||||||
|
"""Whether to use history or not"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "Yuan2.0"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _model_param_names() -> Set[str]:
|
||||||
|
return {
|
||||||
|
"max_tokens",
|
||||||
|
"temp",
|
||||||
|
"top_k",
|
||||||
|
"top_p",
|
||||||
|
"do_sample",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"infer_api": self.infer_api,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temp": self.temp,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
"use_history": self.use_history,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"model": self._llm_type,
|
||||||
|
**self._default_params(),
|
||||||
|
**{
|
||||||
|
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call out to a Yuan2.0 LLM inference endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = yuan_llm("你能做什么?")
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.use_history:
|
||||||
|
self.history.append(prompt)
|
||||||
|
input = "<n>".join(self.history)
|
||||||
|
else:
|
||||||
|
input = prompt
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = json.dumps(
|
||||||
|
{
|
||||||
|
"ques_list": [{"id": "000", "ques": input}],
|
||||||
|
"tokens_to_generate": self.max_tokens,
|
||||||
|
"temperature": self.temp,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Yuan2.0 prompt:", input)
|
||||||
|
|
||||||
|
# call api
|
||||||
|
try:
|
||||||
|
response = requests.put(self.infer_api, headers=headers, data=data)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise ValueError(f"Error raised by inference api: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Yuan2.0 response: {response}")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ValueError(f"Failed with response: {response}")
|
||||||
|
try:
|
||||||
|
resp = response.json()
|
||||||
|
|
||||||
|
if resp["errCode"] != "0":
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed with error code [{resp['errCode']}], "
|
||||||
|
f"error message: [{resp['errMessage']}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "resData" in resp:
|
||||||
|
if len(resp["resData"]["output"]) >= 0:
|
||||||
|
generate_text = resp["resData"]["output"][0]["ans"]
|
||||||
|
else:
|
||||||
|
raise ValueError("No output found in response.")
|
||||||
|
else:
|
||||||
|
raise ValueError("No resData found in response.")
|
||||||
|
|
||||||
|
except requests.exceptions.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error raised during decoding response from inference api: {e}."
|
||||||
|
f"\nResponse: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
generate_text = enforce_stop_tokens(generate_text, stop)
|
||||||
|
|
||||||
|
# support multi-turn chat
|
||||||
|
if self.use_history:
|
||||||
|
self.history.append(generate_text)
|
||||||
|
|
||||||
|
logger.debug(f"history: {self.history}")
|
||||||
|
return generate_text
|
33
libs/community/tests/integration_tests/llms/test_yuan2.py
Normal file
33
libs/community/tests/integration_tests/llms/test_yuan2.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""Test Yuan2.0 API wrapper."""
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
from langchain_community.llms import Yuan2
|
||||||
|
|
||||||
|
|
||||||
|
def test_yuan2_call_method() -> None:
|
||||||
|
"""Test valid call to Yuan2.0."""
|
||||||
|
llm = Yuan2(
|
||||||
|
infer_api="http://127.0.0.1:8000/yuan",
|
||||||
|
max_tokens=1024,
|
||||||
|
temp=1.0,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=40,
|
||||||
|
use_history=False,
|
||||||
|
)
|
||||||
|
output = llm("写一段快速排序算法。")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_yuan2_generate_method() -> None:
|
||||||
|
"""Test valid call to Yuan2.0 inference api."""
|
||||||
|
llm = Yuan2(
|
||||||
|
infer_api="http://127.0.0.1:8000/yuan",
|
||||||
|
max_tokens=1024,
|
||||||
|
temp=1.0,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=40,
|
||||||
|
use_history=False,
|
||||||
|
)
|
||||||
|
output = llm.generate(["who are you?"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert isinstance(output.generations, list)
|
@ -87,6 +87,7 @@ EXPECT_ALL = [
|
|||||||
"JavelinAIGateway",
|
"JavelinAIGateway",
|
||||||
"QianfanLLMEndpoint",
|
"QianfanLLMEndpoint",
|
||||||
"YandexGPT",
|
"YandexGPT",
|
||||||
|
"Yuan2",
|
||||||
"VolcEngineMaasLLM",
|
"VolcEngineMaasLLM",
|
||||||
"WatsonxLLM",
|
"WatsonxLLM",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user