mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(fireworks): parse standard blocks in input (#33426)
This commit is contained in:
26
libs/partners/fireworks/langchain_fireworks/_compat.py
Normal file
26
libs/partners/fireworks/langchain_fireworks/_compat.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Converts between AIMessage output formats, governed by `output_version`."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_from_v1_to_chat_completions(message: AIMessage) -> AIMessage:
|
||||||
|
"""Convert a v1 message to the Chat Completions format."""
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
new_content: list = []
|
||||||
|
for block in message.content:
|
||||||
|
if isinstance(block, dict):
|
||||||
|
block_type = block.get("type")
|
||||||
|
if block_type == "text":
|
||||||
|
# Strip annotations
|
||||||
|
new_content.append({"type": "text", "text": block["text"]})
|
||||||
|
elif block_type in ("reasoning", "tool_call"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
new_content.append(block)
|
||||||
|
else:
|
||||||
|
new_content.append(block)
|
||||||
|
return message.model_copy(update={"content": new_content})
|
||||||
|
|
||||||
|
return message
|
||||||
@@ -78,6 +78,8 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from langchain_fireworks._compat import _convert_from_v1_to_chat_completions
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -152,6 +154,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
|
# Translate v1 content
|
||||||
|
if message.response_metadata.get("output_version") == "v1":
|
||||||
|
message = _convert_from_v1_to_chat_completions(message)
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
if "function_call" in message.additional_kwargs:
|
if "function_call" in message.additional_kwargs:
|
||||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
@@ -238,6 +243,7 @@ def _convert_chunk_to_message_chunk(
|
|||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||||
|
response_metadata={"model_provider": "fireworks"},
|
||||||
)
|
)
|
||||||
if role == "system" or default_class == SystemMessageChunk:
|
if role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
@@ -515,6 +521,8 @@ class ChatFireworks(BaseChatModel):
|
|||||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||||
"total_tokens": token_usage.get("total_tokens", 0),
|
"total_tokens": token_usage.get("total_tokens", 0),
|
||||||
}
|
}
|
||||||
|
message.response_metadata["model_provider"] = "fireworks"
|
||||||
|
message.response_metadata["model_name"] = self.model_name
|
||||||
generation_info = {"finish_reason": res.get("finish_reason")}
|
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||||
if "logprobs" in res:
|
if "logprobs" in res:
|
||||||
generation_info["logprobs"] = res["logprobs"]
|
generation_info["logprobs"] = res["logprobs"]
|
||||||
@@ -525,7 +533,6 @@ class ChatFireworks(BaseChatModel):
|
|||||||
generations.append(gen)
|
generations.append(gen)
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"model_name": self.model_name,
|
|
||||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||||
}
|
}
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|||||||
@@ -57,7 +57,9 @@ async def test_astream() -> None:
|
|||||||
full = token if full is None else full + token
|
full = token if full is None else full + token
|
||||||
if token.usage_metadata is not None:
|
if token.usage_metadata is not None:
|
||||||
chunks_with_token_counts += 1
|
chunks_with_token_counts += 1
|
||||||
if token.response_metadata:
|
if token.response_metadata and not set(token.response_metadata.keys()).issubset(
|
||||||
|
{"model_provider", "output_version"}
|
||||||
|
):
|
||||||
chunks_with_response_metadata += 1
|
chunks_with_response_metadata += 1
|
||||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||||
msg = (
|
msg = (
|
||||||
@@ -76,6 +78,7 @@ async def test_astream() -> None:
|
|||||||
)
|
)
|
||||||
assert isinstance(full.response_metadata["model_name"], str)
|
assert isinstance(full.response_metadata["model_name"], str)
|
||||||
assert full.response_metadata["model_name"]
|
assert full.response_metadata["model_name"]
|
||||||
|
assert full.response_metadata["model_provider"] == "fireworks"
|
||||||
|
|
||||||
|
|
||||||
async def test_abatch_tags() -> None:
|
async def test_abatch_tags() -> None:
|
||||||
@@ -103,6 +106,7 @@ def test_invoke() -> None:
|
|||||||
|
|
||||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
|
assert result.response_metadata["model_provider"] == "fireworks"
|
||||||
|
|
||||||
|
|
||||||
def _get_joke_class(
|
def _get_joke_class(
|
||||||
|
|||||||
@@ -1498,8 +1498,8 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
[("human", "Hello. Please respond in the style of {answer_style}.")]
|
[("human", "Hello. Please respond in the style of {answer_style}.")]
|
||||||
)
|
)
|
||||||
model = GenericFakeChatModel(messages=iter(["hello matey"]))
|
llm = GenericFakeChatModel(messages=iter(["hello matey"]))
|
||||||
chain = prompt | model | StrOutputParser()
|
chain = prompt | llm | StrOutputParser()
|
||||||
tool_ = chain.as_tool(
|
tool_ = chain.as_tool(
|
||||||
name="greeting_generator",
|
name="greeting_generator",
|
||||||
description="Generate a greeting in a particular style of speaking.",
|
description="Generate a greeting in a particular style of speaking.",
|
||||||
|
|||||||
Reference in New Issue
Block a user