diff --git a/libs/langchain/langchain/llms/tongyi.py b/libs/langchain/langchain/llms/tongyi.py index 74d904c4453..1ecd238cd74 100644 --- a/libs/langchain/langchain/llms/tongyi.py +++ b/libs/langchain/langchain/llms/tongyi.py @@ -218,9 +218,23 @@ class Tongyi(LLM): if len(prompts) > 1: raise ValueError("Cannot stream results with multiple prompts.") params["stream"] = True + temp = "" for stream_resp in stream_generate_with_retry( self, prompt=prompts[0], **params ): + if run_manager: + stream_resp_text = stream_resp["output"]["text"] + stream_resp_text = stream_resp_text.replace(temp, "") + # Ali Cloud's streaming transmission interface, each return content + # will contain the output + # of the previous round(as of September 20, 2023, future updates to + # the Alibaba Cloud API may vary) + run_manager.on_llm_new_token(stream_resp_text) + # The implementation of streaming transmission primarily relies on + # the "on_llm_new_token" method + # of the streaming callback. + temp = stream_resp["output"]["text"] + generations.append( [ Generation( @@ -231,6 +245,19 @@ class Tongyi(LLM): ) ] ) + generations.reverse() + # In the official implementation of the OpenAI API, + # the "generations" parameter passed to LLMResult seems to be a 1*1*1 + # two-dimensional list + # (including in non-streaming mode). + # Considering that Alibaba Cloud's streaming transmission + # (as of September 20, 2023, future updates to the Alibaba Cloud API may + # vary) + # includes the output of the previous round in each return, + # reversing this "generations" list should suffice + # (This is the solution with the least amount of changes to the source code, + # while still allowing for convenient modifications in the future, + # although it may result in slightly more memory consumption). else: for prompt in prompts: completion = generate_with_retry( diff --git a/tests/integration_tests/llms/test_tongyi.py b/tests/integration_tests/llms/test_tongyi.py new file mode 100644 index 00000000000..e5858cc02a6 --- /dev/null +++ b/tests/integration_tests/llms/test_tongyi.py @@ -0,0 +1,97 @@ +import sys +from typing import Any, Dict, List, Union +from queue import Queue + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.llms.tongyi import Tongyi +import os + +os.environ['QIANFAN_AK']='' +os.environ['QIANFAN_SK']='' + + +STOP_ITEM = "###finish###" + + +class StreamingHandler(BaseCallbackHandler): + """Callback handler for streaming. Only works with LLMs that support streaming.""" + + def __init__(self, q: Queue) -> None: + super().__init__() + self.q = q + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + # print("on_llm_start", serialized, prompts) + with self.q.mutex: + self.q.queue.clear() + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + # print('======on_llm_new_token=====') + sys.stdout.write(token) + sys.stdout.flush() + self.q.put(token) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.q.put(STOP_ITEM) + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when LLM errors.""" + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + # print('on_chain_start') + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when chain errors.""" + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + pass + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when tool errors.""" + + def on_text(self, text: str, **kwargs: Any) -> None: + """Run on arbitrary text.""" + # print('on_text',text) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + + +streaming_callback_fn = StreamingHandler(q=Queue()) + +llm = Tongyi(streaming=True, + callbacks=[streaming_callback_fn], + temperature=0.1, + model_name = 'qwen-plus') + +llm('write a Poem about Spring') + + +