langchain_openai: Make sure the response from the async client in the astream method of ChatOpenAI is properly awaited in case of "include_response_headers=True" (#26031)

- **Description:** This is a **one line change**. the
`self.async_client.with_raw_response.create(**payload)` call is not
properly awaited within the `_astream` method. In `_agenerate` this is
done already, but likely forgotten in the other method.
  - **Issue:** Not applicable
  - **Dependencies:** No dependencies required.

(If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.)

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Friso H. Kingma 2024-09-04 15:26:48 +02:00 committed by GitHub
parent c812237217
commit af11fbfbf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 3 deletions

View File

@ -757,7 +757,7 @@ class BaseChatOpenAI(BaseChatModel):
) )
return return
if self.include_response_headers: if self.include_response_headers:
raw_response = self.async_client.with_raw_response.create(**payload) raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)} base_generation_info = {"headers": dict(raw_response.headers)}
else: else:

View File

@ -686,15 +686,47 @@ def test_openai_proxy() -> None:
assert proxy.port == 8080 assert proxy.port == 8080
def test_openai_response_headers_invoke() -> None: def test_openai_response_headers() -> None:
"""Test ChatOpenAI response headers.""" """Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True) chat_openai = ChatOpenAI(include_response_headers=True)
result = chat_openai.invoke("I'm Pickle Rick") query = "I'm Pickle Rick"
result = chat_openai.invoke(query, max_tokens=10)
headers = result.response_metadata["headers"] headers = result.response_metadata["headers"]
assert headers assert headers
assert isinstance(headers, dict) assert isinstance(headers, dict)
assert "content-type" in headers assert "content-type" in headers
# Stream
full: Optional[BaseMessageChunk] = None
for chunk in chat_openai.stream(query, max_tokens=10):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers
async def test_openai_response_headers_async() -> None:
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True)
query = "I'm Pickle Rick"
result = await chat_openai.ainvoke(query, max_tokens=10)
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers
# Stream
full: Optional[BaseMessageChunk] = None
async for chunk in chat_openai.astream(query, max_tokens=10):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers
def test_image_token_counting_jpeg() -> None: def test_image_token_counting_jpeg() -> None:
model = ChatOpenAI(model="gpt-4o", temperature=0) model = ChatOpenAI(model="gpt-4o", temperature=0)