fix(fireworks): parse standard blocks in input (#33426)

This commit is contained in:
ccurme
2025-10-10 16:18:37 -04:00
committed by GitHub
parent 0559558715
commit c1b816cb7e
4 changed files with 41 additions and 4 deletions

View File

@@ -0,0 +1,26 @@
"""Converts between AIMessage output formats, governed by `output_version`."""
from __future__ import annotations
from langchain_core.messages import AIMessage
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
"""Convert a v1 message to the Chat Completions format."""
if isinstance(message.content, list):
new_content: list = []
for block in message.content:
if isinstance(block, dict):
block_type = block.get("type")
if block_type == "text":
# Strip annotations
new_content.append({"type": "text", "text": block["text"]})
elif block_type in ("reasoning", "tool_call"):
pass
else:
new_content.append(block)
else:
new_content.append(block)
return message.model_copy(update={"content": new_content})
return message

View File

@@ -78,6 +78,8 @@ from pydantic import (
)
from typing_extensions import Self
from langchain_fireworks._compat import _convert_from_v1_to_chat_completions
logger = logging.getLogger(__name__)
@@ -152,6 +154,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
# Translate v1 content
if message.response_metadata.get("output_version") == "v1":
message = _convert_from_v1_to_chat_completions(message)
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
@@ -238,6 +243,7 @@ def _convert_chunk_to_message_chunk(
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "fireworks"},
)
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
@@ -515,6 +521,8 @@ class ChatFireworks(BaseChatModel):
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
message.response_metadata["model_provider"] = "fireworks"
message.response_metadata["model_name"] = self.model_name
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
@@ -525,7 +533,6 @@ class ChatFireworks(BaseChatModel):
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)