community[patch]: fix yuan2 errors in LLMs (#19004)

1. fix yuan2 errors while invoke Yuan2.
2. update tests.
This commit is contained in:
wulixuan 2024-03-29 05:37:44 +08:00 committed by GitHub
parent aba4bd0d13
commit b7c8bc8268
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 5 deletions

View File

@ -41,7 +41,7 @@ class Yuan2(LLM):
top_p: Optional[float] = 0.9 top_p: Optional[float] = 0.9
"""The top-p value to use for sampling.""" """The top-p value to use for sampling."""
top_k: Optional[int] = 40 top_k: Optional[int] = 0
"""The top-k value to use for sampling.""" """The top-k value to use for sampling."""
do_sample: bool = False do_sample: bool = False
@ -70,6 +70,17 @@ class Yuan2(LLM):
use_history: bool = False use_history: bool = False
"""Whether to use history or not""" """Whether to use history or not"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the Yuan2 class."""
super().__init__(**kwargs)
if (self.top_p or 0) > 0 and (self.top_k or 0) > 0:
logger.warning(
"top_p and top_k cannot be set simultaneously. "
"set top_k to 0 instead..."
)
self.top_k = 0
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "Yuan2.0" return "Yuan2.0"
@ -86,12 +97,13 @@ class Yuan2(LLM):
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
return { return {
"do_sample": self.do_sample,
"infer_api": self.infer_api, "infer_api": self.infer_api,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"repeat_penalty": self.repeat_penalty,
"temp": self.temp, "temp": self.temp,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
"do_sample": self.do_sample,
"use_history": self.use_history, "use_history": self.use_history,
} }
@ -135,6 +147,7 @@ class Yuan2(LLM):
input = prompt input = prompt
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
data = json.dumps( data = json.dumps(
{ {
"ques_list": [{"id": "000", "ques": input}], "ques_list": [{"id": "000", "ques": input}],
@ -164,7 +177,7 @@ class Yuan2(LLM):
if resp["errCode"] != "0": if resp["errCode"] != "0":
raise ValueError( raise ValueError(
f"Failed with error code [{resp['errCode']}], " f"Failed with error code [{resp['errCode']}], "
f"error message: [{resp['errMessage']}]" f"error message: [{resp['exceptionMsg']}]"
) )
if "resData" in resp: if "resData" in resp:

View File

@ -11,7 +11,6 @@ def test_yuan2_call_method() -> None:
max_tokens=1024, max_tokens=1024,
temp=1.0, temp=1.0,
top_p=0.9, top_p=0.9,
top_k=40,
use_history=False, use_history=False,
) )
output = llm("写一段快速排序算法。") output = llm("写一段快速排序算法。")
@ -25,7 +24,6 @@ def test_yuan2_generate_method() -> None:
max_tokens=1024, max_tokens=1024,
temp=1.0, temp=1.0,
top_p=0.9, top_p=0.9,
top_k=40,
use_history=False, use_history=False,
) )
output = llm.generate(["who are you?"]) output = llm.generate(["who are you?"])