mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
community[patch]: fix qianfan chat stream calling caused exception (#13800)
- **Description:** `QianfanChatEndpoint` extends `BaseChatModel` as a super class, which has a default stream implement might concat the MessageChunk with `__add__`. When call stream(), a ValueError for duplicated key will be raise. - **Issues:** * #13546 * #13548 * merge two single test file related to qianfan. - **Dependencies:** no - **Tag maintainer:** --------- Co-authored-by: root <liujun45@baidu.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
656e87beb9
commit
70b6315b23
@ -55,17 +55,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[INFO] [09-15 20:00:29] logging.py:55 [t:139698882193216]: requesting llm api endpoint: /chat/eb-instant\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"For basic init and call\"\"\"\n",
|
||||
"import os\n",
|
||||
@ -126,9 +118,7 @@
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"from langchain_community.chat_models import QianfanChatEndpoint\n",
|
||||
"\n",
|
||||
"chatLLM = QianfanChatEndpoint(\n",
|
||||
" streaming=True,\n",
|
||||
")\n",
|
||||
"chatLLM = QianfanChatEndpoint()\n",
|
||||
"res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
|
||||
"for r in res:\n",
|
||||
" print(\"chat resp:\", r)\n",
|
||||
@ -260,11 +250,11 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
|
||||
"hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast
|
||||
|
||||
@ -244,7 +242,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
token_usage = {}
|
||||
chat_generation_info: Dict = {}
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
chat_generation_info = (
|
||||
chunk.generation_info
|
||||
if chunk.generation_info is not None
|
||||
else chat_generation_info
|
||||
)
|
||||
completion += chunk.text
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
@ -253,7 +258,10 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
llm_output={
|
||||
"token_usage": chat_generation_info.get("usage", {}),
|
||||
"model_name": self.model,
|
||||
},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
@ -279,7 +287,13 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
token_usage = {}
|
||||
chat_generation_info: Dict = {}
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
chat_generation_info = (
|
||||
chunk.generation_info
|
||||
if chunk.generation_info is not None
|
||||
else chat_generation_info
|
||||
)
|
||||
completion += chunk.text
|
||||
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
@ -289,7 +303,10 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
llm_output={
|
||||
"token_usage": chat_generation_info.get("usage", {}),
|
||||
"model_name": self.model,
|
||||
},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = await self.client.ado(**params)
|
||||
@ -315,16 +332,19 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
params["stream"] = True
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
msg = _convert_dict_to_message(res)
|
||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=AIMessageChunk(
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=msg.additional_kwargs,
|
||||
additional_kwargs=additional_kwargs,
|
||||
),
|
||||
generation_info=msg.additional_kwargs,
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
@ -338,16 +358,19 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
params["stream"] = True
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
msg = _convert_dict_to_message(res)
|
||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=AIMessageChunk(
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=msg.additional_kwargs,
|
||||
additional_kwargs=additional_kwargs,
|
||||
),
|
||||
generation_info=msg.additional_kwargs,
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
|
@ -1,53 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_community.chat_models.baidu_qianfan_endpoint import (
|
||||
QianfanChatEndpoint,
|
||||
)
|
||||
|
||||
|
||||
def test_qianfan_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
|
||||
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")
|
||||
|
||||
chat = QianfanChatEndpoint()
|
||||
print(chat.qianfan_ak, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
print(chat.qianfan_sk, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_qianfan_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
chat = QianfanChatEndpoint(
|
||||
qianfan_ak="test-api-key",
|
||||
qianfan_sk="test-secret-key",
|
||||
)
|
||||
print(chat.qianfan_ak, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
print(chat.qianfan_sk, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_uses_actual_secret_value_from_secret_str() -> None:
|
||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||
chat = QianfanChatEndpoint(
|
||||
qianfan_ak="test-api-key",
|
||||
qianfan_sk="test-secret-key",
|
||||
)
|
||||
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
|
||||
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"
|
@ -1,18 +1,24 @@
|
||||
"""Test Baidu Qianfan Chat Endpoint."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain_community.chat_models.baidu_qianfan_endpoint import (
|
||||
QianfanChatEndpoint,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
_FUNCTIONS: Any = [
|
||||
@ -139,6 +145,25 @@ def test_multiple_history() -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_generate() -> None:
|
||||
"""Tests chat generate works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat.generate(
|
||||
[
|
||||
[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
]
|
||||
)
|
||||
assert isinstance(response, LLMResult)
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = QianfanChatEndpoint(streaming=True)
|
||||
@ -156,6 +181,57 @@ def test_stream() -> None:
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
res = chat.stream(
|
||||
[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
]
|
||||
)
|
||||
|
||||
assert len(list(res)) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_invoke() -> None:
|
||||
chat = QianfanChatEndpoint()
|
||||
res = await chat.ainvoke([HumanMessage(content="Hello")])
|
||||
assert isinstance(res, BaseMessage)
|
||||
assert res.content != ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generate() -> None:
|
||||
"""Tests chat agenerate works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = await chat.agenerate(
|
||||
[
|
||||
[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
]
|
||||
)
|
||||
assert isinstance(response, LLMResult)
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream() -> None:
|
||||
chat = QianfanChatEndpoint(streaming=True)
|
||||
async for token in chat.astream(
|
||||
[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
]
|
||||
):
|
||||
assert isinstance(token, BaseMessageChunk)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
@ -232,3 +308,48 @@ def test_rate_limit() -> None:
|
||||
for res in responses:
|
||||
assert isinstance(res, BaseMessage)
|
||||
assert isinstance(res.content, str)
|
||||
|
||||
|
||||
def test_qianfan_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
|
||||
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")
|
||||
|
||||
chat = QianfanChatEndpoint()
|
||||
print(chat.qianfan_ak, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
print(chat.qianfan_sk, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_qianfan_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
chat = QianfanChatEndpoint(
|
||||
qianfan_ak="test-api-key",
|
||||
qianfan_sk="test-secret-key",
|
||||
)
|
||||
print(chat.qianfan_ak, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
print(chat.qianfan_sk, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_uses_actual_secret_value_from_secret_str() -> None:
|
||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||
chat = QianfanChatEndpoint(
|
||||
qianfan_ak="test-api-key",
|
||||
qianfan_sk="test-secret-key",
|
||||
)
|
||||
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
|
||||
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"
|
||||
|
Loading…
Reference in New Issue
Block a user