mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 21:35:08 +00:00
openai[patch]: wrap stream code in context manager blocks (#18013)
**Description:** Use the `Stream` context managers in `ChatOpenAi` `stream` and `astream` method. Using the context manager returned by the OpenAI client makes it possible to terminate the stream early since the response connection will be closed when the context manager exists. **Issue:** #5340 **Twitter handle:** @snopoke --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
6c11c8dac6
commit
a682f0d12b
@ -457,30 +457,33 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
for chunk in self.client.create(messages=message_dicts, **params):
|
with self.client.create(messages=message_dicts, **params) as response:
|
||||||
if not isinstance(chunk, dict):
|
for chunk in response:
|
||||||
chunk = chunk.model_dump()
|
if not isinstance(chunk, dict):
|
||||||
if len(chunk["choices"]) == 0:
|
chunk = chunk.model_dump()
|
||||||
continue
|
if len(chunk["choices"]) == 0:
|
||||||
choice = chunk["choices"][0]
|
continue
|
||||||
if choice["delta"] is None:
|
choice = chunk["choices"][0]
|
||||||
continue
|
if choice["delta"] is None:
|
||||||
chunk = _convert_delta_to_message_chunk(
|
continue
|
||||||
choice["delta"], default_chunk_class
|
chunk = _convert_delta_to_message_chunk(
|
||||||
)
|
choice["delta"], default_chunk_class
|
||||||
generation_info = {}
|
)
|
||||||
if finish_reason := choice.get("finish_reason"):
|
generation_info = {}
|
||||||
generation_info["finish_reason"] = finish_reason
|
if finish_reason := choice.get("finish_reason"):
|
||||||
logprobs = choice.get("logprobs")
|
generation_info["finish_reason"] = finish_reason
|
||||||
if logprobs:
|
logprobs = choice.get("logprobs")
|
||||||
generation_info["logprobs"] = logprobs
|
if logprobs:
|
||||||
default_chunk_class = chunk.__class__
|
generation_info["logprobs"] = logprobs
|
||||||
chunk = ChatGenerationChunk(
|
default_chunk_class = chunk.__class__
|
||||||
message=chunk, generation_info=generation_info or None
|
chunk = ChatGenerationChunk(
|
||||||
)
|
message=chunk, generation_info=generation_info or None
|
||||||
if run_manager:
|
)
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
if run_manager:
|
||||||
yield chunk
|
run_manager.on_llm_new_token(
|
||||||
|
chunk.text, chunk=chunk, logprobs=logprobs
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -553,34 +556,34 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
async for chunk in await self.async_client.create(
|
response = await self.async_client.create(messages=message_dicts, **params)
|
||||||
messages=message_dicts, **params
|
async with response:
|
||||||
):
|
async for chunk in response:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = chunk["choices"][0]
|
||||||
if choice["delta"] is None:
|
if choice["delta"] is None:
|
||||||
continue
|
continue
|
||||||
chunk = _convert_delta_to_message_chunk(
|
chunk = _convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
|
||||||
generation_info = {}
|
|
||||||
if finish_reason := choice.get("finish_reason"):
|
|
||||||
generation_info["finish_reason"] = finish_reason
|
|
||||||
logprobs = choice.get("logprobs")
|
|
||||||
if logprobs:
|
|
||||||
generation_info["logprobs"] = logprobs
|
|
||||||
default_chunk_class = chunk.__class__
|
|
||||||
chunk = ChatGenerationChunk(
|
|
||||||
message=chunk, generation_info=generation_info or None
|
|
||||||
)
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(
|
|
||||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
|
||||||
)
|
)
|
||||||
yield chunk
|
generation_info = {}
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user