fireworks[patch]: ruff fixes and rules (#31903)

* bump ruff deps
* add more thorough ruff rules
* fix said rules
This commit is contained in:
Mason Daugherty
2025-07-07 22:14:59 -04:00
committed by GitHub
parent 63e3f2dea6
commit 06ab2972e3
12 changed files with 164 additions and 91 deletions

View File

@@ -1,5 +1,7 @@
"""Wrapper around Fireworks AI's Completion API."""
from __future__ import annotations
import logging
from typing import Any, Optional
@@ -49,7 +51,7 @@ class Fireworks(LLM):
),
)
"""Fireworks API key.
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
"""
model: str
@@ -60,14 +62,14 @@ class Fireworks(LLM):
"""Used to dynamically adjust the number of choices for each predicted token based
on the cumulative probabilities. A value of ``1`` will always yield the same output.
A temperature less than ``1`` favors more correctness and is appropriate for
question answering or summarization. A value greater than ``1`` introduces more
question answering or summarization. A value greater than ``1`` introduces more
randomness in the output.
"""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for ``create`` call not explicitly specified."""
top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their
"""Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their
probability of occurrence. This technique helps to speed up the generation process
and can improve the quality of the generated text by focusing on the most likely
options.
@@ -79,7 +81,7 @@ class Fireworks(LLM):
of repeated sequences. Higher values decrease repetition.
"""
logprobs: Optional[int] = None
"""An integer that specifies how many top token log probabilities are included in
"""An integer that specifies how many top token log probabilities are included in
the response for each token generation step.
"""
timeout: Optional[int] = 30
@@ -95,8 +97,7 @@ class Fireworks(LLM):
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
return _build_model_kwargs(values, all_required_field_names)
@property
def _llm_type(self) -> str:
@@ -132,9 +133,13 @@ class Fireworks(LLM):
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop sequences to use.
run_manager: (Not used) Optional callback manager for LLM run.
kwargs: Additional parameters to pass to the model.
Returns:
The string generated by the model.
"""
headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
@@ -155,19 +160,20 @@ class Fireworks(LLM):
)
if response.status_code >= 500:
raise Exception(f"Fireworks Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"Fireworks received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
msg = f"Fireworks Server: Error {response.status_code}"
raise Exception(msg)
if response.status_code >= 400:
msg = f"Fireworks received an invalid payload: {response.text}"
raise ValueError(msg)
if response.status_code != 200:
msg = (
f"Fireworks returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
raise Exception(msg)
data = response.json()
output = self._format_output(data)
return output
return self._format_output(data)
async def _acall(
self,
@@ -180,9 +186,13 @@ class Fireworks(LLM):
Args:
prompt: The prompt to pass into the model.
stop: Optional list of strings to stop generation when encountered.
run_manager: (Not used) Optional callback manager for async runs.
kwargs: Additional parameters to pass to the model.
Returns:
The string generated by the model.
"""
headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
@@ -198,25 +208,27 @@ class Fireworks(LLM):
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
async with ClientSession() as session:
async with session.post(
async with (
ClientSession() as session,
session.post(
self.base_url,
json=payload,
headers=headers,
timeout=ClientTimeout(total=self.timeout),
) as response:
if response.status >= 500:
raise Exception(f"Fireworks Server: Error {response.status}")
elif response.status >= 400:
raise ValueError(
f"Fireworks received an invalid payload: {response.text}"
)
elif response.status != 200:
raise Exception(
f"Fireworks returned an unexpected response with status "
f"{response.status}: {response.text}"
)
) as response,
):
if response.status >= 500:
msg = f"Fireworks Server: Error {response.status}"
raise Exception(msg)
if response.status >= 400:
msg = f"Fireworks received an invalid payload: {response.text}"
raise ValueError(msg)
if response.status != 200:
msg = (
f"Fireworks returned an unexpected response with status "
f"{response.status}: {response.text}"
)
raise Exception(msg)
response_json = await response.json()
output = self._format_output(response_json)
return output
response_json = await response.json()
return self._format_output(response_json)