mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
community[minor]: integrate with model Yuan2.0 (#15411)
1. integrate with [`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) 2. update `langchain.llms` 3. add a new doc for [Yuan2.0 integration](docs/docs/integrations/llms/yuan2.ipynb) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -570,6 +570,12 @@ def _import_yandex_gpt() -> Any:
|
||||
return YandexGPT
|
||||
|
||||
|
||||
def _import_yuan2() -> Any:
|
||||
from langchain_community.llms.yuan2 import Yuan2
|
||||
|
||||
return Yuan2
|
||||
|
||||
|
||||
def _import_volcengine_maas() -> Any:
|
||||
from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM
|
||||
|
||||
@@ -753,6 +759,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_xinference()
|
||||
elif name == "YandexGPT":
|
||||
return _import_yandex_gpt()
|
||||
elif name == "Yuan2":
|
||||
return _import_yuan2()
|
||||
elif name == "VolcEngineMaasLLM":
|
||||
return _import_volcengine_maas()
|
||||
elif name == "type_to_cls_dict":
|
||||
@@ -851,6 +859,7 @@ __all__ = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"Yuan2",
|
||||
"VolcEngineMaasLLM",
|
||||
]
|
||||
|
||||
@@ -939,5 +948,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||
"yandex_gpt": _import_yandex_gpt,
|
||||
"yuan2": _import_yuan2,
|
||||
"VolcEngineMaasLLM": _import_volcengine_maas,
|
||||
}
|
||||
|
192
libs/community/langchain_community/llms/yuan2.py
Normal file
192
libs/community/langchain_community/llms/yuan2.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Yuan2(LLM):
|
||||
"""Yuan2.0 language models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
yuan_llm = Yuan2(
|
||||
infer_api="http://127.0.0.1:8000/yuan",
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
)
|
||||
print(yuan_llm)
|
||||
print(yuan_llm("你是谁?"))
|
||||
"""
|
||||
|
||||
infer_api: str = "http://127.0.0.1:8000/yuan"
|
||||
"""Yuan2.0 inference api"""
|
||||
|
||||
max_tokens: int = Field(1024, alias="max_token")
|
||||
"""Token context window."""
|
||||
|
||||
temp: Optional[float] = 0.7
|
||||
"""The temperature to use for sampling."""
|
||||
|
||||
top_p: Optional[float] = 0.9
|
||||
"""The top-p value to use for sampling."""
|
||||
|
||||
top_k: Optional[int] = 40
|
||||
"""The top-k value to use for sampling."""
|
||||
|
||||
do_sample: bool = False
|
||||
"""The do_sample is a Boolean value that determines whether
|
||||
to use the sampling method during text generation.
|
||||
"""
|
||||
|
||||
echo: Optional[bool] = False
|
||||
"""Whether to echo the prompt."""
|
||||
|
||||
stop: Optional[List[str]] = []
|
||||
"""A list of strings to stop generation when encountered."""
|
||||
|
||||
repeat_last_n: Optional[int] = 64
|
||||
"Last n tokens to penalize"
|
||||
|
||||
repeat_penalty: Optional[float] = 1.18
|
||||
"""The penalty to apply to repeated tokens."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
history: List[str] = []
|
||||
"""History of the conversation"""
|
||||
|
||||
use_history: bool = False
|
||||
"""Whether to use history or not"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "Yuan2.0"
|
||||
|
||||
@staticmethod
|
||||
def _model_param_names() -> Set[str]:
|
||||
return {
|
||||
"max_tokens",
|
||||
"temp",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"do_sample",
|
||||
}
|
||||
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"infer_api": self.infer_api,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temp": self.temp,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"do_sample": self.do_sample,
|
||||
"use_history": self.use_history,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self._llm_type,
|
||||
**self._default_params(),
|
||||
**{
|
||||
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
|
||||
},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to a Yuan2.0 LLM inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = yuan_llm("你能做什么?")
|
||||
"""
|
||||
|
||||
if self.use_history:
|
||||
self.history.append(prompt)
|
||||
input = "<n>".join(self.history)
|
||||
else:
|
||||
input = prompt
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = json.dumps(
|
||||
{
|
||||
"ques_list": [{"id": "000", "ques": input}],
|
||||
"tokens_to_generate": self.max_tokens,
|
||||
"temperature": self.temp,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"do_sample": self.do_sample,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Yuan2.0 prompt:", input)
|
||||
|
||||
# call api
|
||||
try:
|
||||
response = requests.put(self.infer_api, headers=headers, data=data)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Error raised by inference api: {e}")
|
||||
|
||||
logger.debug(f"Yuan2.0 response: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed with response: {response}")
|
||||
try:
|
||||
resp = response.json()
|
||||
|
||||
if resp["errCode"] != "0":
|
||||
raise ValueError(
|
||||
f"Failed with error code [{resp['errCode']}], "
|
||||
f"error message: [{resp['errMessage']}]"
|
||||
)
|
||||
|
||||
if "resData" in resp:
|
||||
if len(resp["resData"]["output"]) >= 0:
|
||||
generate_text = resp["resData"]["output"][0]["ans"]
|
||||
else:
|
||||
raise ValueError("No output found in response.")
|
||||
else:
|
||||
raise ValueError("No resData found in response.")
|
||||
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised during decoding response from inference api: {e}."
|
||||
f"\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
generate_text = enforce_stop_tokens(generate_text, stop)
|
||||
|
||||
# support multi-turn chat
|
||||
if self.use_history:
|
||||
self.history.append(generate_text)
|
||||
|
||||
logger.debug(f"history: {self.history}")
|
||||
return generate_text
|
33
libs/community/tests/integration_tests/llms/test_yuan2.py
Normal file
33
libs/community/tests/integration_tests/llms/test_yuan2.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Test Yuan2.0 API wrapper."""
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms import Yuan2
|
||||
|
||||
|
||||
def test_yuan2_call_method() -> None:
|
||||
"""Test valid call to Yuan2.0."""
|
||||
llm = Yuan2(
|
||||
infer_api="http://127.0.0.1:8000/yuan",
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
use_history=False,
|
||||
)
|
||||
output = llm("写一段快速排序算法。")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_yuan2_generate_method() -> None:
|
||||
"""Test valid call to Yuan2.0 inference api."""
|
||||
llm = Yuan2(
|
||||
infer_api="http://127.0.0.1:8000/yuan",
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
use_history=False,
|
||||
)
|
||||
output = llm.generate(["who are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
@@ -87,6 +87,7 @@ EXPECT_ALL = [
|
||||
"JavelinAIGateway",
|
||||
"QianfanLLMEndpoint",
|
||||
"YandexGPT",
|
||||
"Yuan2",
|
||||
"VolcEngineMaasLLM",
|
||||
"WatsonxLLM",
|
||||
]
|
||||
|
Reference in New Issue
Block a user