From c1b816cb7e70286bbb870c333d28f3bdd4d62363 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 10 Oct 2025 16:18:37 -0400 Subject: [PATCH] fix(fireworks): parse standard blocks in input (#33426) --- .../fireworks/langchain_fireworks/_compat.py | 26 +++++++++++++++++++ .../langchain_fireworks/chat_models.py | 9 ++++++- .../integration_tests/test_chat_models.py | 6 ++++- .../integration_tests/chat_models.py | 4 +-- 4 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 libs/partners/fireworks/langchain_fireworks/_compat.py diff --git a/libs/partners/fireworks/langchain_fireworks/_compat.py b/libs/partners/fireworks/langchain_fireworks/_compat.py new file mode 100644 index 00000000000..277762a78f5 --- /dev/null +++ b/libs/partners/fireworks/langchain_fireworks/_compat.py @@ -0,0 +1,26 @@ +"""Converts between AIMessage output formats, governed by `output_version`.""" + +from __future__ import annotations + +from langchain_core.messages import AIMessage + + +def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage: + """Convert a v1 message to the Chat Completions format.""" + if isinstance(message.content, list): + new_content: list = [] + for block in message.content: + if isinstance(block, dict): + block_type = block.get("type") + if block_type == "text": + # Strip annotations + new_content.append({"type": "text", "text": block["text"]}) + elif block_type in ("reasoning", "tool_call"): + pass + else: + new_content.append(block) + else: + new_content.append(block) + return message.model_copy(update={"content": new_content}) + + return message diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 9c242b1eaf5..0af3bc4feb5 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -78,6 +78,8 @@ from pydantic import ( ) from typing_extensions import Self +from langchain_fireworks._compat import _convert_from_v1_to_chat_completions + logger = logging.getLogger(__name__) @@ -152,6 +154,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): + # Translate v1 content + if message.response_metadata.get("output_version") == "v1": + message = _convert_from_v1_to_chat_completions(message) message_dict = {"role": "assistant", "content": message.content} if "function_call" in message.additional_kwargs: message_dict["function_call"] = message.additional_kwargs["function_call"] @@ -238,6 +243,7 @@ def _convert_chunk_to_message_chunk( additional_kwargs=additional_kwargs, tool_call_chunks=tool_call_chunks, usage_metadata=usage_metadata, # type: ignore[arg-type] + response_metadata={"model_provider": "fireworks"}, ) if role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) @@ -515,6 +521,8 @@ class ChatFireworks(BaseChatModel): "output_tokens": token_usage.get("completion_tokens", 0), "total_tokens": token_usage.get("total_tokens", 0), } + message.response_metadata["model_provider"] = "fireworks" + message.response_metadata["model_name"] = self.model_name generation_info = {"finish_reason": res.get("finish_reason")} if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] @@ -525,7 +533,6 @@ class ChatFireworks(BaseChatModel): generations.append(gen) llm_output = { "token_usage": token_usage, - "model_name": self.model_name, "system_fingerprint": response.get("system_fingerprint", ""), } return ChatResult(generations=generations, llm_output=llm_output) diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index c1662d3a356..2eb48409ee9 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -57,7 +57,9 @@ async def test_astream() -> None: full = token if full is None else full + token if token.usage_metadata is not None: chunks_with_token_counts += 1 - if token.response_metadata: + if token.response_metadata and not set(token.response_metadata.keys()).issubset( + {"model_provider", "output_version"} + ): chunks_with_response_metadata += 1 if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: msg = ( @@ -76,6 +78,7 @@ async def test_astream() -> None: ) assert isinstance(full.response_metadata["model_name"], str) assert full.response_metadata["model_name"] + assert full.response_metadata["model_provider"] == "fireworks" async def test_abatch_tags() -> None: @@ -103,6 +106,7 @@ def test_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) + assert result.response_metadata["model_provider"] == "fireworks" def _get_joke_class( diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index 07303e5a61f..e20085e1306 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -1498,8 +1498,8 @@ class ChatModelIntegrationTests(ChatModelTests): prompt = ChatPromptTemplate.from_messages( [("human", "Hello. Please respond in the style of {answer_style}.")] ) - model = GenericFakeChatModel(messages=iter(["hello matey"])) - chain = prompt | model | StrOutputParser() + llm = GenericFakeChatModel(messages=iter(["hello matey"])) + chain = prompt | llm | StrOutputParser() tool_ = chain.as_tool( name="greeting_generator", description="Generate a greeting in a particular style of speaking.",