mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
fix: chat_models Qianfan not compatiable with SystemMessage (#10642)
- **Description:** QianfanEndpoint bugs for SystemMessages. When the `SystemMessage` is input as the messages to `chat_models.QianfanEndpoint`. A `TypeError` will be raised. - **Issue:** #10643 - **Dependencies:** - **Tag maintainer:** @baskaryan - **Twitter handle:** no
This commit is contained in:
@@ -26,6 +26,7 @@ from langchain.schema.messages import (
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
@@ -80,7 +81,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
|
||||
from langchain.chat_models import QianfanChatEndpoint
|
||||
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
|
||||
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
|
||||
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -174,9 +175,35 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts a list of messages into a dictionary containing the message content
|
||||
and default parameters.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The list of messages.
|
||||
**kwargs (Any): Optional arguments to add additional parameters to the
|
||||
resulting dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the message content and default
|
||||
parameters.
|
||||
|
||||
"""
|
||||
messages_dict: Dict[str, Any] = {
|
||||
"messages": [
|
||||
convert_message_to_dict(m)
|
||||
for m in messages
|
||||
if not isinstance(m, SystemMessage)
|
||||
]
|
||||
}
|
||||
for i in [i for i, m in enumerate(messages) if isinstance(m, SystemMessage)]:
|
||||
if "system" not in messages_dict:
|
||||
messages_dict["system"] = ""
|
||||
messages_dict["system"] += messages[i].content + "\n"
|
||||
|
||||
return {
|
||||
**{"messages": [convert_message_to_dict(m) for m in messages]},
|
||||
**messages_dict,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
@@ -206,7 +233,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
@@ -217,7 +244,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
@@ -232,12 +259,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
token_usage = {}
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
@@ -249,7 +278,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
generations = []
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
@@ -269,11 +298,10 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=_convert_resp_to_message_chunk(res),
|
||||
generation_info={"finish_reason": "finished"},
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@@ -286,8 +314,9 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"], message=_convert_resp_to_message_chunk(res)
|
||||
text=res["result"],
|
||||
message=_convert_resp_to_message_chunk(res),
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
@@ -37,7 +37,7 @@ class QianfanLLMEndpoint(LLM):
|
||||
|
||||
from langchain.llms import QianfanLLMEndpoint
|
||||
qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
|
||||
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
|
||||
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -132,6 +132,8 @@ class QianfanLLMEndpoint(LLM):
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
if "streaming" in kwargs:
|
||||
kwargs["stream"] = kwargs.pop("streaming")
|
||||
return {
|
||||
**{"prompt": prompt, "model": self.model},
|
||||
**self._default_params,
|
||||
@@ -191,8 +193,7 @@ class QianfanLLMEndpoint(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
|
||||
params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
chunk = GenerationChunk(text=res["result"])
|
||||
@@ -207,7 +208,7 @@ class QianfanLLMEndpoint(LLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
chunk = GenerationChunk(text=res["result"])
|
||||
|
Reference in New Issue
Block a user