diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 85b200e6ace..c374dea3cce 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -83,7 +83,7 @@ from langchain_groq.version import __version__ class ChatGroq(BaseChatModel): - """Groq Chat large language models API. + r"""Groq Chat large language models API. To use, you should have the environment variable ``GROQ_API_KEY`` set with your API key. @@ -412,7 +412,8 @@ class ChatGroq(BaseChatModel): extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") + msg = f"Found {field_name} supplied twice." + raise ValueError(msg) if field_name not in all_required_field_names: warnings.warn( f"""WARNING! {field_name} is not default parameter. @@ -423,10 +424,11 @@ class ChatGroq(BaseChatModel): invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: - raise ValueError( + msg = ( f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Instead they were passed in as part of `model_kwargs` parameter." ) + raise ValueError(msg) values["model_kwargs"] = extra return values @@ -435,9 +437,11 @@ class ChatGroq(BaseChatModel): def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" if self.n < 1: - raise ValueError("n must be at least 1.") + msg = "n must be at least 1." + raise ValueError(msg) if self.n > 1 and self.streaming: - raise ValueError("n must be 1 when streaming.") + msg = "n must be 1 when streaming." + raise ValueError(msg) if self.temperature == 0: self.temperature = 1e-8 @@ -470,10 +474,11 @@ class ChatGroq(BaseChatModel): **client_params, **async_specific ).chat.completions except ImportError as exc: - raise ImportError( + msg = ( "Could not import groq python package. " "Please install it with `pip install groq`." - ) from exc + ) + raise ImportError(msg) from exc return self # @@ -680,7 +685,7 @@ class ChatGroq(BaseChatModel): return params def _create_chat_result( - self, response: Union[dict, BaseModel], params: dict + self, response: dict | BaseModel, params: dict ) -> ChatResult: generations = [] if not isinstance(response, dict): @@ -698,7 +703,7 @@ class ChatGroq(BaseChatModel): "total_tokens", input_tokens + output_tokens ), } - generation_info = dict(finish_reason=res.get("finish_reason")) + generation_info = {"finish_reason": res.get("finish_reason")} if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] gen = ChatGeneration( @@ -755,7 +760,7 @@ class ChatGroq(BaseChatModel): self, functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], function_call: Optional[ - Union[_FunctionCall, str, Literal["auto", "none"]] + Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051 ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: @@ -777,8 +782,8 @@ class ChatGroq(BaseChatModel): (if any). **kwargs: Any additional parameters to pass to :meth:`~langchain_groq.chat_models.ChatGroq.bind`. - """ + """ formatted_functions = [convert_to_openai_function(fn) for fn in functions] if function_call is not None: function_call = ( @@ -788,18 +793,20 @@ class ChatGroq(BaseChatModel): else function_call ) if isinstance(function_call, dict) and len(formatted_functions) != 1: - raise ValueError( + msg = ( "When specifying `function_call`, you must provide exactly one " "function." ) + raise ValueError(msg) if ( isinstance(function_call, dict) and formatted_functions[0]["name"] != function_call["name"] ): - raise ValueError( + msg = ( f"Function call {function_call} was specified, but the only " f"provided function was {formatted_functions[0]['name']}." ) + raise ValueError(msg) kwargs = {**kwargs, "function_call": function_call} return super().bind( functions=formatted_functions, @@ -811,7 +818,7 @@ class ChatGroq(BaseChatModel): tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], *, tool_choice: Optional[ - Union[dict, str, Literal["auto", "any", "none"], bool] + Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051 ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: @@ -829,8 +836,8 @@ class ChatGroq(BaseChatModel): {"type": "function", "function": {"name": <>}}. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. - """ + """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] if tool_choice is not None and tool_choice: if tool_choice == "any": @@ -841,10 +848,11 @@ class ChatGroq(BaseChatModel): tool_choice = {"type": "function", "function": {"name": tool_choice}} if isinstance(tool_choice, bool): if len(tools) > 1: - raise ValueError( + msg = ( "tool_choice can only be True when there is one tool. Received " f"{len(tools)} tools." ) + raise ValueError(msg) tool_name = formatted_tools[0]["function"]["name"] tool_choice = { "type": "function", @@ -861,8 +869,8 @@ class ChatGroq(BaseChatModel): method: Literal["function_calling", "json_mode"] = "function_calling", include_raw: bool = False, **kwargs: Any, - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: - """Model wrapper that returns outputs formatted to match the given schema. + ) -> Runnable[LanguageModelInput, dict | BaseModel]: + r"""Model wrapper that returns outputs formatted to match the given schema. Args: schema: @@ -895,6 +903,9 @@ class ChatGroq(BaseChatModel): response will be returned. If an error occurs during output parsing it will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". + kwargs: + Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. Returns: A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. @@ -1075,10 +1086,12 @@ class ChatGroq(BaseChatModel): # }, # 'parsing_error': None # } + """ # noqa: E501 _ = kwargs.pop("strict", None) if kwargs: - raise ValueError(f"Received unsupported arguments {kwargs}") + msg = f"Received unsupported arguments {kwargs}" + raise ValueError(msg) is_pydantic_schema = _is_pydantic_class(schema) if method == "json_schema": # Some applications require that incompatible parameters (e.g., unsupported @@ -1086,10 +1099,11 @@ class ChatGroq(BaseChatModel): method = "function_calling" if method == "function_calling": if schema is None: - raise ValueError( + msg = ( "schema must be specified when method is 'function_calling'. " "Received None." ) + raise ValueError(msg) formatted_tool = convert_to_openai_tool(schema) tool_name = formatted_tool["function"]["name"] llm = self.bind_tools( @@ -1123,10 +1137,11 @@ class ChatGroq(BaseChatModel): else JsonOutputParser() ) else: - raise ValueError( + msg = ( f"Unrecognized method argument. Expected one of 'function_calling' or " f"'json_mode'. Received: '{method}'" ) + raise ValueError(msg) if include_raw: parser_assign = RunnablePassthrough.assign( @@ -1137,8 +1152,7 @@ class ChatGroq(BaseChatModel): [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser + return llm | output_parser def _is_pydantic_class(obj: Any) -> bool: @@ -1160,6 +1174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: Returns: The dictionary. + """ message_dict: dict[str, Any] if isinstance(message, ChatMessage): @@ -1200,7 +1215,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: "tool_call_id": message.tool_call_id, } else: - raise TypeError(f"Got unknown type {message}") + msg = f"Got unknown type {message}" + raise TypeError(msg) if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] return message_dict @@ -1224,7 +1240,7 @@ def _convert_chunk_to_message_chunk( if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: + if role == "assistant" or default_class == AIMessageChunk: if reasoning := _dict.get("reasoning"): additional_kwargs["reasoning_content"] = reasoning if usage := (chunk.get("x_groq") or {}).get("usage"): @@ -1242,16 +1258,15 @@ def _convert_chunk_to_message_chunk( additional_kwargs=additional_kwargs, usage_metadata=usage_metadata, # type: ignore[arg-type] ) - elif role == "system" or default_class == SystemMessageChunk: + if role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: + if role == "function" or default_class == FunctionMessageChunk: return FunctionMessageChunk(content=content, name=_dict["name"]) - elif role == "tool" or default_class == ToolMessageChunk: + if role == "tool" or default_class == ToolMessageChunk: return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) - elif role or default_class == ChatMessageChunk: + if role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) - else: - return default_class(content=content) # type: ignore + return default_class(content=content) # type: ignore[call-arg] def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: @@ -1262,12 +1277,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: Returns: The LangChain message. + """ id_ = _dict.get("id") role = _dict.get("role") if role == "user": return HumanMessage(content=_dict.get("content", "")) - elif role == "assistant": + if role == "assistant": content = _dict.get("content", "") or "" additional_kwargs: dict = {} if reasoning := _dict.get("reasoning"): @@ -1292,11 +1308,11 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, ) - elif role == "system": + if role == "system": return SystemMessage(content=_dict.get("content", "")) - elif role == "function": + if role == "function": return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) # type: ignore[arg-type] - elif role == "tool": + if role == "tool": additional_kwargs = {} if "name" in _dict: additional_kwargs["name"] = _dict["name"] @@ -1305,8 +1321,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: tool_call_id=_dict.get("tool_call_id"), additional_kwargs=additional_kwargs, ) - else: - return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type] + return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type] def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict: diff --git a/libs/partners/groq/pyproject.toml b/libs/partners/groq/pyproject.toml index 9c999104497..3fd3d38c03a 100644 --- a/libs/partners/groq/pyproject.toml +++ b/libs/partners/groq/pyproject.toml @@ -44,8 +44,58 @@ disallow_untyped_defs = "True" target-version = "py39" [tool.ruff.lint] -select = ["E", "F", "I", "W", "UP", "S"] -ignore = [ "UP007", ] +select = [ + "A", # flake8-builtins + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "COM", # flake8-commas + "D", # pydocstyle + "DOC", # pydoclint + "E", # pycodestyle error + "EM", # flake8-errmsg + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT", # flake8-boolean-trap + "FLY", # flake8-flynt + "I", # isort + "ICN", # flake8-import-conventions + "INT", # flake8-gettext + "ISC", # isort-comprehensions + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PERF", # flake8-perf + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-rst-docstrings + "RUF", # ruff + "S", # flake8-bandit + "SLF", # flake8-self + "SLOT", # flake8-slots + "SIM", # flake8-simplify + "T10", # flake8-debugger + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 +] +ignore = [ + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D107", # Missing docstring in __init__ + "COM812", # Messes with the formatter + "ISC001", # Messes with the formatter + "PERF203", # Rarely useful + "S112", # Rarely useful + "RUF012", # Doesn't play well with Pydantic + "SLF001", # Private member access + "UP007", # pyupgrade: non-pep604-annotation-union +] [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/partners/groq/scripts/check_imports.py b/libs/partners/groq/scripts/check_imports.py index ba8de50118c..ec3fc6e95f5 100644 --- a/libs/partners/groq/scripts/check_imports.py +++ b/libs/partners/groq/scripts/check_imports.py @@ -10,8 +10,6 @@ if __name__ == "__main__": SourceFileLoader("x", file).load_module() except Exception: has_failure = True - print(file) traceback.print_exc() - print() sys.exit(1 if has_failure else 0) diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index 5ef2c3045d9..fd4ad486637 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -1,5 +1,7 @@ """Test ChatGroq chat model.""" +from __future__ import annotations + import json from typing import Any, Optional, cast @@ -109,11 +111,12 @@ async def test_astream() -> None: if token.response_metadata: chunks_with_response_metadata += 1 if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: - raise AssertionError( + msg = ( "Expected exactly one chunk with token counts or metadata. " "AIMessageChunk aggregation adds / appends these metadata. Check that " "this is behaving properly." ) + raise AssertionError(msg) assert isinstance(full, AIMessageChunk) assert full.usage_metadata is not None assert full.usage_metadata["input_tokens"] > 0 @@ -451,7 +454,7 @@ async def test_astreaming_tool_call() -> None: @pytest.mark.scheduled def test_json_mode_structured_output() -> None: - """Test with_structured_output with json""" + """Test with_structured_output with json.""" class Joke(BaseModel): """Joke to tell user.""" @@ -496,9 +499,9 @@ def test_setting_service_tier_class() -> None: assert response.response_metadata.get("service_tier") == "on_demand" with pytest.raises(ValueError): - ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore + ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore[arg-type] with pytest.raises(ValueError): - ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore + ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore[arg-type] def test_setting_service_tier_request() -> None: diff --git a/libs/partners/groq/tests/integration_tests/test_compile.py b/libs/partners/groq/tests/integration_tests/test_compile.py index 33ecccdfa0f..f315e45f521 100644 --- a/libs/partners/groq/tests/integration_tests/test_compile.py +++ b/libs/partners/groq/tests/integration_tests/test_compile.py @@ -4,4 +4,3 @@ import pytest @pytest.mark.compile def test_placeholder() -> None: """Used for compiling integration tests without running any real tests.""" - pass diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index 90b9f16a5e5..4d3eb6b0d7a 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -1,4 +1,4 @@ -"""Standard LangChain interface tests""" +"""Standard LangChain interface tests.""" import pytest from langchain_core.language_models import BaseChatModel diff --git a/libs/partners/groq/tests/unit_tests/fake/callbacks.py b/libs/partners/groq/tests/unit_tests/fake/callbacks.py index 8b25a400b43..f38e8b56dc8 100644 --- a/libs/partners/groq/tests/unit_tests/fake/callbacks.py +++ b/libs/partners/groq/tests/unit_tests/fake/callbacks.py @@ -1,5 +1,7 @@ """A fake callback handler for testing purposes.""" +from __future__ import annotations + from itertools import chain from typing import Any, Optional, Union from uuid import UUID @@ -257,7 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): self.on_retriever_error_common() # Overriding since BaseModel has __deepcopy__ method as well - def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore + def __deepcopy__(self, memo: dict) -> FakeCallbackHandler: # type: ignore[override] return self @@ -392,5 +394,5 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi self.on_text_common() # Overriding since BaseModel has __deepcopy__ method as well - def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore + def __deepcopy__(self, memo: dict) -> FakeAsyncCallbackHandler: # type: ignore[override] return self diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index ef623c947d1..eec86d508a4 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -244,7 +244,7 @@ def test_chat_groq_invalid_streaming_params() -> None: def test_chat_groq_secret() -> None: - """Test that secret is not printed""" + """Test that secret is not printed.""" secret = "secretKey" # noqa: S105 not_secret = "safe" # noqa: S105 llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type] @@ -255,7 +255,7 @@ def test_chat_groq_secret() -> None: @pytest.mark.filterwarnings("ignore:The function `loads` is in beta") def test_groq_serialization() -> None: - """Test that ChatGroq can be successfully serialized and deserialized""" + """Test that ChatGroq can be successfully serialized and deserialized.""" api_key1 = "top secret" api_key2 = "topest secret" llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type] diff --git a/libs/partners/groq/tests/unit_tests/test_standard.py b/libs/partners/groq/tests/unit_tests/test_standard.py index 4ab0bf4fedd..30e4b1679fb 100644 --- a/libs/partners/groq/tests/unit_tests/test_standard.py +++ b/libs/partners/groq/tests/unit_tests/test_standard.py @@ -1,4 +1,4 @@ -"""Standard LangChain interface tests""" +"""Standard LangChain interface tests.""" from langchain_core.language_models import BaseChatModel from langchain_tests.unit_tests.chat_models import (