mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
community[patch]: fix qianfan chat stream calling caused exception (#13800)
- **Description:** `QianfanChatEndpoint` extends `BaseChatModel` as a super class, which has a default stream implement might concat the MessageChunk with `__add__`. When call stream(), a ValueError for duplicated key will be raise. - **Issues:** * #13546 * #13548 * merge two single test file related to qianfan. - **Dependencies:** no - **Tag maintainer:** --------- Co-authored-by: root <liujun45@baidu.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
656e87beb9
commit
70b6315b23
@ -55,17 +55,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[INFO] [09-15 20:00:29] logging.py:55 [t:139698882193216]: requesting llm api endpoint: /chat/eb-instant\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"\"\"\"For basic init and call\"\"\"\n",
|
"\"\"\"For basic init and call\"\"\"\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
@ -126,9 +118,7 @@
|
|||||||
"from langchain.schema import HumanMessage\n",
|
"from langchain.schema import HumanMessage\n",
|
||||||
"from langchain_community.chat_models import QianfanChatEndpoint\n",
|
"from langchain_community.chat_models import QianfanChatEndpoint\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chatLLM = QianfanChatEndpoint(\n",
|
"chatLLM = QianfanChatEndpoint()\n",
|
||||||
" streaming=True,\n",
|
|
||||||
")\n",
|
|
||||||
"res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
|
"res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
|
||||||
"for r in res:\n",
|
"for r in res:\n",
|
||||||
" print(\"chat resp:\", r)\n",
|
" print(\"chat resp:\", r)\n",
|
||||||
@ -260,11 +250,11 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.11.5"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
|
"hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast
|
||||||
|
|
||||||
@ -244,7 +242,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
completion = ""
|
completion = ""
|
||||||
|
token_usage = {}
|
||||||
|
chat_generation_info: Dict = {}
|
||||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||||
|
chat_generation_info = (
|
||||||
|
chunk.generation_info
|
||||||
|
if chunk.generation_info is not None
|
||||||
|
else chat_generation_info
|
||||||
|
)
|
||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||||
gen = ChatGeneration(
|
gen = ChatGeneration(
|
||||||
@ -253,7 +258,10 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
generations=[gen],
|
generations=[gen],
|
||||||
llm_output={"token_usage": {}, "model_name": self.model},
|
llm_output={
|
||||||
|
"token_usage": chat_generation_info.get("usage", {}),
|
||||||
|
"model_name": self.model,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
response_payload = self.client.do(**params)
|
response_payload = self.client.do(**params)
|
||||||
@ -279,7 +287,13 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
if self.streaming:
|
if self.streaming:
|
||||||
completion = ""
|
completion = ""
|
||||||
token_usage = {}
|
token_usage = {}
|
||||||
|
chat_generation_info: Dict = {}
|
||||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||||
|
chat_generation_info = (
|
||||||
|
chunk.generation_info
|
||||||
|
if chunk.generation_info is not None
|
||||||
|
else chat_generation_info
|
||||||
|
)
|
||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
|
|
||||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||||
@ -289,7 +303,10 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
generations=[gen],
|
generations=[gen],
|
||||||
llm_output={"token_usage": {}, "model_name": self.model},
|
llm_output={
|
||||||
|
"token_usage": chat_generation_info.get("usage", {}),
|
||||||
|
"model_name": self.model,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
response_payload = await self.client.ado(**params)
|
response_payload = await self.client.ado(**params)
|
||||||
@ -315,16 +332,19 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
|
params["stream"] = True
|
||||||
for res in self.client.do(**params):
|
for res in self.client.do(**params):
|
||||||
if res:
|
if res:
|
||||||
msg = _convert_dict_to_message(res)
|
msg = _convert_dict_to_message(res)
|
||||||
|
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
text=res["result"],
|
text=res["result"],
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
content=msg.content,
|
content=msg.content,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
additional_kwargs=msg.additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
),
|
),
|
||||||
|
generation_info=msg.additional_kwargs,
|
||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -338,16 +358,19 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||||
|
params["stream"] = True
|
||||||
async for res in await self.client.ado(**params):
|
async for res in await self.client.ado(**params):
|
||||||
if res:
|
if res:
|
||||||
msg = _convert_dict_to_message(res)
|
msg = _convert_dict_to_message(res)
|
||||||
|
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
text=res["result"],
|
text=res["result"],
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
content=msg.content,
|
content=msg.content,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
additional_kwargs=msg.additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
),
|
),
|
||||||
|
generation_info=msg.additional_kwargs,
|
||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
|
@ -1,53 +0,0 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
|
||||||
from pytest import CaptureFixture, MonkeyPatch
|
|
||||||
|
|
||||||
from langchain_community.chat_models.baidu_qianfan_endpoint import (
|
|
||||||
QianfanChatEndpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_qianfan_key_masked_when_passed_from_env(
|
|
||||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
|
||||||
) -> None:
|
|
||||||
"""Test initialization with an API key provided via an env variable"""
|
|
||||||
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
|
|
||||||
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")
|
|
||||||
|
|
||||||
chat = QianfanChatEndpoint()
|
|
||||||
print(chat.qianfan_ak, end="")
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert captured.out == "**********"
|
|
||||||
|
|
||||||
print(chat.qianfan_sk, end="")
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert captured.out == "**********"
|
|
||||||
|
|
||||||
|
|
||||||
def test_qianfan_key_masked_when_passed_via_constructor(
|
|
||||||
capsys: CaptureFixture,
|
|
||||||
) -> None:
|
|
||||||
"""Test initialization with an API key provided via the initializer"""
|
|
||||||
chat = QianfanChatEndpoint(
|
|
||||||
qianfan_ak="test-api-key",
|
|
||||||
qianfan_sk="test-secret-key",
|
|
||||||
)
|
|
||||||
print(chat.qianfan_ak, end="")
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert captured.out == "**********"
|
|
||||||
|
|
||||||
print(chat.qianfan_sk, end="")
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
|
|
||||||
assert captured.out == "**********"
|
|
||||||
|
|
||||||
|
|
||||||
def test_uses_actual_secret_value_from_secret_str() -> None:
|
|
||||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
|
||||||
chat = QianfanChatEndpoint(
|
|
||||||
qianfan_ak="test-api-key",
|
|
||||||
qianfan_sk="test-secret-key",
|
|
||||||
)
|
|
||||||
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
|
|
||||||
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"
|
|
@ -1,18 +1,24 @@
|
|||||||
"""Test Baidu Qianfan Chat Endpoint."""
|
"""Test Baidu Qianfan Chat Endpoint."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain_core.callbacks import CallbackManager
|
from langchain_core.callbacks import CallbackManager
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
|
from pytest import CaptureFixture, MonkeyPatch
|
||||||
|
|
||||||
from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
from langchain_community.chat_models.baidu_qianfan_endpoint import (
|
||||||
|
QianfanChatEndpoint,
|
||||||
|
)
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
_FUNCTIONS: Any = [
|
_FUNCTIONS: Any = [
|
||||||
@ -139,6 +145,25 @@ def test_multiple_history() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_generate() -> None:
|
||||||
|
"""Tests chat generate works."""
|
||||||
|
chat = QianfanChatEndpoint()
|
||||||
|
response = chat.generate(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
HumanMessage(content="Hello."),
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
HumanMessage(content="How are you doing?"),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
for generations in response.generations:
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
|
||||||
|
|
||||||
def test_stream() -> None:
|
def test_stream() -> None:
|
||||||
"""Test that stream works."""
|
"""Test that stream works."""
|
||||||
chat = QianfanChatEndpoint(streaming=True)
|
chat = QianfanChatEndpoint(streaming=True)
|
||||||
@ -156,6 +181,57 @@ def test_stream() -> None:
|
|||||||
assert callback_handler.llm_streams > 0
|
assert callback_handler.llm_streams > 0
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
res = chat.stream(
|
||||||
|
[
|
||||||
|
HumanMessage(content="Hello."),
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
HumanMessage(content="Who are you?"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(list(res)) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_invoke() -> None:
|
||||||
|
chat = QianfanChatEndpoint()
|
||||||
|
res = await chat.ainvoke([HumanMessage(content="Hello")])
|
||||||
|
assert isinstance(res, BaseMessage)
|
||||||
|
assert res.content != ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_generate() -> None:
|
||||||
|
"""Tests chat agenerate works."""
|
||||||
|
chat = QianfanChatEndpoint()
|
||||||
|
response = await chat.agenerate(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
HumanMessage(content="Hello."),
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
HumanMessage(content="How are you doing?"),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
for generations in response.generations:
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_stream() -> None:
|
||||||
|
chat = QianfanChatEndpoint(streaming=True)
|
||||||
|
async for token in chat.astream(
|
||||||
|
[
|
||||||
|
HumanMessage(content="Hello."),
|
||||||
|
AIMessage(content="Hello!"),
|
||||||
|
HumanMessage(content="Who are you?"),
|
||||||
|
]
|
||||||
|
):
|
||||||
|
assert isinstance(token, BaseMessageChunk)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_messages() -> None:
|
def test_multiple_messages() -> None:
|
||||||
"""Tests multiple messages works."""
|
"""Tests multiple messages works."""
|
||||||
@ -232,3 +308,48 @@ def test_rate_limit() -> None:
|
|||||||
for res in responses:
|
for res in responses:
|
||||||
assert isinstance(res, BaseMessage)
|
assert isinstance(res, BaseMessage)
|
||||||
assert isinstance(res.content, str)
|
assert isinstance(res.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_qianfan_key_masked_when_passed_from_env(
|
||||||
|
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test initialization with an API key provided via an env variable"""
|
||||||
|
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
|
||||||
|
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")
|
||||||
|
|
||||||
|
chat = QianfanChatEndpoint()
|
||||||
|
print(chat.qianfan_ak, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
print(chat.qianfan_sk, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
|
||||||
|
def test_qianfan_key_masked_when_passed_via_constructor(
|
||||||
|
capsys: CaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test initialization with an API key provided via the initializer"""
|
||||||
|
chat = QianfanChatEndpoint(
|
||||||
|
qianfan_ak="test-api-key",
|
||||||
|
qianfan_sk="test-secret-key",
|
||||||
|
)
|
||||||
|
print(chat.qianfan_ak, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
print(chat.qianfan_sk, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
|
||||||
|
def test_uses_actual_secret_value_from_secret_str() -> None:
|
||||||
|
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||||
|
chat = QianfanChatEndpoint(
|
||||||
|
qianfan_ak="test-api-key",
|
||||||
|
qianfan_sk="test-secret-key",
|
||||||
|
)
|
||||||
|
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
|
||||||
|
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"
|
||||||
|
Loading…
Reference in New Issue
Block a user