fix: chat_models Qianfan not compatiable with SystemMessage (#10642)

- **Description:** QianfanEndpoint bugs for SystemMessages. When the
`SystemMessage` is input as the messages to
`chat_models.QianfanEndpoint`. A `TypeError` will be raised.
  - **Issue:** #10643
  - **Dependencies:** 
  - **Tag maintainer:** @baskaryan
  - **Twitter handle:** no
This commit is contained in:
DanielZzz
2023-09-20 13:35:51 +08:00
committed by GitHub
parent f0198354d9
commit ebe08412ad
5 changed files with 293 additions and 72 deletions

View File

@@ -26,6 +26,7 @@ from langchain.schema.messages import (
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env
@@ -80,7 +81,7 @@ class QianfanChatEndpoint(BaseChatModel):
from langchain.chat_models import QianfanChatEndpoint
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -174,9 +175,35 @@ class QianfanChatEndpoint(BaseChatModel):
self,
messages: List[BaseMessage],
**kwargs: Any,
) -> dict:
) -> Dict[str, Any]:
"""
Converts a list of messages into a dictionary containing the message content
and default parameters.
Args:
messages (List[BaseMessage]): The list of messages.
**kwargs (Any): Optional arguments to add additional parameters to the
resulting dictionary.
Returns:
Dict[str, Any]: A dictionary containing the message content and default
parameters.
"""
messages_dict: Dict[str, Any] = {
"messages": [
convert_message_to_dict(m)
for m in messages
if not isinstance(m, SystemMessage)
]
}
for i in [i for i, m in enumerate(messages) if isinstance(m, SystemMessage)]:
if "system" not in messages_dict:
messages_dict["system"] = ""
messages_dict["system"] += messages[i].content + "\n"
return {
**{"messages": [convert_message_to_dict(m) for m in messages]},
**messages_dict,
**self._default_params,
**kwargs,
}
@@ -206,7 +233,7 @@ class QianfanChatEndpoint(BaseChatModel):
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
generation_info=dict(finish_reason="stop"),
)
return ChatResult(
generations=[gen],
@@ -217,7 +244,7 @@ class QianfanChatEndpoint(BaseChatModel):
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
generation_info=dict(finish_reason="stop"),
)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
@@ -232,12 +259,14 @@ class QianfanChatEndpoint(BaseChatModel):
) -> ChatResult:
if self.streaming:
completion = ""
token_usage = {}
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
generation_info=dict(finish_reason="stop"),
)
return ChatResult(
generations=[gen],
@@ -249,7 +278,7 @@ class QianfanChatEndpoint(BaseChatModel):
generations = []
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="finished"),
generation_info=dict(finish_reason="stop"),
)
generations.append(gen)
token_usage = response_payload.get("usage", {})
@@ -269,11 +298,10 @@ class QianfanChatEndpoint(BaseChatModel):
chunk = ChatGenerationChunk(
text=res["result"],
message=_convert_resp_to_message_chunk(res),
generation_info={"finish_reason": "finished"},
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream(
self,
@@ -286,8 +314,9 @@ class QianfanChatEndpoint(BaseChatModel):
async for res in await self.client.ado(**params):
if res:
chunk = ChatGenerationChunk(
text=res["result"], message=_convert_resp_to_message_chunk(res)
text=res["result"],
message=_convert_resp_to_message_chunk(res),
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)

View File

@@ -37,7 +37,7 @@ class QianfanLLMEndpoint(LLM):
from langchain.llms import QianfanLLMEndpoint
qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -132,6 +132,8 @@ class QianfanLLMEndpoint(LLM):
prompt: str,
**kwargs: Any,
) -> dict:
if "streaming" in kwargs:
kwargs["stream"] = kwargs.pop("streaming")
return {
**{"prompt": prompt, "model": self.model},
**self._default_params,
@@ -191,8 +193,7 @@ class QianfanLLMEndpoint(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
for res in self.client.do(**params):
if res:
chunk = GenerationChunk(text=res["result"])
@@ -207,7 +208,7 @@ class QianfanLLMEndpoint(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
async for res in await self.client.ado(**params):
if res:
chunk = GenerationChunk(text=res["result"])