mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 06:53:59 +00:00
fireworks[patch]: ruff fixes and rules (#31903)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Wrapper around Fireworks AI's Completion API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -49,7 +51,7 @@ class Fireworks(LLM):
|
||||
),
|
||||
)
|
||||
"""Fireworks API key.
|
||||
|
||||
|
||||
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
|
||||
"""
|
||||
model: str
|
||||
@@ -60,14 +62,14 @@ class Fireworks(LLM):
|
||||
"""Used to dynamically adjust the number of choices for each predicted token based
|
||||
on the cumulative probabilities. A value of ``1`` will always yield the same output.
|
||||
A temperature less than ``1`` favors more correctness and is appropriate for
|
||||
question answering or summarization. A value greater than ``1`` introduces more
|
||||
question answering or summarization. A value greater than ``1`` introduces more
|
||||
randomness in the output.
|
||||
"""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for ``create`` call not explicitly specified."""
|
||||
top_k: Optional[int] = None
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
probability of occurrence. This technique helps to speed up the generation process
|
||||
and can improve the quality of the generated text by focusing on the most likely
|
||||
options.
|
||||
@@ -79,7 +81,7 @@ class Fireworks(LLM):
|
||||
of repeated sequences. Higher values decrease repetition.
|
||||
"""
|
||||
logprobs: Optional[int] = None
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
the response for each token generation step.
|
||||
"""
|
||||
timeout: Optional[int] = 30
|
||||
@@ -95,8 +97,7 @@ class Fireworks(LLM):
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
return _build_model_kwargs(values, all_required_field_names)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -132,9 +133,13 @@ class Fireworks(LLM):
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop sequences to use.
|
||||
run_manager: (Not used) Optional callback manager for LLM run.
|
||||
kwargs: Additional parameters to pass to the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
@@ -155,19 +160,20 @@ class Fireworks(LLM):
|
||||
)
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"Fireworks received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
msg = f"Fireworks Server: Error {response.status_code}"
|
||||
raise Exception(msg)
|
||||
if response.status_code >= 400:
|
||||
msg = f"Fireworks received an invalid payload: {response.text}"
|
||||
raise ValueError(msg)
|
||||
if response.status_code != 200:
|
||||
msg = (
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
raise Exception(msg)
|
||||
|
||||
data = response.json()
|
||||
output = self._format_output(data)
|
||||
|
||||
return output
|
||||
return self._format_output(data)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@@ -180,9 +186,13 @@ class Fireworks(LLM):
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of strings to stop generation when encountered.
|
||||
run_manager: (Not used) Optional callback manager for async runs.
|
||||
kwargs: Additional parameters to pass to the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
@@ -198,25 +208,27 @@ class Fireworks(LLM):
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
async with ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
ClientSession() as session,
|
||||
session.post(
|
||||
self.base_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status}")
|
||||
elif response.status >= 400:
|
||||
raise ValueError(
|
||||
f"Fireworks received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != 200:
|
||||
raise Exception(
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
) as response,
|
||||
):
|
||||
if response.status >= 500:
|
||||
msg = f"Fireworks Server: Error {response.status}"
|
||||
raise Exception(msg)
|
||||
if response.status >= 400:
|
||||
msg = f"Fireworks received an invalid payload: {response.text}"
|
||||
raise ValueError(msg)
|
||||
if response.status != 200:
|
||||
msg = (
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
raise Exception(msg)
|
||||
|
||||
response_json = await response.json()
|
||||
output = self._format_output(response_json)
|
||||
return output
|
||||
response_json = await response.json()
|
||||
return self._format_output(response_json)
|
||||
|
Reference in New Issue
Block a user