mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
community[patch]: fix yuan2 errors in LLMs (#19004)
1. fix yuan2 errors while invoke Yuan2. 2. update tests.
This commit is contained in:
parent
aba4bd0d13
commit
b7c8bc8268
@ -41,7 +41,7 @@ class Yuan2(LLM):
|
|||||||
top_p: Optional[float] = 0.9
|
top_p: Optional[float] = 0.9
|
||||||
"""The top-p value to use for sampling."""
|
"""The top-p value to use for sampling."""
|
||||||
|
|
||||||
top_k: Optional[int] = 40
|
top_k: Optional[int] = 0
|
||||||
"""The top-k value to use for sampling."""
|
"""The top-k value to use for sampling."""
|
||||||
|
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
@ -70,6 +70,17 @@ class Yuan2(LLM):
|
|||||||
use_history: bool = False
|
use_history: bool = False
|
||||||
"""Whether to use history or not"""
|
"""Whether to use history or not"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the Yuan2 class."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if (self.top_p or 0) > 0 and (self.top_k or 0) > 0:
|
||||||
|
logger.warning(
|
||||||
|
"top_p and top_k cannot be set simultaneously. "
|
||||||
|
"set top_k to 0 instead..."
|
||||||
|
)
|
||||||
|
self.top_k = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "Yuan2.0"
|
return "Yuan2.0"
|
||||||
@ -86,12 +97,13 @@ class Yuan2(LLM):
|
|||||||
|
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
"do_sample": self.do_sample,
|
||||||
"infer_api": self.infer_api,
|
"infer_api": self.infer_api,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
|
"repeat_penalty": self.repeat_penalty,
|
||||||
"temp": self.temp,
|
"temp": self.temp,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"do_sample": self.do_sample,
|
|
||||||
"use_history": self.use_history,
|
"use_history": self.use_history,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,6 +147,7 @@ class Yuan2(LLM):
|
|||||||
input = prompt
|
input = prompt
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
data = json.dumps(
|
data = json.dumps(
|
||||||
{
|
{
|
||||||
"ques_list": [{"id": "000", "ques": input}],
|
"ques_list": [{"id": "000", "ques": input}],
|
||||||
@ -164,7 +177,7 @@ class Yuan2(LLM):
|
|||||||
if resp["errCode"] != "0":
|
if resp["errCode"] != "0":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed with error code [{resp['errCode']}], "
|
f"Failed with error code [{resp['errCode']}], "
|
||||||
f"error message: [{resp['errMessage']}]"
|
f"error message: [{resp['exceptionMsg']}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
if "resData" in resp:
|
if "resData" in resp:
|
||||||
|
@ -11,7 +11,6 @@ def test_yuan2_call_method() -> None:
|
|||||||
max_tokens=1024,
|
max_tokens=1024,
|
||||||
temp=1.0,
|
temp=1.0,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
top_k=40,
|
|
||||||
use_history=False,
|
use_history=False,
|
||||||
)
|
)
|
||||||
output = llm("写一段快速排序算法。")
|
output = llm("写一段快速排序算法。")
|
||||||
@ -25,7 +24,6 @@ def test_yuan2_generate_method() -> None:
|
|||||||
max_tokens=1024,
|
max_tokens=1024,
|
||||||
temp=1.0,
|
temp=1.0,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
top_k=40,
|
|
||||||
use_history=False,
|
use_history=False,
|
||||||
)
|
)
|
||||||
output = llm.generate(["who are you?"])
|
output = llm.generate(["who are you?"])
|
||||||
|
Loading…
Reference in New Issue
Block a user