openai[patch]: propagate service_tier to response metadata (#31089)

This commit is contained in:
ccurme
2025-05-01 13:50:48 -04:00
committed by GitHub
parent 6110c3ffc5
commit c51eadd54f
3 changed files with 10 additions and 2 deletions

View File

@@ -736,6 +736,8 @@ class BaseChatOpenAI(BaseChatModel):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
if service_tier := chunk.get("service_tier"):
generation_info["service_tier"] = service_tier
logprobs = choice.get("logprobs")
if logprobs:
@@ -1020,6 +1022,8 @@ class BaseChatOpenAI(BaseChatModel):
}
if "id" in response_dict:
llm_output["id"] = response_dict["id"]
if "service_tier" in response_dict:
llm_output["service_tier"] = response_dict["service_tier"]
if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
@@ -3243,6 +3247,7 @@ def _construct_lc_result_from_responses_api(
"status",
"user",
"model",
"service_tier",
)
}
if metadata:

View File

@@ -350,6 +350,7 @@ def test_response_metadata() -> None:
"logprobs",
"system_fingerprint",
"finish_reason",
"service_tier",
)
)
assert "content" in result.response_metadata["logprobs"]
@@ -367,6 +368,7 @@ async def test_async_response_metadata() -> None:
"logprobs",
"system_fingerprint",
"finish_reason",
"service_tier",
)
)
assert "content" in result.response_metadata["logprobs"]
@@ -380,7 +382,7 @@ def test_response_metadata_streaming() -> None:
full = chunk if full is None else full + chunk
assert all(
k in cast(BaseMessageChunk, full).response_metadata
for k in ("logprobs", "finish_reason")
for k in ("logprobs", "finish_reason", "service_tier")
)
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
@@ -393,7 +395,7 @@ async def test_async_response_metadata_streaming() -> None:
full = chunk if full is None else full + chunk
assert all(
k in cast(BaseMessageChunk, full).response_metadata
for k in ("logprobs", "finish_reason")
for k in ("logprobs", "finish_reason", "service_tier")
)
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]

View File

@@ -47,6 +47,7 @@ def _check_response(response: Optional[BaseMessage]) -> None:
assert response.usage_metadata["output_tokens"] > 0
assert response.usage_metadata["total_tokens"] > 0
assert response.response_metadata["model_name"]
assert response.response_metadata["service_tier"]
for tool_output in response.additional_kwargs["tool_outputs"]:
assert tool_output["id"]
assert tool_output["status"]