mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
community[minor]: integrate with model Yuan2.0 (#15411)
1. integrate with [`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) 2. update `langchain.llms` 3. add a new doc for [Yuan2.0 integration](docs/docs/integrations/llms/yuan2.ipynb) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d07db457fc
commit
c776cfc599
1
.github/workflows/codespell.yml
vendored
1
.github/workflows/codespell.yml
vendored
@ -34,3 +34,4 @@ jobs:
|
||||
with:
|
||||
skip: guide_imports.json
|
||||
ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }}
|
||||
exclude_file: libs/community/langchain_community/llms/yuan2.py
|
||||
|
117
docs/docs/integrations/llms/yuan2.ipynb
Normal file
117
docs/docs/integrations/llms/yuan2.ipynb
Normal file
@ -0,0 +1,117 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Yuan2.0\n",
|
||||
"\n",
|
||||
"[Yuan2.0](https://github.com/IEIT-Yuan/Yuan-2.0) is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. And we provide relevant scripts for pretraining, fine-tuning, and inference services for other developers. Yuan2.0 is based on Yuan1.0, utilizing a wider range of high-quality pre training data and instruction fine-tuning datasets to enhance the model's understanding of semantics, mathematics, reasoning, code, knowledge, and other aspects.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with `Yuan2.0`(2B/51B/102B) Inference for text generation.\n",
|
||||
"\n",
|
||||
"Yuan2.0 set up an inference service so user just need request the inference api to get result, which is introduced in [Yuan2.0 Inference-Server](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/docs/inference_server.md)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain_community.llms.yuan2 import Yuan2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# default infer_api for a local deployed Yuan2.0 inference server\n",
|
||||
"infer_api = \"http://127.0.0.1:8000\"\n",
|
||||
"\n",
|
||||
"# direct access endpoint in a proxied environment\n",
|
||||
"# import os\n",
|
||||
"# os.environ[\"no_proxy\"]=\"localhost,127.0.0.1,::1\"\n",
|
||||
"\n",
|
||||
"yuan_llm = Yuan2(\n",
|
||||
" infer_api=infer_api,\n",
|
||||
" max_tokens=2048,\n",
|
||||
" temp=1.0,\n",
|
||||
" top_p=0.9,\n",
|
||||
" top_k=40,\n",
|
||||
" use_history=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# turn on use_history only when you want the Yuan2.0 to keep track of the conversation history\n",
|
||||
"# and send the accumulated context to the backend model api, which make it stateful. By default it is stateless.\n",
|
||||
"# llm.use_history = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"请介绍一下中国。\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(yuan_llm(question))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "langchain-dev",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -570,6 +570,12 @@ def _import_yandex_gpt() -> Any:
|
||||
return YandexGPT
|
||||
|
||||
|
||||
def _import_yuan2() -> Any:
|
||||
from langchain_community.llms.yuan2 import Yuan2
|
||||
|
||||
return Yuan2
|
||||
|
||||
|
||||
def _import_volcengine_maas() -> Any:
|
||||
from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
|
||||
|
||||
@ -753,6 +759,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_xinference()
|
||||
elif name == "YandexGPT":
|
||||
return _import_yandex_gpt()
|
||||
elif name == "Yuan2":
|
||||
return _import_yuan2()
|
||||
elif name == "VolcEngineMaasLLM":
|
||||
return _import_volcengine_maas()
|
||||
elif name == "type_to_cls_dict":
|
||||
@ -851,6 +859,7 @@ __all__ = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"Yuan2",
|
||||
"VolcEngineMaasLLM",
|
||||
]
|
||||
|
||||
@ -939,5 +948,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||
"yandex_gpt": _import_yandex_gpt,
|
||||
"yuan2": _import_yuan2,
|
||||
"VolcEngineMaasLLM": _import_volcengine_maas,
|
||||
}
|
||||
|
192
libs/community/langchain_community/llms/yuan2.py
Normal file
192
libs/community/langchain_community/llms/yuan2.py
Normal file
@ -0,0 +1,192 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Yuan2(LLM):
|
||||
"""Yuan2.0 language models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
yuan_llm = Yuan2(
|
||||
infer_api="http://127.0.0.1:8000/yuan",
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
)
|
||||
print(yuan_llm)
|
||||
print(yuan_llm("你是谁?"))
|
||||
"""
|
||||
|
||||
infer_api: str = "http://127.0.0.1:8000/yuan"
|
||||
"""Yuan2.0 inference api"""
|
||||
|
||||
max_tokens: int = Field(1024, alias="max_token")
|
||||
"""Token context window."""
|
||||
|
||||
temp: Optional[float] = 0.7
|
||||
"""The temperature to use for sampling."""
|
||||
|
||||
top_p: Optional[float] = 0.9
|
||||
"""The top-p value to use for sampling."""
|
||||
|
||||
top_k: Optional[int] = 40
|
||||
"""The top-k value to use for sampling."""
|
||||
|
||||
do_sample: bool = False
|
||||
"""The do_sample is a Boolean value that determines whether
|
||||
to use the sampling method during text generation.
|
||||
"""
|
||||
|
||||
echo: Optional[bool] = False
|
||||
"""Whether to echo the prompt."""
|
||||
|
||||
stop: Optional[List[str]] = []
|
||||
"""A list of strings to stop generation when encountered."""
|
||||
|
||||
repeat_last_n: Optional[int] = 64
|
||||
"Last n tokens to penalize"
|
||||
|
||||
repeat_penalty: Optional[float] = 1.18
|
||||
"""The penalty to apply to repeated tokens."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
history: List[str] = []
|
||||
"""History of the conversation"""
|
||||
|
||||
use_history: bool = False
|
||||
"""Whether to use history or not"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "Yuan2.0"
|
||||
|
||||
@staticmethod
|
||||
def _model_param_names() -> Set[str]:
|
||||
return {
|
||||
"max_tokens",
|
||||
"temp",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"do_sample",
|
||||
}
|
||||
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"infer_api": self.infer_api,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temp": self.temp,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"do_sample": self.do_sample,
|
||||
"use_history": self.use_history,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self._llm_type,
|
||||
**self._default_params(),
|
||||
**{
|
||||
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
|
||||
},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to a Yuan2.0 LLM inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = yuan_llm("你能做什么?")
|
||||
"""
|
||||
|
||||
if self.use_history:
|
||||
self.history.append(prompt)
|
||||
input = "<n>".join(self.history)
|
||||
else:
|
||||
input = prompt
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = json.dumps(
|
||||
{
|
||||
"ques_list": [{"id": "000", "ques": input}],
|
||||
"tokens_to_generate": self.max_tokens,
|
||||
"temperature": self.temp,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"do_sample": self.do_sample,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Yuan2.0 prompt:", input)
|
||||
|
||||
# call api
|
||||
try:
|
||||
response = requests.put(self.infer_api, headers=headers, data=data)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Error raised by inference api: {e}")
|
||||
|
||||
logger.debug(f"Yuan2.0 response: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed with response: {response}")
|
||||
try:
|
||||
resp = response.json()
|
||||
|
||||
if resp["errCode"] != "0":
|
||||
raise ValueError(
|
||||
f"Failed with error code [{resp['errCode']}], "
|
||||
f"error message: [{resp['errMessage']}]"
|
||||
)
|
||||
|
||||
if "resData" in resp:
|
||||
if len(resp["resData"]["output"]) >= 0:
|
||||
generate_text = resp["resData"]["output"][0]["ans"]
|
||||
else:
|
||||
raise ValueError("No output found in response.")
|
||||
else:
|
||||
raise ValueError("No resData found in response.")
|
||||
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised during decoding response from inference api: {e}."
|
||||
f"\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
generate_text = enforce_stop_tokens(generate_text, stop)
|
||||
|
||||
# support multi-turn chat
|
||||
if self.use_history:
|
||||
self.history.append(generate_text)
|
||||
|
||||
logger.debug(f"history: {self.history}")
|
||||
return generate_text
|
33
libs/community/tests/integration_tests/llms/test_yuan2.py
Normal file
33
libs/community/tests/integration_tests/llms/test_yuan2.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Test Yuan2.0 API wrapper."""
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms import Yuan2
|
||||
|
||||
|
||||
def test_yuan2_call_method() -> None:
|
||||
"""Test valid call to Yuan2.0."""
|
||||
llm = Yuan2(
|
||||
infer_api="http://127.0.0.1:8000/yuan",
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
use_history=False,
|
||||
)
|
||||
output = llm("写一段快速排序算法。")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_yuan2_generate_method() -> None:
|
||||
"""Test valid call to Yuan2.0 inference api."""
|
||||
llm = Yuan2(
|
||||
infer_api="http://127.0.0.1:8000/yuan",
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
use_history=False,
|
||||
)
|
||||
output = llm.generate(["who are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
@ -87,6 +87,7 @@ EXPECT_ALL = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"Yuan2",
|
||||
"VolcEngineMaasLLM",
|
||||
"WatsonxLLM",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user