mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
groq[patch]: ruff fixes and rules (#31904)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
parent
750721b4c3
commit
dd76209bbd
@ -83,7 +83,7 @@ from langchain_groq.version import __version__
|
|||||||
|
|
||||||
|
|
||||||
class ChatGroq(BaseChatModel):
|
class ChatGroq(BaseChatModel):
|
||||||
"""Groq Chat large language models API.
|
r"""Groq Chat large language models API.
|
||||||
|
|
||||||
To use, you should have the
|
To use, you should have the
|
||||||
environment variable ``GROQ_API_KEY`` set with your API key.
|
environment variable ``GROQ_API_KEY`` set with your API key.
|
||||||
@ -412,7 +412,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
raise ValueError(f"Found {field_name} supplied twice.")
|
msg = f"Found {field_name} supplied twice."
|
||||||
|
raise ValueError(msg)
|
||||||
if field_name not in all_required_field_names:
|
if field_name not in all_required_field_names:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"""WARNING! {field_name} is not default parameter.
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
@ -423,10 +424,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
|
|
||||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
if invalid_model_kwargs:
|
if invalid_model_kwargs:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
values["model_kwargs"] = extra
|
values["model_kwargs"] = extra
|
||||||
return values
|
return values
|
||||||
@ -435,9 +437,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if self.n < 1:
|
if self.n < 1:
|
||||||
raise ValueError("n must be at least 1.")
|
msg = "n must be at least 1."
|
||||||
|
raise ValueError(msg)
|
||||||
if self.n > 1 and self.streaming:
|
if self.n > 1 and self.streaming:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
msg = "n must be 1 when streaming."
|
||||||
|
raise ValueError(msg)
|
||||||
if self.temperature == 0:
|
if self.temperature == 0:
|
||||||
self.temperature = 1e-8
|
self.temperature = 1e-8
|
||||||
|
|
||||||
@ -470,10 +474,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
**client_params, **async_specific
|
**client_params, **async_specific
|
||||||
).chat.completions
|
).chat.completions
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import groq python package. "
|
"Could not import groq python package. "
|
||||||
"Please install it with `pip install groq`."
|
"Please install it with `pip install groq`."
|
||||||
) from exc
|
)
|
||||||
|
raise ImportError(msg) from exc
|
||||||
return self
|
return self
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -680,7 +685,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
def _create_chat_result(
|
def _create_chat_result(
|
||||||
self, response: Union[dict, BaseModel], params: dict
|
self, response: dict | BaseModel, params: dict
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
generations = []
|
generations = []
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
@ -698,7 +703,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
"total_tokens", input_tokens + output_tokens
|
"total_tokens", input_tokens + output_tokens
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||||
if "logprobs" in res:
|
if "logprobs" in res:
|
||||||
generation_info["logprobs"] = res["logprobs"]
|
generation_info["logprobs"] = res["logprobs"]
|
||||||
gen = ChatGeneration(
|
gen = ChatGeneration(
|
||||||
@ -755,7 +760,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
self,
|
self,
|
||||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||||
function_call: Optional[
|
function_call: Optional[
|
||||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
Union[_FunctionCall, str, Literal["auto", "none"]] # noqa: PYI051
|
||||||
] = None,
|
] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
@ -777,8 +782,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
(if any).
|
(if any).
|
||||||
**kwargs: Any additional parameters to pass to
|
**kwargs: Any additional parameters to pass to
|
||||||
:meth:`~langchain_groq.chat_models.ChatGroq.bind`.
|
:meth:`~langchain_groq.chat_models.ChatGroq.bind`.
|
||||||
"""
|
|
||||||
|
|
||||||
|
"""
|
||||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||||
if function_call is not None:
|
if function_call is not None:
|
||||||
function_call = (
|
function_call = (
|
||||||
@ -788,18 +793,20 @@ class ChatGroq(BaseChatModel):
|
|||||||
else function_call
|
else function_call
|
||||||
)
|
)
|
||||||
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"When specifying `function_call`, you must provide exactly one "
|
"When specifying `function_call`, you must provide exactly one "
|
||||||
"function."
|
"function."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
if (
|
if (
|
||||||
isinstance(function_call, dict)
|
isinstance(function_call, dict)
|
||||||
and formatted_functions[0]["name"] != function_call["name"]
|
and formatted_functions[0]["name"] != function_call["name"]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Function call {function_call} was specified, but the only "
|
f"Function call {function_call} was specified, but the only "
|
||||||
f"provided function was {formatted_functions[0]['name']}."
|
f"provided function was {formatted_functions[0]['name']}."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
kwargs = {**kwargs, "function_call": function_call}
|
kwargs = {**kwargs, "function_call": function_call}
|
||||||
return super().bind(
|
return super().bind(
|
||||||
functions=formatted_functions,
|
functions=formatted_functions,
|
||||||
@ -811,7 +818,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||||
*,
|
*,
|
||||||
tool_choice: Optional[
|
tool_choice: Optional[
|
||||||
Union[dict, str, Literal["auto", "any", "none"], bool]
|
Union[dict, str, Literal["auto", "any", "none"], bool] # noqa: PYI051
|
||||||
] = None,
|
] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
@ -829,8 +836,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||||
**kwargs: Any additional parameters to pass to the
|
**kwargs: Any additional parameters to pass to the
|
||||||
:class:`~langchain.runnable.Runnable` constructor.
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
"""
|
|
||||||
|
|
||||||
|
"""
|
||||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
if tool_choice is not None and tool_choice:
|
if tool_choice is not None and tool_choice:
|
||||||
if tool_choice == "any":
|
if tool_choice == "any":
|
||||||
@ -841,10 +848,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||||
if isinstance(tool_choice, bool):
|
if isinstance(tool_choice, bool):
|
||||||
if len(tools) > 1:
|
if len(tools) > 1:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"tool_choice can only be True when there is one tool. Received "
|
"tool_choice can only be True when there is one tool. Received "
|
||||||
f"{len(tools)} tools."
|
f"{len(tools)} tools."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
tool_name = formatted_tools[0]["function"]["name"]
|
tool_name = formatted_tools[0]["function"]["name"]
|
||||||
tool_choice = {
|
tool_choice = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -861,8 +869,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||||
include_raw: bool = False,
|
include_raw: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
||||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
r"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema:
|
schema:
|
||||||
@ -895,6 +903,9 @@ class ChatGroq(BaseChatModel):
|
|||||||
response will be returned. If an error occurs during output parsing it
|
response will be returned. If an error occurs during output parsing it
|
||||||
will be caught and returned as well. The final output is always a dict
|
will be caught and returned as well. The final output is always a dict
|
||||||
with keys "raw", "parsed", and "parsing_error".
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
|
kwargs:
|
||||||
|
Any additional parameters to pass to the
|
||||||
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||||
@ -1075,10 +1086,12 @@ class ChatGroq(BaseChatModel):
|
|||||||
# },
|
# },
|
||||||
# 'parsing_error': None
|
# 'parsing_error': None
|
||||||
# }
|
# }
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
_ = kwargs.pop("strict", None)
|
_ = kwargs.pop("strict", None)
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
msg = f"Received unsupported arguments {kwargs}"
|
||||||
|
raise ValueError(msg)
|
||||||
is_pydantic_schema = _is_pydantic_class(schema)
|
is_pydantic_schema = _is_pydantic_class(schema)
|
||||||
if method == "json_schema":
|
if method == "json_schema":
|
||||||
# Some applications require that incompatible parameters (e.g., unsupported
|
# Some applications require that incompatible parameters (e.g., unsupported
|
||||||
@ -1086,10 +1099,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
method = "function_calling"
|
method = "function_calling"
|
||||||
if method == "function_calling":
|
if method == "function_calling":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"schema must be specified when method is 'function_calling'. "
|
"schema must be specified when method is 'function_calling'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
formatted_tool = convert_to_openai_tool(schema)
|
formatted_tool = convert_to_openai_tool(schema)
|
||||||
tool_name = formatted_tool["function"]["name"]
|
tool_name = formatted_tool["function"]["name"]
|
||||||
llm = self.bind_tools(
|
llm = self.bind_tools(
|
||||||
@ -1123,10 +1137,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
else JsonOutputParser()
|
else JsonOutputParser()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||||
f"'json_mode'. Received: '{method}'"
|
f"'json_mode'. Received: '{method}'"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if include_raw:
|
if include_raw:
|
||||||
parser_assign = RunnablePassthrough.assign(
|
parser_assign = RunnablePassthrough.assign(
|
||||||
@ -1137,8 +1152,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
[parser_none], exception_key="parsing_error"
|
[parser_none], exception_key="parsing_error"
|
||||||
)
|
)
|
||||||
return RunnableMap(raw=llm) | parser_with_fallback
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
else:
|
return llm | output_parser
|
||||||
return llm | output_parser
|
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
@ -1160,6 +1174,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The dictionary.
|
The dictionary.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
message_dict: dict[str, Any]
|
message_dict: dict[str, Any]
|
||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
@ -1200,7 +1215,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
"tool_call_id": message.tool_call_id,
|
"tool_call_id": message.tool_call_id,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Got unknown type {message}")
|
msg = f"Got unknown type {message}"
|
||||||
|
raise TypeError(msg)
|
||||||
if "name" in message.additional_kwargs:
|
if "name" in message.additional_kwargs:
|
||||||
message_dict["name"] = message.additional_kwargs["name"]
|
message_dict["name"] = message.additional_kwargs["name"]
|
||||||
return message_dict
|
return message_dict
|
||||||
@ -1224,7 +1240,7 @@ def _convert_chunk_to_message_chunk(
|
|||||||
|
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
if role == "assistant" or default_class == AIMessageChunk:
|
||||||
if reasoning := _dict.get("reasoning"):
|
if reasoning := _dict.get("reasoning"):
|
||||||
additional_kwargs["reasoning_content"] = reasoning
|
additional_kwargs["reasoning_content"] = reasoning
|
||||||
if usage := (chunk.get("x_groq") or {}).get("usage"):
|
if usage := (chunk.get("x_groq") or {}).get("usage"):
|
||||||
@ -1242,16 +1258,15 @@ def _convert_chunk_to_message_chunk(
|
|||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
if role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
elif role == "function" or default_class == FunctionMessageChunk:
|
if role == "function" or default_class == FunctionMessageChunk:
|
||||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
elif role == "tool" or default_class == ToolMessageChunk:
|
if role == "tool" or default_class == ToolMessageChunk:
|
||||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||||
elif role or default_class == ChatMessageChunk:
|
if role or default_class == ChatMessageChunk:
|
||||||
return ChatMessageChunk(content=content, role=role)
|
return ChatMessageChunk(content=content, role=role)
|
||||||
else:
|
return default_class(content=content) # type: ignore[call-arg]
|
||||||
return default_class(content=content) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
@ -1262,12 +1277,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The LangChain message.
|
The LangChain message.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
id_ = _dict.get("id")
|
id_ = _dict.get("id")
|
||||||
role = _dict.get("role")
|
role = _dict.get("role")
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessage(content=_dict.get("content", ""))
|
return HumanMessage(content=_dict.get("content", ""))
|
||||||
elif role == "assistant":
|
if role == "assistant":
|
||||||
content = _dict.get("content", "") or ""
|
content = _dict.get("content", "") or ""
|
||||||
additional_kwargs: dict = {}
|
additional_kwargs: dict = {}
|
||||||
if reasoning := _dict.get("reasoning"):
|
if reasoning := _dict.get("reasoning"):
|
||||||
@ -1292,11 +1308,11 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
invalid_tool_calls=invalid_tool_calls,
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
)
|
)
|
||||||
elif role == "system":
|
if role == "system":
|
||||||
return SystemMessage(content=_dict.get("content", ""))
|
return SystemMessage(content=_dict.get("content", ""))
|
||||||
elif role == "function":
|
if role == "function":
|
||||||
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) # type: ignore[arg-type]
|
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) # type: ignore[arg-type]
|
||||||
elif role == "tool":
|
if role == "tool":
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if "name" in _dict:
|
if "name" in _dict:
|
||||||
additional_kwargs["name"] = _dict["name"]
|
additional_kwargs["name"] = _dict["name"]
|
||||||
@ -1305,8 +1321,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
tool_call_id=_dict.get("tool_call_id"),
|
tool_call_id=_dict.get("tool_call_id"),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
|
||||||
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
|
|
||||||
def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict:
|
def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict:
|
||||||
|
@ -44,8 +44,58 @@ disallow_untyped_defs = "True"
|
|||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "I", "W", "UP", "S"]
|
select = [
|
||||||
ignore = [ "UP007", ]
|
"A", # flake8-builtins
|
||||||
|
"ASYNC", # flake8-async
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"COM", # flake8-commas
|
||||||
|
"D", # pydocstyle
|
||||||
|
"DOC", # pydoclint
|
||||||
|
"E", # pycodestyle error
|
||||||
|
"EM", # flake8-errmsg
|
||||||
|
"F", # pyflakes
|
||||||
|
"FA", # flake8-future-annotations
|
||||||
|
"FBT", # flake8-boolean-trap
|
||||||
|
"FLY", # flake8-flynt
|
||||||
|
"I", # isort
|
||||||
|
"ICN", # flake8-import-conventions
|
||||||
|
"INT", # flake8-gettext
|
||||||
|
"ISC", # isort-comprehensions
|
||||||
|
"PGH", # pygrep-hooks
|
||||||
|
"PIE", # flake8-pie
|
||||||
|
"PERF", # flake8-perf
|
||||||
|
"PYI", # flake8-pyi
|
||||||
|
"Q", # flake8-quotes
|
||||||
|
"RET", # flake8-return
|
||||||
|
"RSE", # flake8-rst-docstrings
|
||||||
|
"RUF", # ruff
|
||||||
|
"S", # flake8-bandit
|
||||||
|
"SLF", # flake8-self
|
||||||
|
"SLOT", # flake8-slots
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"T10", # flake8-debugger
|
||||||
|
"T20", # flake8-print
|
||||||
|
"TID", # flake8-tidy-imports
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"W", # pycodestyle warning
|
||||||
|
"YTT", # flake8-2020
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"D100", # Missing docstring in public module
|
||||||
|
"D101", # Missing docstring in public class
|
||||||
|
"D102", # Missing docstring in public method
|
||||||
|
"D103", # Missing docstring in public function
|
||||||
|
"D104", # Missing docstring in public package
|
||||||
|
"D105", # Missing docstring in magic method
|
||||||
|
"D107", # Missing docstring in __init__
|
||||||
|
"COM812", # Messes with the formatter
|
||||||
|
"ISC001", # Messes with the formatter
|
||||||
|
"PERF203", # Rarely useful
|
||||||
|
"S112", # Rarely useful
|
||||||
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
|
"SLF001", # Private member access
|
||||||
|
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||||
|
]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
omit = ["tests/*"]
|
omit = ["tests/*"]
|
||||||
|
@ -10,8 +10,6 @@ if __name__ == "__main__":
|
|||||||
SourceFileLoader("x", file).load_module()
|
SourceFileLoader("x", file).load_module()
|
||||||
except Exception:
|
except Exception:
|
||||||
has_failure = True
|
has_failure = True
|
||||||
print(file)
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print()
|
|
||||||
|
|
||||||
sys.exit(1 if has_failure else 0)
|
sys.exit(1 if has_failure else 0)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test ChatGroq chat model."""
|
"""Test ChatGroq chat model."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
@ -109,11 +111,12 @@ async def test_astream() -> None:
|
|||||||
if token.response_metadata:
|
if token.response_metadata:
|
||||||
chunks_with_response_metadata += 1
|
chunks_with_response_metadata += 1
|
||||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||||
raise AssertionError(
|
msg = (
|
||||||
"Expected exactly one chunk with token counts or metadata. "
|
"Expected exactly one chunk with token counts or metadata. "
|
||||||
"AIMessageChunk aggregation adds / appends these metadata. Check that "
|
"AIMessageChunk aggregation adds / appends these metadata. Check that "
|
||||||
"this is behaving properly."
|
"this is behaving properly."
|
||||||
)
|
)
|
||||||
|
raise AssertionError(msg)
|
||||||
assert isinstance(full, AIMessageChunk)
|
assert isinstance(full, AIMessageChunk)
|
||||||
assert full.usage_metadata is not None
|
assert full.usage_metadata is not None
|
||||||
assert full.usage_metadata["input_tokens"] > 0
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
@ -451,7 +454,7 @@ async def test_astreaming_tool_call() -> None:
|
|||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
def test_json_mode_structured_output() -> None:
|
def test_json_mode_structured_output() -> None:
|
||||||
"""Test with_structured_output with json"""
|
"""Test with_structured_output with json."""
|
||||||
|
|
||||||
class Joke(BaseModel):
|
class Joke(BaseModel):
|
||||||
"""Joke to tell user."""
|
"""Joke to tell user."""
|
||||||
@ -496,9 +499,9 @@ def test_setting_service_tier_class() -> None:
|
|||||||
assert response.response_metadata.get("service_tier") == "on_demand"
|
assert response.response_metadata.get("service_tier") == "on_demand"
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore
|
ChatGroq(model=MODEL_NAME, service_tier=None) # type: ignore[arg-type]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore
|
ChatGroq(model=MODEL_NAME, service_tier="invalid") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def test_setting_service_tier_request() -> None:
|
def test_setting_service_tier_request() -> None:
|
||||||
|
@ -4,4 +4,3 @@ import pytest
|
|||||||
@pytest.mark.compile
|
@pytest.mark.compile
|
||||||
def test_placeholder() -> None:
|
def test_placeholder() -> None:
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
pass
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Standard LangChain interface tests"""
|
"""Standard LangChain interface tests."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""A fake callback handler for testing purposes."""
|
"""A fake callback handler for testing purposes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -257,7 +259,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
self.on_retriever_error_common()
|
self.on_retriever_error_common()
|
||||||
|
|
||||||
# Overriding since BaseModel has __deepcopy__ method as well
|
# Overriding since BaseModel has __deepcopy__ method as well
|
||||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore
|
def __deepcopy__(self, memo: dict) -> FakeCallbackHandler: # type: ignore[override]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -392,5 +394,5 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
self.on_text_common()
|
self.on_text_common()
|
||||||
|
|
||||||
# Overriding since BaseModel has __deepcopy__ method as well
|
# Overriding since BaseModel has __deepcopy__ method as well
|
||||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore
|
def __deepcopy__(self, memo: dict) -> FakeAsyncCallbackHandler: # type: ignore[override]
|
||||||
return self
|
return self
|
||||||
|
@ -244,7 +244,7 @@ def test_chat_groq_invalid_streaming_params() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_chat_groq_secret() -> None:
|
def test_chat_groq_secret() -> None:
|
||||||
"""Test that secret is not printed"""
|
"""Test that secret is not printed."""
|
||||||
secret = "secretKey" # noqa: S105
|
secret = "secretKey" # noqa: S105
|
||||||
not_secret = "safe" # noqa: S105
|
not_secret = "safe" # noqa: S105
|
||||||
llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
|
llm = ChatGroq(model="foo", api_key=secret, model_kwargs={"not_secret": not_secret}) # type: ignore[call-arg, arg-type]
|
||||||
@ -255,7 +255,7 @@ def test_chat_groq_secret() -> None:
|
|||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore:The function `loads` is in beta")
|
@pytest.mark.filterwarnings("ignore:The function `loads` is in beta")
|
||||||
def test_groq_serialization() -> None:
|
def test_groq_serialization() -> None:
|
||||||
"""Test that ChatGroq can be successfully serialized and deserialized"""
|
"""Test that ChatGroq can be successfully serialized and deserialized."""
|
||||||
api_key1 = "top secret"
|
api_key1 = "top secret"
|
||||||
api_key2 = "topest secret"
|
api_key2 = "topest secret"
|
||||||
llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
|
llm = ChatGroq(model="foo", api_key=api_key1, temperature=0.5) # type: ignore[call-arg, arg-type]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Standard LangChain interface tests"""
|
"""Standard LangChain interface tests."""
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_tests.unit_tests.chat_models import (
|
from langchain_tests.unit_tests.chat_models import (
|
||||||
|
Loading…
Reference in New Issue
Block a user