mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
fireworks[patch]: ruff fixes and rules (#31903)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
parent
63e3f2dea6
commit
06ab2972e3
@ -4,8 +4,8 @@ from langchain_fireworks.llms import Fireworks
|
||||
from langchain_fireworks.version import __version__
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"ChatFireworks",
|
||||
"Fireworks",
|
||||
"FireworksEmbeddings",
|
||||
"__version__",
|
||||
]
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
@ -16,7 +17,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
|
||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped]
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -94,11 +95,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
|
||||
"""
|
||||
role = _dict.get("role")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict.get("content", ""))
|
||||
elif role == "assistant":
|
||||
if role == "assistant":
|
||||
# Fix for azure
|
||||
# Also Fireworks returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
@ -122,13 +124,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "system":
|
||||
if role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
elif role == "function":
|
||||
if role == "function":
|
||||
return FunctionMessage(
|
||||
content=_dict.get("content", ""), name=_dict.get("name", "")
|
||||
)
|
||||
elif role == "tool":
|
||||
if role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
@ -137,8 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_call_id=_dict.get("tool_call_id", ""),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
@ -149,6 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
|
||||
"""
|
||||
message_dict: dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
@ -191,7 +193,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
msg = f"Got unknown type {message}"
|
||||
raise TypeError(msg)
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
@ -214,7 +217,7 @@ def _convert_chunk_to_message_chunk(
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for rtc in raw_tool_calls:
|
||||
try:
|
||||
with contextlib.suppress(KeyError):
|
||||
tool_call_chunks.append(
|
||||
create_tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
@ -223,11 +226,9 @@ def _convert_chunk_to_message_chunk(
|
||||
index=rtc.get("index"),
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
if usage := chunk.get("usage"):
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
@ -244,16 +245,15 @@ def _convert_chunk_to_message_chunk(
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
if role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
if role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
@ -280,6 +280,7 @@ class ChatFireworks(BaseChatModel):
|
||||
from langchain_fireworks.chat_models import ChatFireworks
|
||||
fireworks = ChatFireworks(
|
||||
model_name="accounts/fireworks/models/llama-v3p1-8b-instruct")
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
@ -326,14 +327,14 @@ class ChatFireworks(BaseChatModel):
|
||||
),
|
||||
)
|
||||
"""Fireworks API key.
|
||||
|
||||
|
||||
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
|
||||
"""
|
||||
|
||||
fireworks_api_base: Optional[str] = Field(
|
||||
alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None)
|
||||
)
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
request_timeout: Union[float, tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
@ -358,16 +359,17 @@ class ChatFireworks(BaseChatModel):
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
return _build_model_kwargs(values, all_required_field_names)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
msg = "n must be at least 1."
|
||||
raise ValueError(msg)
|
||||
if self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
msg = "n must be 1 when streaming."
|
||||
raise ValueError(msg)
|
||||
|
||||
client_params = {
|
||||
"api_key": (
|
||||
@ -522,7 +524,7 @@ class ChatFireworks(BaseChatModel):
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
gen = ChatGeneration(
|
||||
@ -628,7 +630,7 @@ class ChatFireworks(BaseChatModel):
|
||||
self,
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@ -651,8 +653,8 @@ class ChatFireworks(BaseChatModel):
|
||||
(if any).
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
"""
|
||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||
if function_call is not None:
|
||||
function_call = (
|
||||
@ -662,18 +664,20 @@ class ChatFireworks(BaseChatModel):
|
||||
else function_call
|
||||
)
|
||||
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"When specifying `function_call`, you must provide exactly one "
|
||||
"function."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if (
|
||||
isinstance(function_call, dict)
|
||||
and formatted_functions[0]["name"] != function_call["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Function call {function_call} was specified, but the only "
|
||||
f"provided function was {formatted_functions[0]['name']}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
kwargs = {**kwargs, "function_call": function_call}
|
||||
return super().bind(
|
||||
functions=formatted_functions,
|
||||
@ -685,7 +689,7 @@ class ChatFireworks(BaseChatModel):
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "any", "none"], bool]
|
||||
Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@ -705,8 +709,8 @@ class ChatFireworks(BaseChatModel):
|
||||
``{"type": "function", "function": {"name": <<tool_name>>}}``.
|
||||
**kwargs: Any additional parameters to pass to
|
||||
:meth:`~langchain_fireworks.chat_models.ChatFireworks.bind`
|
||||
"""
|
||||
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None and tool_choice:
|
||||
if isinstance(tool_choice, str) and (
|
||||
@ -715,10 +719,11 @@ class ChatFireworks(BaseChatModel):
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||
if isinstance(tool_choice, bool):
|
||||
if len(tools) > 1:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"tool_choice can only be True when there is one tool. Received "
|
||||
f"{len(tools)} tools."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_name = formatted_tools[0]["function"]["name"]
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
@ -779,6 +784,9 @@ class ChatFireworks(BaseChatModel):
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
kwargs:
|
||||
Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
@ -964,17 +972,20 @@ class ChatFireworks(BaseChatModel):
|
||||
# },
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
msg = f"Received unsupported arguments {kwargs}"
|
||||
raise ValueError(msg)
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
formatted_tool = convert_to_openai_tool(schema)
|
||||
tool_name = formatted_tool["function"]["name"]
|
||||
llm = self.bind_tools(
|
||||
@ -996,10 +1007,11 @@ class ChatFireworks(BaseChatModel):
|
||||
)
|
||||
elif method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'json_schema'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
formatted_schema = convert_to_json_schema(schema)
|
||||
llm = self.bind(
|
||||
response_format={"type": "json_object", "schema": formatted_schema},
|
||||
@ -1027,10 +1039,11 @@ class ChatFireworks(BaseChatModel):
|
||||
else JsonOutputParser()
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
@ -1041,8 +1054,7 @@ class ChatFireworks(BaseChatModel):
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
return llm | output_parser
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
|
@ -4,8 +4,6 @@ from openai import OpenAI
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
# type: ignore
|
||||
|
||||
|
||||
class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
"""Fireworks embedding model integration.
|
||||
@ -78,7 +76,7 @@ class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
),
|
||||
)
|
||||
"""Fireworks API key.
|
||||
|
||||
|
||||
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
|
||||
"""
|
||||
model: str = "nomic-ai/nomic-embed-text-v1.5"
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Wrapper around Fireworks AI's Completion API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -49,7 +51,7 @@ class Fireworks(LLM):
|
||||
),
|
||||
)
|
||||
"""Fireworks API key.
|
||||
|
||||
|
||||
Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided.
|
||||
"""
|
||||
model: str
|
||||
@ -60,14 +62,14 @@ class Fireworks(LLM):
|
||||
"""Used to dynamically adjust the number of choices for each predicted token based
|
||||
on the cumulative probabilities. A value of ``1`` will always yield the same output.
|
||||
A temperature less than ``1`` favors more correctness and is appropriate for
|
||||
question answering or summarization. A value greater than ``1`` introduces more
|
||||
question answering or summarization. A value greater than ``1`` introduces more
|
||||
randomness in the output.
|
||||
"""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for ``create`` call not explicitly specified."""
|
||||
top_k: Optional[int] = None
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
probability of occurrence. This technique helps to speed up the generation process
|
||||
and can improve the quality of the generated text by focusing on the most likely
|
||||
options.
|
||||
@ -79,7 +81,7 @@ class Fireworks(LLM):
|
||||
of repeated sequences. Higher values decrease repetition.
|
||||
"""
|
||||
logprobs: Optional[int] = None
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
the response for each token generation step.
|
||||
"""
|
||||
timeout: Optional[int] = 30
|
||||
@ -95,8 +97,7 @@ class Fireworks(LLM):
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
return _build_model_kwargs(values, all_required_field_names)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@ -132,9 +133,13 @@ class Fireworks(LLM):
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop sequences to use.
|
||||
run_manager: (Not used) Optional callback manager for LLM run.
|
||||
kwargs: Additional parameters to pass to the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
@ -155,19 +160,20 @@ class Fireworks(LLM):
|
||||
)
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"Fireworks received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
msg = f"Fireworks Server: Error {response.status_code}"
|
||||
raise Exception(msg)
|
||||
if response.status_code >= 400:
|
||||
msg = f"Fireworks received an invalid payload: {response.text}"
|
||||
raise ValueError(msg)
|
||||
if response.status_code != 200:
|
||||
msg = (
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
raise Exception(msg)
|
||||
|
||||
data = response.json()
|
||||
output = self._format_output(data)
|
||||
|
||||
return output
|
||||
return self._format_output(data)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@ -180,9 +186,13 @@ class Fireworks(LLM):
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of strings to stop generation when encountered.
|
||||
run_manager: (Not used) Optional callback manager for async runs.
|
||||
kwargs: Additional parameters to pass to the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
@ -198,25 +208,27 @@ class Fireworks(LLM):
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
async with ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
ClientSession() as session,
|
||||
session.post(
|
||||
self.base_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status}")
|
||||
elif response.status >= 400:
|
||||
raise ValueError(
|
||||
f"Fireworks received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != 200:
|
||||
raise Exception(
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
) as response,
|
||||
):
|
||||
if response.status >= 500:
|
||||
msg = f"Fireworks Server: Error {response.status}"
|
||||
raise Exception(msg)
|
||||
if response.status >= 400:
|
||||
msg = f"Fireworks received an invalid payload: {response.text}"
|
||||
raise ValueError(msg)
|
||||
if response.status != 200:
|
||||
msg = (
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
raise Exception(msg)
|
||||
|
||||
response_json = await response.json()
|
||||
output = self._format_output(response_json)
|
||||
return output
|
||||
response_json = await response.json()
|
||||
return self._format_output(response_json)
|
||||
|
@ -52,8 +52,58 @@ disallow_untyped_defs = "True"
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201", "UP", "S"]
|
||||
ignore = [ "UP007", ]
|
||||
select = [
|
||||
"A", # flake8-builtins
|
||||
"ASYNC", # flake8-async
|
||||
"C4", # flake8-comprehensions
|
||||
"COM", # flake8-commas
|
||||
"D", # pydocstyle
|
||||
"DOC", # pydoclint
|
||||
"E", # pycodestyle error
|
||||
"EM", # flake8-errmsg
|
||||
"F", # pyflakes
|
||||
"FA", # flake8-future-annotations
|
||||
"FBT", # flake8-boolean-trap
|
||||
"FLY", # flake8-flynt
|
||||
"I", # isort
|
||||
"ICN", # flake8-import-conventions
|
||||
"INT", # flake8-gettext
|
||||
"ISC", # isort-comprehensions
|
||||
"PGH", # pygrep-hooks
|
||||
"PIE", # flake8-pie
|
||||
"PERF", # flake8-perf
|
||||
"PYI", # flake8-pyi
|
||||
"Q", # flake8-quotes
|
||||
"RET", # flake8-return
|
||||
"RSE", # flake8-rst-docstrings
|
||||
"RUF", # ruff
|
||||
"S", # flake8-bandit
|
||||
"SLF", # flake8-self
|
||||
"SLOT", # flake8-slots
|
||||
"SIM", # flake8-simplify
|
||||
"T10", # flake8-debugger
|
||||
"T20", # flake8-print
|
||||
"TID", # flake8-tidy-imports
|
||||
"UP", # pyupgrade
|
||||
"W", # pycodestyle warning
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
ignore = [
|
||||
"D100", # Missing docstring in public module
|
||||
"D101", # Missing docstring in public class
|
||||
"D102", # Missing docstring in public method
|
||||
"D103", # Missing docstring in public function
|
||||
"D104", # Missing docstring in public package
|
||||
"D105", # Missing docstring in magic method
|
||||
"D107", # Missing docstring in __init__
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"S112", # Rarely useful
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -1,8 +1,10 @@
|
||||
"""Test ChatFireworks API wrapper
|
||||
"""Test ChatFireworks API wrapper.
|
||||
|
||||
You will need FIREWORKS_API_KEY set in your environment to run these tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
@ -18,7 +20,6 @@ _MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct"
|
||||
|
||||
def test_tool_choice_bool() -> None:
|
||||
"""Test that tool choice is respected just passing in True."""
|
||||
|
||||
llm = ChatFireworks(
|
||||
model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0
|
||||
)
|
||||
@ -59,11 +60,12 @@ async def test_astream() -> None:
|
||||
if token.response_metadata:
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
raise AssertionError(
|
||||
msg = (
|
||||
"Expected exactly one chunk with token counts or response_metadata. "
|
||||
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
@ -99,7 +101,7 @@ def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatFireworks."""
|
||||
llm = ChatFireworks(model=_MODEL)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
@ -122,18 +124,18 @@ def _get_joke_class(
|
||||
punchline: Annotated[str, ..., "answer to resolve the joke"]
|
||||
|
||||
def validate_joke_dict(result: Any) -> bool:
|
||||
return all(key in ["setup", "punchline"] for key in result.keys())
|
||||
return all(key in ["setup", "punchline"] for key in result)
|
||||
|
||||
if schema_type == "pydantic":
|
||||
return Joke, validate_joke
|
||||
|
||||
elif schema_type == "typeddict":
|
||||
if schema_type == "typeddict":
|
||||
return JokeDict, validate_joke_dict
|
||||
|
||||
elif schema_type == "json_schema":
|
||||
if schema_type == "json_schema":
|
||||
return Joke.model_json_schema(), validate_joke_dict
|
||||
else:
|
||||
raise ValueError("Invalid schema type")
|
||||
msg = "Invalid schema type"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
|
||||
|
@ -4,4 +4,3 @@ import pytest
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
||||
|
@ -100,5 +100,5 @@ def test_invoke() -> None:
|
||||
"""Test invoke tokens from Fireworks."""
|
||||
llm = Fireworks(model=_MODEL)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Test Fireworks LLM"""
|
||||
"""Test Fireworks LLM."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
|
||||
|
Loading…
Reference in New Issue
Block a user