From 06ab2972e3a854e094724c7fd630a6ff913f76f7 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 7 Jul 2025 22:14:59 -0400 Subject: [PATCH] fireworks[patch]: ruff fixes and rules (#31903) * bump ruff deps * add more thorough ruff rules * fix said rules --- .../fireworks/langchain_fireworks/__init__.py | 2 +- .../langchain_fireworks/chat_models.py | 88 +++++++++++-------- .../langchain_fireworks/embeddings.py | 4 +- .../fireworks/langchain_fireworks/llms.py | 76 +++++++++------- libs/partners/fireworks/pyproject.toml | 54 +++++++++++- .../integration_tests/test_chat_models.py | 20 +++-- .../tests/integration_tests/test_compile.py | 1 - .../tests/integration_tests/test_llms.py | 2 +- .../tests/integration_tests/test_standard.py | 2 +- .../unit_tests/test_embeddings_standard.py | 2 +- .../fireworks/tests/unit_tests/test_llms.py | 2 +- .../tests/unit_tests/test_standard.py | 2 +- 12 files changed, 164 insertions(+), 91 deletions(-) diff --git a/libs/partners/fireworks/langchain_fireworks/__init__.py b/libs/partners/fireworks/langchain_fireworks/__init__.py index 2cc7be5fbd8..02e1a5b4fb6 100644 --- a/libs/partners/fireworks/langchain_fireworks/__init__.py +++ b/libs/partners/fireworks/langchain_fireworks/__init__.py @@ -4,8 +4,8 @@ from langchain_fireworks.llms import Fireworks from langchain_fireworks.version import __version__ __all__ = [ - "__version__", "ChatFireworks", "Fireworks", "FireworksEmbeddings", + "__version__", ] diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index d1dd62032ba..986d89818f4 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import json import logging from collections.abc import AsyncIterator, Iterator, Mapping, Sequence @@ -16,7 +17,7 @@ from typing import ( cast, ) -from fireworks.client import AsyncFireworks, Fireworks # type: ignore +from fireworks.client import AsyncFireworks, Fireworks # type: ignore[import-untyped] from langchain_core._api import deprecated from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -94,11 +95,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: Returns: The LangChain message. + """ role = _dict.get("role") if role == "user": return HumanMessage(content=_dict.get("content", "")) - elif role == "assistant": + if role == "assistant": # Fix for azure # Also Fireworks returns None for tool invocations content = _dict.get("content", "") or "" @@ -122,13 +124,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, ) - elif role == "system": + if role == "system": return SystemMessage(content=_dict.get("content", "")) - elif role == "function": + if role == "function": return FunctionMessage( content=_dict.get("content", ""), name=_dict.get("name", "") ) - elif role == "tool": + if role == "tool": additional_kwargs = {} if "name" in _dict: additional_kwargs["name"] = _dict["name"] @@ -137,8 +139,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: tool_call_id=_dict.get("tool_call_id", ""), additional_kwargs=additional_kwargs, ) - else: - return ChatMessage(content=_dict.get("content", ""), role=role or "") + return ChatMessage(content=_dict.get("content", ""), role=role or "") def _convert_message_to_dict(message: BaseMessage) -> dict: @@ -149,6 +150,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: Returns: The dictionary. + """ message_dict: dict[str, Any] if isinstance(message, ChatMessage): @@ -191,7 +193,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: "tool_call_id": message.tool_call_id, } else: - raise TypeError(f"Got unknown type {message}") + msg = f"Got unknown type {message}" + raise TypeError(msg) if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] return message_dict @@ -214,7 +217,7 @@ def _convert_chunk_to_message_chunk( if raw_tool_calls := _dict.get("tool_calls"): additional_kwargs["tool_calls"] = raw_tool_calls for rtc in raw_tool_calls: - try: + with contextlib.suppress(KeyError): tool_call_chunks.append( create_tool_call_chunk( name=rtc["function"].get("name"), @@ -223,11 +226,9 @@ def _convert_chunk_to_message_chunk( index=rtc.get("index"), ) ) - except KeyError: - pass if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: + if role == "assistant" or default_class == AIMessageChunk: if usage := chunk.get("usage"): input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) @@ -244,16 +245,15 @@ def _convert_chunk_to_message_chunk( tool_call_chunks=tool_call_chunks, usage_metadata=usage_metadata, # type: ignore[arg-type] ) - elif role == "system" or default_class == SystemMessageChunk: + if role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: + if role == "function" or default_class == FunctionMessageChunk: return FunctionMessageChunk(content=content, name=_dict["name"]) - elif role == "tool" or default_class == ToolMessageChunk: + if role == "tool" or default_class == ToolMessageChunk: return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) - elif role or default_class == ChatMessageChunk: + if role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) - else: - return default_class(content=content) # type: ignore + return default_class(content=content) # type: ignore[call-arg] class _FunctionCall(TypedDict): @@ -280,6 +280,7 @@ class ChatFireworks(BaseChatModel): from langchain_fireworks.chat_models import ChatFireworks fireworks = ChatFireworks( model_name="accounts/fireworks/models/llama-v3p1-8b-instruct") + """ @property @@ -326,14 +327,14 @@ class ChatFireworks(BaseChatModel): ), ) """Fireworks API key. - + Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided. """ fireworks_api_base: Optional[str] = Field( alias="base_url", default_factory=from_env("FIREWORKS_API_BASE", default=None) ) - """Base URL path for API requests, leave blank if not using a proxy or service + """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" request_timeout: Union[float, tuple[float, float], Any, None] = Field( default=None, alias="timeout" @@ -358,16 +359,17 @@ class ChatFireworks(BaseChatModel): def build_extra(cls, values: dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) - values = _build_model_kwargs(values, all_required_field_names) - return values + return _build_model_kwargs(values, all_required_field_names) @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" if self.n < 1: - raise ValueError("n must be at least 1.") + msg = "n must be at least 1." + raise ValueError(msg) if self.n > 1 and self.streaming: - raise ValueError("n must be 1 when streaming.") + msg = "n must be 1 when streaming." + raise ValueError(msg) client_params = { "api_key": ( @@ -522,7 +524,7 @@ class ChatFireworks(BaseChatModel): "output_tokens": token_usage.get("completion_tokens", 0), "total_tokens": token_usage.get("total_tokens", 0), } - generation_info = dict(finish_reason=res.get("finish_reason")) + generation_info = {"finish_reason": res.get("finish_reason")} if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] gen = ChatGeneration( @@ -628,7 +630,7 @@ class ChatFireworks(BaseChatModel): self, functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], function_call: Optional[ - Union[_FunctionCall, str, Literal["auto", "none"]] + Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051 ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: @@ -651,8 +653,8 @@ class ChatFireworks(BaseChatModel): (if any). **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. - """ + """ formatted_functions = [convert_to_openai_function(fn) for fn in functions] if function_call is not None: function_call = ( @@ -662,18 +664,20 @@ class ChatFireworks(BaseChatModel): else function_call ) if isinstance(function_call, dict) and len(formatted_functions) != 1: - raise ValueError( + msg = ( "When specifying `function_call`, you must provide exactly one " "function." ) + raise ValueError(msg) if ( isinstance(function_call, dict) and formatted_functions[0]["name"] != function_call["name"] ): - raise ValueError( + msg = ( f"Function call {function_call} was specified, but the only " f"provided function was {formatted_functions[0]['name']}." ) + raise ValueError(msg) kwargs = {**kwargs, "function_call": function_call} return super().bind( functions=formatted_functions, @@ -685,7 +689,7 @@ class ChatFireworks(BaseChatModel): tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], *, tool_choice: Optional[ - Union[dict, str, Literal["auto", "any", "none"], bool] + Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051 ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: @@ -705,8 +709,8 @@ class ChatFireworks(BaseChatModel): ``{"type": "function", "function": {"name": <>}}``. **kwargs: Any additional parameters to pass to :meth:`~langchain_fireworks.chat_models.ChatFireworks.bind` - """ + """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] if tool_choice is not None and tool_choice: if isinstance(tool_choice, str) and ( @@ -715,10 +719,11 @@ class ChatFireworks(BaseChatModel): tool_choice = {"type": "function", "function": {"name": tool_choice}} if isinstance(tool_choice, bool): if len(tools) > 1: - raise ValueError( + msg = ( "tool_choice can only be True when there is one tool. Received " f"{len(tools)} tools." ) + raise ValueError(msg) tool_name = formatted_tools[0]["function"]["name"] tool_choice = { "type": "function", @@ -779,6 +784,9 @@ class ChatFireworks(BaseChatModel): will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". + kwargs: + Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. + Returns: A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. @@ -964,17 +972,20 @@ class ChatFireworks(BaseChatModel): # }, # 'parsing_error': None # } + """ # noqa: E501 _ = kwargs.pop("strict", None) if kwargs: - raise ValueError(f"Received unsupported arguments {kwargs}") + msg = f"Received unsupported arguments {kwargs}" + raise ValueError(msg) is_pydantic_schema = _is_pydantic_class(schema) if method == "function_calling": if schema is None: - raise ValueError( + msg = ( "schema must be specified when method is 'function_calling'. " "Received None." ) + raise ValueError(msg) formatted_tool = convert_to_openai_tool(schema) tool_name = formatted_tool["function"]["name"] llm = self.bind_tools( @@ -996,10 +1007,11 @@ class ChatFireworks(BaseChatModel): ) elif method == "json_schema": if schema is None: - raise ValueError( + msg = ( "schema must be specified when method is 'json_schema'. " "Received None." ) + raise ValueError(msg) formatted_schema = convert_to_json_schema(schema) llm = self.bind( response_format={"type": "json_object", "schema": formatted_schema}, @@ -1027,10 +1039,11 @@ class ChatFireworks(BaseChatModel): else JsonOutputParser() ) else: - raise ValueError( + msg = ( f"Unrecognized method argument. Expected one of 'function_calling' or " f"'json_mode'. Received: '{method}'" ) + raise ValueError(msg) if include_raw: parser_assign = RunnablePassthrough.assign( @@ -1041,8 +1054,7 @@ class ChatFireworks(BaseChatModel): [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser + return llm | output_parser def _is_pydantic_class(obj: Any) -> bool: diff --git a/libs/partners/fireworks/langchain_fireworks/embeddings.py b/libs/partners/fireworks/langchain_fireworks/embeddings.py index 116cdd3ef4f..a1933961900 100644 --- a/libs/partners/fireworks/langchain_fireworks/embeddings.py +++ b/libs/partners/fireworks/langchain_fireworks/embeddings.py @@ -4,8 +4,6 @@ from openai import OpenAI from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self -# type: ignore - class FireworksEmbeddings(BaseModel, Embeddings): """Fireworks embedding model integration. @@ -78,7 +76,7 @@ class FireworksEmbeddings(BaseModel, Embeddings): ), ) """Fireworks API key. - + Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided. """ model: str = "nomic-ai/nomic-embed-text-v1.5" diff --git a/libs/partners/fireworks/langchain_fireworks/llms.py b/libs/partners/fireworks/langchain_fireworks/llms.py index cec176fcdf0..7f7fee8436f 100644 --- a/libs/partners/fireworks/langchain_fireworks/llms.py +++ b/libs/partners/fireworks/langchain_fireworks/llms.py @@ -1,5 +1,7 @@ """Wrapper around Fireworks AI's Completion API.""" +from __future__ import annotations + import logging from typing import Any, Optional @@ -49,7 +51,7 @@ class Fireworks(LLM): ), ) """Fireworks API key. - + Automatically read from env variable ``FIREWORKS_API_KEY`` if not provided. """ model: str @@ -60,14 +62,14 @@ class Fireworks(LLM): """Used to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. A value of ``1`` will always yield the same output. A temperature less than ``1`` favors more correctness and is appropriate for - question answering or summarization. A value greater than ``1`` introduces more + question answering or summarization. A value greater than ``1`` introduces more randomness in the output. """ model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for ``create`` call not explicitly specified.""" top_k: Optional[int] = None - """Used to limit the number of choices for the next predicted word or token. It - specifies the maximum number of tokens to consider at each step, based on their + """Used to limit the number of choices for the next predicted word or token. It + specifies the maximum number of tokens to consider at each step, based on their probability of occurrence. This technique helps to speed up the generation process and can improve the quality of the generated text by focusing on the most likely options. @@ -79,7 +81,7 @@ class Fireworks(LLM): of repeated sequences. Higher values decrease repetition. """ logprobs: Optional[int] = None - """An integer that specifies how many top token log probabilities are included in + """An integer that specifies how many top token log probabilities are included in the response for each token generation step. """ timeout: Optional[int] = 30 @@ -95,8 +97,7 @@ class Fireworks(LLM): def build_extra(cls, values: dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) - values = _build_model_kwargs(values, all_required_field_names) - return values + return _build_model_kwargs(values, all_required_field_names) @property def _llm_type(self) -> str: @@ -132,9 +133,13 @@ class Fireworks(LLM): Args: prompt: The prompt to pass into the model. + stop: Optional list of stop sequences to use. + run_manager: (Not used) Optional callback manager for LLM run. + kwargs: Additional parameters to pass to the model. Returns: The string generated by the model. + """ headers = { "Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}", @@ -155,19 +160,20 @@ class Fireworks(LLM): ) if response.status_code >= 500: - raise Exception(f"Fireworks Server: Error {response.status_code}") - elif response.status_code >= 400: - raise ValueError(f"Fireworks received an invalid payload: {response.text}") - elif response.status_code != 200: - raise Exception( + msg = f"Fireworks Server: Error {response.status_code}" + raise Exception(msg) + if response.status_code >= 400: + msg = f"Fireworks received an invalid payload: {response.text}" + raise ValueError(msg) + if response.status_code != 200: + msg = ( f"Fireworks returned an unexpected response with status " f"{response.status_code}: {response.text}" ) + raise Exception(msg) data = response.json() - output = self._format_output(data) - - return output + return self._format_output(data) async def _acall( self, @@ -180,9 +186,13 @@ class Fireworks(LLM): Args: prompt: The prompt to pass into the model. + stop: Optional list of strings to stop generation when encountered. + run_manager: (Not used) Optional callback manager for async runs. + kwargs: Additional parameters to pass to the model. Returns: The string generated by the model. + """ headers = { "Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}", @@ -198,25 +208,27 @@ class Fireworks(LLM): # filter None values to not pass them to the http payload payload = {k: v for k, v in payload.items() if v is not None} - async with ClientSession() as session: - async with session.post( + async with ( + ClientSession() as session, + session.post( self.base_url, json=payload, headers=headers, timeout=ClientTimeout(total=self.timeout), - ) as response: - if response.status >= 500: - raise Exception(f"Fireworks Server: Error {response.status}") - elif response.status >= 400: - raise ValueError( - f"Fireworks received an invalid payload: {response.text}" - ) - elif response.status != 200: - raise Exception( - f"Fireworks returned an unexpected response with status " - f"{response.status}: {response.text}" - ) + ) as response, + ): + if response.status >= 500: + msg = f"Fireworks Server: Error {response.status}" + raise Exception(msg) + if response.status >= 400: + msg = f"Fireworks received an invalid payload: {response.text}" + raise ValueError(msg) + if response.status != 200: + msg = ( + f"Fireworks returned an unexpected response with status " + f"{response.status}: {response.text}" + ) + raise Exception(msg) - response_json = await response.json() - output = self._format_output(response_json) - return output + response_json = await response.json() + return self._format_output(response_json) diff --git a/libs/partners/fireworks/pyproject.toml b/libs/partners/fireworks/pyproject.toml index 8cb7e4aae98..5a9a24c42a9 100644 --- a/libs/partners/fireworks/pyproject.toml +++ b/libs/partners/fireworks/pyproject.toml @@ -52,8 +52,58 @@ disallow_untyped_defs = "True" target-version = "py39" [tool.ruff.lint] -select = ["E", "F", "I", "T201", "UP", "S"] -ignore = [ "UP007", ] +select = [ + "A", # flake8-builtins + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "COM", # flake8-commas + "D", # pydocstyle + "DOC", # pydoclint + "E", # pycodestyle error + "EM", # flake8-errmsg + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT", # flake8-boolean-trap + "FLY", # flake8-flynt + "I", # isort + "ICN", # flake8-import-conventions + "INT", # flake8-gettext + "ISC", # isort-comprehensions + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PERF", # flake8-perf + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-rst-docstrings + "RUF", # ruff + "S", # flake8-bandit + "SLF", # flake8-self + "SLOT", # flake8-slots + "SIM", # flake8-simplify + "T10", # flake8-debugger + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 +] +ignore = [ + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D107", # Missing docstring in __init__ + "COM812", # Messes with the formatter + "ISC001", # Messes with the formatter + "PERF203", # Rarely useful + "S112", # Rarely useful + "RUF012", # Doesn't play well with Pydantic + "SLF001", # Private member access + "UP007", # pyupgrade: non-pep604-annotation-union +] [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index 8de13e1f30f..ab3683e5490 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -1,8 +1,10 @@ -"""Test ChatFireworks API wrapper +"""Test ChatFireworks API wrapper. You will need FIREWORKS_API_KEY set in your environment to run these tests. """ +from __future__ import annotations + import json from typing import Annotated, Any, Literal, Optional @@ -18,7 +20,6 @@ _MODEL = "accounts/fireworks/models/llama-v3p1-8b-instruct" def test_tool_choice_bool() -> None: """Test that tool choice is respected just passing in True.""" - llm = ChatFireworks( model="accounts/fireworks/models/llama-v3p1-70b-instruct", temperature=0 ) @@ -59,11 +60,12 @@ async def test_astream() -> None: if token.response_metadata: chunks_with_response_metadata += 1 if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: - raise AssertionError( + msg = ( "Expected exactly one chunk with token counts or response_metadata. " "AIMessageChunk aggregation adds / appends counts and metadata. Check that " "this is behaving properly." ) + raise AssertionError(msg) assert isinstance(full, AIMessageChunk) assert full.usage_metadata is not None assert full.usage_metadata["input_tokens"] > 0 @@ -99,7 +101,7 @@ def test_invoke() -> None: """Test invoke tokens from ChatFireworks.""" llm = ChatFireworks(model=_MODEL) - result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) @@ -122,18 +124,18 @@ def _get_joke_class( punchline: Annotated[str, ..., "answer to resolve the joke"] def validate_joke_dict(result: Any) -> bool: - return all(key in ["setup", "punchline"] for key in result.keys()) + return all(key in ["setup", "punchline"] for key in result) if schema_type == "pydantic": return Joke, validate_joke - elif schema_type == "typeddict": + if schema_type == "typeddict": return JokeDict, validate_joke_dict - elif schema_type == "json_schema": + if schema_type == "json_schema": return Joke.model_json_schema(), validate_joke_dict - else: - raise ValueError("Invalid schema type") + msg = "Invalid schema type" + raise ValueError(msg) @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"]) diff --git a/libs/partners/fireworks/tests/integration_tests/test_compile.py b/libs/partners/fireworks/tests/integration_tests/test_compile.py index 33ecccdfa0f..f315e45f521 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_compile.py +++ b/libs/partners/fireworks/tests/integration_tests/test_compile.py @@ -4,4 +4,3 @@ import pytest @pytest.mark.compile def test_placeholder() -> None: """Used for compiling integration tests without running any real tests.""" - pass diff --git a/libs/partners/fireworks/tests/integration_tests/test_llms.py b/libs/partners/fireworks/tests/integration_tests/test_llms.py index 30e940faaea..553666f6560 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_llms.py +++ b/libs/partners/fireworks/tests/integration_tests/test_llms.py @@ -100,5 +100,5 @@ def test_invoke() -> None: """Test invoke tokens from Fireworks.""" llm = Fireworks(model=_MODEL) - result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result, str) diff --git a/libs/partners/fireworks/tests/integration_tests/test_standard.py b/libs/partners/fireworks/tests/integration_tests/test_standard.py index 5c467a01649..45a9e69cbae 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_standard.py +++ b/libs/partners/fireworks/tests/integration_tests/test_standard.py @@ -1,4 +1,4 @@ -"""Standard LangChain interface tests""" +"""Standard LangChain interface tests.""" import pytest from langchain_core.language_models import BaseChatModel diff --git a/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py b/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py index 916f9bf2ca7..de07d229970 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py +++ b/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py @@ -1,4 +1,4 @@ -"""Standard LangChain interface tests""" +"""Standard LangChain interface tests.""" from langchain_core.embeddings import Embeddings from langchain_tests.unit_tests.embeddings import EmbeddingsUnitTests diff --git a/libs/partners/fireworks/tests/unit_tests/test_llms.py b/libs/partners/fireworks/tests/unit_tests/test_llms.py index 265df7ede83..d438f1d7048 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_llms.py +++ b/libs/partners/fireworks/tests/unit_tests/test_llms.py @@ -1,4 +1,4 @@ -"""Test Fireworks LLM""" +"""Test Fireworks LLM.""" from typing import cast diff --git a/libs/partners/fireworks/tests/unit_tests/test_standard.py b/libs/partners/fireworks/tests/unit_tests/test_standard.py index 3aee0335557..58b443b6722 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_standard.py +++ b/libs/partners/fireworks/tests/unit_tests/test_standard.py @@ -1,4 +1,4 @@ -"""Standard LangChain interface tests""" +"""Standard LangChain interface tests.""" from langchain_core.language_models import BaseChatModel from langchain_tests.unit_tests import ( # type: ignore[import-not-found]