diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 4de8b2798a3..62ad918e037 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -34,7 +34,11 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -474,6 +478,8 @@ class ChatOpenAI(BaseChatModel): chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info or None ) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) yield chunk def _generate( @@ -483,12 +489,13 @@ class ChatOpenAI(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) - params = { - **params, - **({"stream": self.streaming} if self.streaming else {}), - **kwargs, - } + params = {**params, **kwargs} response = self.client.create(messages=message_dicts, **params) return self._create_chat_result(response) @@ -569,6 +576,10 @@ class ChatOpenAI(BaseChatModel): chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info or None ) + if run_manager: + await run_manager.on_llm_new_token( + token=chunk.text, chunk=chunk, logprobs=logprobs + ) yield chunk async def _agenerate( @@ -578,12 +589,14 @@ class ChatOpenAI(BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + message_dicts, params = self._create_message_dicts(messages, stop) - params = { - **params, - **({"stream": self.streaming} if self.streaming else {}), - **kwargs, - } + params = {**params, **kwargs} response = await self.async_client.create(messages=message_dicts, **params) return self._create_chat_result(response)