mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
community[patch]: fix yuan2 errors in LLMs (#19004)
1. fix yuan2 errors while invoke Yuan2. 2. update tests.
This commit is contained in:
parent
aba4bd0d13
commit
b7c8bc8268
@ -41,7 +41,7 @@ class Yuan2(LLM):
|
||||
top_p: Optional[float] = 0.9
|
||||
"""The top-p value to use for sampling."""
|
||||
|
||||
top_k: Optional[int] = 40
|
||||
top_k: Optional[int] = 0
|
||||
"""The top-k value to use for sampling."""
|
||||
|
||||
do_sample: bool = False
|
||||
@ -70,6 +70,17 @@ class Yuan2(LLM):
|
||||
use_history: bool = False
|
||||
"""Whether to use history or not"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the Yuan2 class."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if (self.top_p or 0) > 0 and (self.top_k or 0) > 0:
|
||||
logger.warning(
|
||||
"top_p and top_k cannot be set simultaneously. "
|
||||
"set top_k to 0 instead..."
|
||||
)
|
||||
self.top_k = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "Yuan2.0"
|
||||
@ -86,12 +97,13 @@ class Yuan2(LLM):
|
||||
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"do_sample": self.do_sample,
|
||||
"infer_api": self.infer_api,
|
||||
"max_tokens": self.max_tokens,
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"temp": self.temp,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"do_sample": self.do_sample,
|
||||
"use_history": self.use_history,
|
||||
}
|
||||
|
||||
@ -135,6 +147,7 @@ class Yuan2(LLM):
|
||||
input = prompt
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
data = json.dumps(
|
||||
{
|
||||
"ques_list": [{"id": "000", "ques": input}],
|
||||
@ -164,7 +177,7 @@ class Yuan2(LLM):
|
||||
if resp["errCode"] != "0":
|
||||
raise ValueError(
|
||||
f"Failed with error code [{resp['errCode']}], "
|
||||
f"error message: [{resp['errMessage']}]"
|
||||
f"error message: [{resp['exceptionMsg']}]"
|
||||
)
|
||||
|
||||
if "resData" in resp:
|
||||
|
@ -11,7 +11,6 @@ def test_yuan2_call_method() -> None:
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
use_history=False,
|
||||
)
|
||||
output = llm("写一段快速排序算法。")
|
||||
@ -25,7 +24,6 @@ def test_yuan2_generate_method() -> None:
|
||||
max_tokens=1024,
|
||||
temp=1.0,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
use_history=False,
|
||||
)
|
||||
output = llm.generate(["who are you?"])
|
||||
|
Loading…
Reference in New Issue
Block a user