mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
openai[patch]: default to invoke on o1 stream() (#27983)
This commit is contained in:
parent
503f2487a5
commit
33dbfba08b
@ -632,24 +632,6 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
base_generation_info = {}
|
base_generation_info = {}
|
||||||
|
|
||||||
if "response_format" in payload and is_basemodel_subclass(
|
|
||||||
payload["response_format"]
|
|
||||||
):
|
|
||||||
# TODO: Add support for streaming with Pydantic response_format.
|
|
||||||
warnings.warn("Streaming with Pydantic response_format not yet supported.")
|
|
||||||
chat_result = self._generate(
|
|
||||||
messages, stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
msg = chat_result.generations[0].message
|
|
||||||
yield ChatGenerationChunk(
|
|
||||||
message=AIMessageChunk(
|
|
||||||
**msg.dict(exclude={"type", "additional_kwargs"}),
|
|
||||||
# preserve the "parsed" Pydantic object without converting to dict
|
|
||||||
additional_kwargs=msg.additional_kwargs,
|
|
||||||
),
|
|
||||||
generation_info=chat_result.generations[0].generation_info,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if self.include_response_headers:
|
if self.include_response_headers:
|
||||||
raw_response = self.client.with_raw_response.create(**payload)
|
raw_response = self.client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
@ -783,24 +765,6 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
base_generation_info = {}
|
base_generation_info = {}
|
||||||
if "response_format" in payload and is_basemodel_subclass(
|
|
||||||
payload["response_format"]
|
|
||||||
):
|
|
||||||
# TODO: Add support for streaming with Pydantic response_format.
|
|
||||||
warnings.warn("Streaming with Pydantic response_format not yet supported.")
|
|
||||||
chat_result = await self._agenerate(
|
|
||||||
messages, stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
msg = chat_result.generations[0].message
|
|
||||||
yield ChatGenerationChunk(
|
|
||||||
message=AIMessageChunk(
|
|
||||||
**msg.dict(exclude={"type", "additional_kwargs"}),
|
|
||||||
# preserve the "parsed" Pydantic object without converting to dict
|
|
||||||
additional_kwargs=msg.additional_kwargs,
|
|
||||||
),
|
|
||||||
generation_info=chat_result.generations[0].generation_info,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if self.include_response_headers:
|
if self.include_response_headers:
|
||||||
raw_response = await self.async_client.with_raw_response.create(**payload)
|
raw_response = await self.async_client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
@ -999,6 +963,28 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
def _should_stream(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
async_api: bool,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
response_format: Optional[Union[dict, type]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> bool:
|
||||||
|
if isinstance(response_format, type) and is_basemodel_subclass(response_format):
|
||||||
|
# TODO: Add support for streaming with Pydantic response_format.
|
||||||
|
warnings.warn("Streaming with Pydantic response_format not yet supported.")
|
||||||
|
return False
|
||||||
|
if self.model_name.startswith("o1"):
|
||||||
|
# TODO: Add support for streaming with o1 once supported.
|
||||||
|
return False
|
||||||
|
|
||||||
|
return super()._should_stream(
|
||||||
|
async_api=async_api, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
since="0.2.1",
|
since="0.2.1",
|
||||||
alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools",
|
alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools",
|
||||||
|
@ -1061,3 +1061,27 @@ def test_prediction_tokens() -> None:
|
|||||||
]
|
]
|
||||||
assert output_token_details["accepted_prediction_tokens"] > 0
|
assert output_token_details["accepted_prediction_tokens"] > 0
|
||||||
assert output_token_details["rejected_prediction_tokens"] > 0
|
assert output_token_details["rejected_prediction_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_o1() -> None:
|
||||||
|
list(ChatOpenAI(model="o1-mini").stream("how are you"))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_astream_o1() -> None:
|
||||||
|
async for _ in ChatOpenAI(model="o1-mini").astream("how are you"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_response_format() -> None:
|
||||||
|
list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_astream_response_format() -> None:
|
||||||
|
async for _ in ChatOpenAI(model="gpt-4o-mini").astream(
|
||||||
|
"how are ya", response_format=Foo
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user