diff --git a/libs/community/langchain_community/llms/yuan2.py b/libs/community/langchain_community/llms/yuan2.py index 360418a0a0f..0c345f92992 100644 --- a/libs/community/langchain_community/llms/yuan2.py +++ b/libs/community/langchain_community/llms/yuan2.py @@ -41,7 +41,7 @@ class Yuan2(LLM): top_p: Optional[float] = 0.9 """The top-p value to use for sampling.""" - top_k: Optional[int] = 40 + top_k: Optional[int] = 0 """The top-k value to use for sampling.""" do_sample: bool = False @@ -70,6 +70,17 @@ class Yuan2(LLM): use_history: bool = False """Whether to use history or not""" + def __init__(self, **kwargs: Any) -> None: + """Initialize the Yuan2 class.""" + super().__init__(**kwargs) + + if (self.top_p or 0) > 0 and (self.top_k or 0) > 0: + logger.warning( + "top_p and top_k cannot be set simultaneously. " + "set top_k to 0 instead..." + ) + self.top_k = 0 + @property def _llm_type(self) -> str: return "Yuan2.0" @@ -86,12 +97,13 @@ class Yuan2(LLM): def _default_params(self) -> Dict[str, Any]: return { + "do_sample": self.do_sample, "infer_api": self.infer_api, "max_tokens": self.max_tokens, + "repeat_penalty": self.repeat_penalty, "temp": self.temp, "top_k": self.top_k, "top_p": self.top_p, - "do_sample": self.do_sample, "use_history": self.use_history, } @@ -135,6 +147,7 @@ class Yuan2(LLM): input = prompt headers = {"Content-Type": "application/json"} + data = json.dumps( { "ques_list": [{"id": "000", "ques": input}], @@ -164,7 +177,7 @@ class Yuan2(LLM): if resp["errCode"] != "0": raise ValueError( f"Failed with error code [{resp['errCode']}], " - f"error message: [{resp['errMessage']}]" + f"error message: [{resp['exceptionMsg']}]" ) if "resData" in resp: diff --git a/libs/community/tests/integration_tests/llms/test_yuan2.py b/libs/community/tests/integration_tests/llms/test_yuan2.py index 94667819e66..2660a2af58d 100644 --- a/libs/community/tests/integration_tests/llms/test_yuan2.py +++ b/libs/community/tests/integration_tests/llms/test_yuan2.py @@ -11,7 +11,6 @@ def test_yuan2_call_method() -> None: max_tokens=1024, temp=1.0, top_p=0.9, - top_k=40, use_history=False, ) output = llm("写一段快速排序算法。") @@ -25,7 +24,6 @@ def test_yuan2_generate_method() -> None: max_tokens=1024, temp=1.0, top_p=0.9, - top_k=40, use_history=False, ) output = llm.generate(["who are you?"])