community[patch]: fix yuan2 errors in LLMs (#19004)

1. fix yuan2 errors while invoke Yuan2.
2. update tests.
This commit is contained in:
wulixuan 2024-03-29 05:37:44 +08:00 committed by GitHub
parent aba4bd0d13
commit b7c8bc8268
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 5 deletions

View File

@ -41,7 +41,7 @@ class Yuan2(LLM):
top_p: Optional[float] = 0.9
"""The top-p value to use for sampling."""
top_k: Optional[int] = 40
top_k: Optional[int] = 0
"""The top-k value to use for sampling."""
do_sample: bool = False
@ -70,6 +70,17 @@ class Yuan2(LLM):
use_history: bool = False
"""Whether to use history or not"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the Yuan2 class."""
super().__init__(**kwargs)
if (self.top_p or 0) > 0 and (self.top_k or 0) > 0:
logger.warning(
"top_p and top_k cannot be set simultaneously. "
"set top_k to 0 instead..."
)
self.top_k = 0
@property
def _llm_type(self) -> str:
return "Yuan2.0"
@ -86,12 +97,13 @@ class Yuan2(LLM):
def _default_params(self) -> Dict[str, Any]:
return {
"do_sample": self.do_sample,
"infer_api": self.infer_api,
"max_tokens": self.max_tokens,
"repeat_penalty": self.repeat_penalty,
"temp": self.temp,
"top_k": self.top_k,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_history": self.use_history,
}
@ -135,6 +147,7 @@ class Yuan2(LLM):
input = prompt
headers = {"Content-Type": "application/json"}
data = json.dumps(
{
"ques_list": [{"id": "000", "ques": input}],
@ -164,7 +177,7 @@ class Yuan2(LLM):
if resp["errCode"] != "0":
raise ValueError(
f"Failed with error code [{resp['errCode']}], "
f"error message: [{resp['errMessage']}]"
f"error message: [{resp['exceptionMsg']}]"
)
if "resData" in resp:

View File

@ -11,7 +11,6 @@ def test_yuan2_call_method() -> None:
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
use_history=False,
)
output = llm("写一段快速排序算法。")
@ -25,7 +24,6 @@ def test_yuan2_generate_method() -> None:
max_tokens=1024,
temp=1.0,
top_p=0.9,
top_k=40,
use_history=False,
)
output = llm.generate(["who are you?"])