fireworks[patch]: ruff fixes and rules (#31903)

* bump ruff deps
* add more thorough ruff rules
* fix said rules
This commit is contained in:
Mason Daugherty 2025-07-07 22:14:59 -04:00 committed by GitHub
parent 63e3f2dea6
commit 06ab2972e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 164 additions and 91 deletions

View File

@ -4,8 +4,8 @@ from langchain_fireworks.llms import Fireworks
from langchain_fireworks.version import __version__ from langchain_fireworks.version import __version__
__all__ = [ __all__ = [
"__version__",
"ChatFireworks", "ChatFireworks",
"Fireworks", "Fireworks",
"FireworksEmbeddings", "FireworksEmbeddings",
"__version__",
] ]

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import contextlib
import json import json
import logging import logging
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
@ -16,7 +17,7 @@ from typing import (
cast, cast,
) )
from fireworks.client import AsyncFireworks, Fireworks # type: ignore from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped]
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -94,11 +95,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
Returns: Returns:
The LangChain message. The LangChain message.
""" """
role = _dict.get("role") role = _dict.get("role")
if role == "user": if role == "user":
return HumanMessage(content=_dict.get("content", "")) return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant": if role == "assistant":
# Fix for azure # Fix for azure
# Also Fireworks returns None for tool invocations # Also Fireworks returns None for tool invocations
content = _dict.get("content", "") or "" content = _dict.get("content", "") or ""
@ -122,13 +124,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls, tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls, invalid_tool_calls=invalid_tool_calls,
) )
elif role == "system": if role == "system":
return SystemMessage(content=_dict.get("content", "")) return SystemMessage(content=_dict.get("content", ""))
elif role == "function": if role == "function":
return FunctionMessage( return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name", "") content=_dict.get("content", ""), name=_dict.get("name", "")
) )
elif role == "tool": if role == "tool":
additional_kwargs = {} additional_kwargs = {}
if "name" in _dict: if "name" in _dict:
additional_kwargs["name"] = _dict["name"] additional_kwargs["name"] = _dict["name"]
@ -137,8 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_call_id=_dict.get("tool_call_id", ""), tool_call_id=_dict.get("tool_call_id", ""),
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
) )
else: return ChatMessage(content=_dict.get("content", ""), role=role or "")
return ChatMessage(content=_dict.get("content", ""), role=role or "")
def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_message_to_dict(message: BaseMessage) -> dict:
@ -149,6 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns: Returns:
The dictionary. The dictionary.
""" """
message_dict: dict[str, Any] message_dict: dict[str, Any]
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
@ -191,7 +193,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"tool_call_id": message.tool_call_id, "tool_call_id": message.tool_call_id,
} }
else: else:
raise TypeError(f"Got unknown type {message}") msg = f"Got unknown type {message}"
raise TypeError(msg)
if "name" in message.additional_kwargs: if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"] message_dict["name"] = message.additional_kwargs["name"]
return message_dict return message_dict
@ -214,7 +217,7 @@ def _convert_chunk_to_message_chunk(
if raw_tool_calls := _dict.get("tool_calls"): if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls: for rtc in raw_tool_calls:
try: with contextlib.suppress(KeyError):
tool_call_chunks.append( tool_call_chunks.append(
create_tool_call_chunk( create_tool_call_chunk(
name=rtc["function"].get("name"), name=rtc["function"].get("name"),
@ -223,11 +226,9 @@ def _convert_chunk_to_message_chunk(
index=rtc.get("index"), index=rtc.get("index"),
) )
) )
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: if role == "assistant" or default_class == AIMessageChunk:
if usage := chunk.get("usage"): if usage := chunk.get("usage"):
input_tokens = usage.get("prompt_tokens", 0) input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0) output_tokens = usage.get("completion_tokens", 0)
@ -244,16 +245,15 @@ def _convert_chunk_to_message_chunk(
tool_call_chunks=tool_call_chunks, tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type]
) )
elif role == "system" or default_class == SystemMessageChunk: if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content) return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk: if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"]) return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk: if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk: if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) return ChatMessageChunk(content=content, role=role)
else: return default_class(content=content) # type: ignore[call-arg]
return default_class(content=content) # type: ignore
class _FunctionCall(TypedDict): class _FunctionCall(TypedDict):
@ -280,6 +280,7 @@ class ChatFireworks(BaseChatModel):
from langchain_fireworks.chat_models import ChatFireworks from langchain_fireworks.chat_models import ChatFireworks
fireworks = ChatFireworks( fireworks = ChatFireworks(
model_name="accounts/fireworks/models/llama-v3p1-8b-instruct") model_name="accounts/fireworks/models/llama-v3p1-8b-instruct")
""" """
@property @property
@ -326,14 +327,14 @@ class ChatFireworks(BaseChatModel):
), ),
) )
"""Fireworks API key. """Fireworks API key.
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided. Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
""" """
fireworks_api_base: Optional[str] = Field( fireworks_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None) alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None)
) )
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
emulator.""" emulator."""
request_timeout: Union[float, tuple[float, float], Any, None] = Field( request_timeout: Union[float, tuple[float, float], Any, None] = Field(
default=None, alias="timeout" default=None, alias="timeout"
@ -358,16 +359,17 @@ class ChatFireworks(BaseChatModel):
def build_extra(cls, values: dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) return _build_model_kwargs(values, all_required_field_names)
return values
@model_validator(mode="after") @model_validator(mode="after")
def validate_environment(self) -> Self: def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if self.n < 1: if self.n < 1:
raise ValueError("n must be at least 1.") msg = "n must be at least 1."
raise ValueError(msg)
if self.n > 1 and self.streaming: if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.") msg = "n must be 1 when streaming."
raise ValueError(msg)
client_params = { client_params = {
"api_key": ( "api_key": (
@ -522,7 +524,7 @@ class ChatFireworks(BaseChatModel):
"output_tokens": token_usage.get("completion_tokens", 0), "output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0), "total_tokens": token_usage.get("total_tokens", 0),
} }
generation_info = dict(finish_reason=res.get("finish_reason")) generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res: if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"] generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration( gen = ChatGeneration(
@ -628,7 +630,7 @@ class ChatFireworks(BaseChatModel):
self, self,
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
function_call: Optional[ function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]] Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
] = None, ] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
@ -651,8 +653,8 @@ class ChatFireworks(BaseChatModel):
(if any). (if any).
**kwargs: Any additional parameters to pass to the **kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor. :class:`~langchain.runnable.Runnable` constructor.
"""
"""
formatted_functions = [convert_to_openai_function(fn) for fn in functions] formatted_functions = [convert_to_openai_function(fn) for fn in functions]
if function_call is not None: if function_call is not None:
function_call = ( function_call = (
@ -662,18 +664,20 @@ class ChatFireworks(BaseChatModel):
else function_call else function_call
) )
if isinstance(function_call, dict) and len(formatted_functions) != 1: if isinstance(function_call, dict) and len(formatted_functions) != 1:
raise ValueError( msg = (
"When specifying `function_call`, you must provide exactly one " "When specifying `function_call`, you must provide exactly one "
"function." "function."
) )
raise ValueError(msg)
if ( if (
isinstance(function_call, dict) isinstance(function_call, dict)
and formatted_functions[0]["name"] != function_call["name"] and formatted_functions[0]["name"] != function_call["name"]
): ):
raise ValueError( msg = (
f"Function call {function_call} was specified, but the only " f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}." f"provided function was {formatted_functions[0]['name']}."
) )
raise ValueError(msg)
kwargs = {**kwargs, "function_call": function_call} kwargs = {**kwargs, "function_call": function_call}
return super().bind( return super().bind(
functions=formatted_functions, functions=formatted_functions,
@ -685,7 +689,7 @@ class ChatFireworks(BaseChatModel):
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool] Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051
] = None, ] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
@ -705,8 +709,8 @@ class ChatFireworks(BaseChatModel):
``{"type": "function", "function": {"name": <<tool_name>>}}``. ``{"type": "function", "function": {"name": <<tool_name>>}}``.
**kwargs: Any additional parameters to pass to **kwargs: Any additional parameters to pass to
:meth:`~langchain_fireworks.chat_models.ChatFireworks.bind` :meth:`~langchain_fireworks.chat_models.ChatFireworks.bind`
"""
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice: if tool_choice is not None and tool_choice:
if isinstance(tool_choice, str) and ( if isinstance(tool_choice, str) and (
@ -715,10 +719,11 @@ class ChatFireworks(BaseChatModel):
tool_choice = {"type": "function", "function": {"name": tool_choice}} tool_choice = {"type": "function", "function": {"name": tool_choice}}
if isinstance(tool_choice, bool): if isinstance(tool_choice, bool):
if len(tools) > 1: if len(tools) > 1:
raise ValueError( msg = (
"tool_choice can only be True when there is one tool. Received " "tool_choice can only be True when there is one tool. Received "
f"{len(tools)} tools." f"{len(tools)} tools."
) )
raise ValueError(msg)
tool_name = formatted_tools[0]["function"]["name"] tool_name = formatted_tools[0]["function"]["name"]
tool_choice = { tool_choice = {
"type": "function", "type": "function",
@ -779,6 +784,9 @@ class ChatFireworks(BaseChatModel):
will be caught and returned as well. The final output is always a dict will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error". with keys "raw", "parsed", and "parsing_error".
kwargs:
Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor.
Returns: Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@ -964,17 +972,20 @@ class ChatFireworks(BaseChatModel):
# }, # },
# 'parsing_error': None # 'parsing_error': None
# } # }
""" # noqa: E501 """ # noqa: E501
_ = kwargs.pop("strict", None) _ = kwargs.pop("strict", None)
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)
is_pydantic_schema = _is_pydantic_class(schema) is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling": if method == "function_calling":
if schema is None: if schema is None:
raise ValueError( msg = (
"schema must be specified when method is 'function_calling'. " "schema must be specified when method is 'function_calling'. "
"Received None." "Received None."
) )
raise ValueError(msg)
formatted_tool = convert_to_openai_tool(schema) formatted_tool = convert_to_openai_tool(schema)
tool_name = formatted_tool["function"]["name"] tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools( llm = self.bind_tools(
@ -996,10 +1007,11 @@ class ChatFireworks(BaseChatModel):
) )
elif method == "json_schema": elif method == "json_schema":
if schema is None: if schema is None:
raise ValueError( msg = (
"schema must be specified when method is 'json_schema'. " "schema must be specified when method is 'json_schema'. "
"Received None." "Received None."
) )
raise ValueError(msg)
formatted_schema = convert_to_json_schema(schema) formatted_schema = convert_to_json_schema(schema)
llm = self.bind( llm = self.bind(
response_format={"type": "json_object", "schema": formatted_schema}, response_format={"type": "json_object", "schema": formatted_schema},
@ -1027,10 +1039,11 @@ class ChatFireworks(BaseChatModel):
else JsonOutputParser() else JsonOutputParser()
) )
else: else:
raise ValueError( msg = (
f"Unrecognized method argument. Expected one of 'function_calling' or " f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'" f"'json_mode'. Received: '{method}'"
) )
raise ValueError(msg)
if include_raw: if include_raw:
parser_assign = RunnablePassthrough.assign( parser_assign = RunnablePassthrough.assign(
@ -1041,8 +1054,7 @@ class ChatFireworks(BaseChatModel):
[parser_none], exception_key="parsing_error" [parser_none], exception_key="parsing_error"
) )
return RunnableMap(raw=llm) | parser_with_fallback return RunnableMap(raw=llm) | parser_with_fallback
else: return llm | output_parser
return llm | output_parser
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:

View File

@ -4,8 +4,6 @@ from openai import OpenAI
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self from typing_extensions import Self
# type: ignore
class FireworksEmbeddings(BaseModel, Embeddings): class FireworksEmbeddings(BaseModel, Embeddings):
"""Fireworks embedding model integration. """Fireworks embedding model integration.
@ -78,7 +76,7 @@ class FireworksEmbeddings(BaseModel, Embeddings):
), ),
) )
"""Fireworks API key. """Fireworks API key.
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided. Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
""" """
model: str = "nomic-ai/nomic-embed-text-v1.5" model: str = "nomic-ai/nomic-embed-text-v1.5"

View File

@ -1,5 +1,7 @@
"""Wrapper around Fireworks AI's Completion API.""" """Wrapper around Fireworks AI's Completion API."""
from __future__ import annotations
import logging import logging
from typing import Any, Optional from typing import Any, Optional
@ -49,7 +51,7 @@ class Fireworks(LLM):
), ),
) )
"""Fireworks API key. """Fireworks API key.
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided. Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
""" """
model: str model: str
@ -60,14 +62,14 @@ class Fireworks(LLM):
"""Used to dynamically adjust the number of choices for each predicted token based """Used to dynamically adjust the number of choices for each predicted token based
on the cumulative probabilities. A value of ``1`` will always yield the same output. on the cumulative probabilities. A value of ``1`` will always yield the same output.
A temperature less than ``1`` favors more correctness and is appropriate for A temperature less than ``1`` favors more correctness and is appropriate for
question answering or summarization. A value greater than ``1`` introduces more question answering or summarization. A value greater than ``1`` introduces more
randomness in the output. randomness in the output.
""" """
model_kwargs: dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for ``create`` call not explicitly specified.""" """Holds any model parameters valid for ``create`` call not explicitly specified."""
top_k: Optional[int] = None top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It """Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their specifies the maximum number of tokens to consider at each step, based on their
probability of occurrence. This technique helps to speed up the generation process probability of occurrence. This technique helps to speed up the generation process
and can improve the quality of the generated text by focusing on the most likely and can improve the quality of the generated text by focusing on the most likely
options. options.
@ -79,7 +81,7 @@ class Fireworks(LLM):
of repeated sequences. Higher values decrease repetition. of repeated sequences. Higher values decrease repetition.
""" """
logprobs: Optional[int] = None logprobs: Optional[int] = None
"""An integer that specifies how many top token log probabilities are included in """An integer that specifies how many top token log probabilities are included in
the response for each token generation step. the response for each token generation step.
""" """
timeout: Optional[int] = 30 timeout: Optional[int] = 30
@ -95,8 +97,7 @@ class Fireworks(LLM):
def build_extra(cls, values: dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names) return _build_model_kwargs(values, all_required_field_names)
return values
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -132,9 +133,13 @@ class Fireworks(LLM):
Args: Args:
prompt: The prompt to pass into the model. prompt: The prompt to pass into the model.
stop: Optional list of stop sequences to use.
run_manager: (Not used) Optional callback manager for LLM run.
kwargs: Additional parameters to pass to the model.
Returns: Returns:
The string generated by the model. The string generated by the model.
""" """
headers = { headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}", "Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
@ -155,19 +160,20 @@ class Fireworks(LLM):
) )
if response.status_code >= 500: if response.status_code >= 500:
raise Exception(f"Fireworks Server: Error {response.status_code}") msg = f"Fireworks Server: Error {response.status_code}"
elif response.status_code >= 400: raise Exception(msg)
raise ValueError(f"Fireworks received an invalid payload: {response.text}") if response.status_code >= 400:
elif response.status_code != 200: msg = f"Fireworks received an invalid payload: {response.text}"
raise Exception( raise ValueError(msg)
if response.status_code != 200:
msg = (
f"Fireworks returned an unexpected response with status " f"Fireworks returned an unexpected response with status "
f"{response.status_code}: {response.text}" f"{response.status_code}: {response.text}"
) )
raise Exception(msg)
data = response.json() data = response.json()
output = self._format_output(data) return self._format_output(data)
return output
async def _acall( async def _acall(
self, self,
@ -180,9 +186,13 @@ class Fireworks(LLM):
Args: Args:
prompt: The prompt to pass into the model. prompt: The prompt to pass into the model.
stop: Optional list of strings to stop generation when encountered.
run_manager: (Not used) Optional callback manager for async runs.
kwargs: Additional parameters to pass to the model.
Returns: Returns:
The string generated by the model. The string generated by the model.
""" """
headers = { headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}", "Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
@ -198,25 +208,27 @@ class Fireworks(LLM):
# filter None values to not pass them to the http payload # filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None} payload = {k: v for k, v in payload.items() if v is not None}
async with ClientSession() as session: async with (
async with session.post( ClientSession() as session,
session.post(
self.base_url, self.base_url,
json=payload, json=payload,
headers=headers, headers=headers,
timeout=ClientTimeout(total=self.timeout), timeout=ClientTimeout(total=self.timeout),
) as response: ) as response,
if response.status >= 500: ):
raise Exception(f"Fireworks Server: Error {response.status}") if response.status >= 500:
elif response.status >= 400: msg = f"Fireworks Server: Error {response.status}"
raise ValueError( raise Exception(msg)
f"Fireworks received an invalid payload: {response.text}" if response.status >= 400:
) msg = f"Fireworks received an invalid payload: {response.text}"
elif response.status != 200: raise ValueError(msg)
raise Exception( if response.status != 200:
f"Fireworks returned an unexpected response with status " msg = (
f"{response.status}: {response.text}" f"Fireworks returned an unexpected response with status "
) f"{response.status}: {response.text}"
)
raise Exception(msg)
response_json = await response.json() response_json = await response.json()
output = self._format_output(response_json) return self._format_output(response_json)
return output

View File

@ -52,8 +52,58 @@ disallow_untyped_defs = "True"
target-version = "py39" target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP", "S"] select = [
ignore = [ "UP007", ] "A", # flake8-builtins
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D107", # Missing docstring in __init__
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -1,8 +1,10 @@
"""Test ChatFireworks API wrapper """Test ChatFireworks API wrapper.
You will need FIREWORKS_API_KEY set in your environment to run these tests. You will need FIREWORKS_API_KEY set in your environment to run these tests.
""" """
from __future__ import annotations
import json import json
from typing import Annotated, Any, Literal, Optional from typing import Annotated, Any, Literal, Optional
@ -18,7 +20,6 @@ _MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
def test_tool_choice_bool() -> None: def test_tool_choice_bool() -> None:
"""Test that tool choice is respected just passing in True.""" """Test that tool choice is respected just passing in True."""
llm = ChatFireworks( llm = ChatFireworks(
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0 model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
) )
@ -59,11 +60,12 @@ async def test_astream() -> None:
if token.response_metadata: if token.response_metadata:
chunks_with_response_metadata += 1 chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
raise AssertionError( msg = (
"Expected exactly one chunk with token counts or response_metadata. " "Expected exactly one chunk with token counts or response_metadata. "
"AIMessageChunk aggregation adds / appends counts and metadata. Check that " "AIMessageChunk aggregation adds / appends counts and metadata. Check that "
"this is behaving properly." "this is behaving properly."
) )
raise AssertionError(msg)
assert isinstance(full, AIMessageChunk) assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0 assert full.usage_metadata["input_tokens"] > 0
@ -99,7 +101,7 @@ def test_invoke() -> None:
"""Test invoke tokens from ChatFireworks.""" """Test invoke tokens from ChatFireworks."""
llm = ChatFireworks(model=_MODEL) llm = ChatFireworks(model=_MODEL)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str) assert isinstance(result.content, str)
@ -122,18 +124,18 @@ def _get_joke_class(
punchline: Annotated[str, ..., "answer to resolve the joke"] punchline: Annotated[str, ..., "answer to resolve the joke"]
def validate_joke_dict(result: Any) -> bool: def validate_joke_dict(result: Any) -> bool:
return all(key in ["setup", "punchline"] for key in result.keys()) return all(key in ["setup", "punchline"] for key in result)
if schema_type == "pydantic": if schema_type == "pydantic":
return Joke, validate_joke return Joke, validate_joke
elif schema_type == "typeddict": if schema_type == "typeddict":
return JokeDict, validate_joke_dict return JokeDict, validate_joke_dict
elif schema_type == "json_schema": if schema_type == "json_schema":
return Joke.model_json_schema(), validate_joke_dict return Joke.model_json_schema(), validate_joke_dict
else: msg = "Invalid schema type"
raise ValueError("Invalid schema type") raise ValueError(msg)
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"]) @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])

View File

@ -4,4 +4,3 @@ import pytest
@pytest.mark.compile @pytest.mark.compile
def test_placeholder() -> None: def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests.""" """Used for compiling integration tests without running any real tests."""
pass

View File

@ -100,5 +100,5 @@ def test_invoke() -> None:
"""Test invoke tokens from Fireworks.""" """Test invoke tokens from Fireworks."""
llm = Fireworks(model=_MODEL) llm = Fireworks(model=_MODEL)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str) assert isinstance(result, str)

View File

@ -1,4 +1,4 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests."""
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel

View File

@ -1,4 +1,4 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests."""
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests

View File

@ -1,4 +1,4 @@
"""Test Fireworks LLM""" """Test Fireworks LLM."""
from typing import cast from typing import cast

View File

@ -1,4 +1,4 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests."""
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_tests.unit_tests import ( # type: ignore[import-not-found] from langchain_tests.unit_tests import ( # type: ignore[import-not-found]