mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
community[patch]: upgrade to recent version of mypy (#21616)
This PR upgrades community to a recent version of mypy. It inserts type: ignore on all existing failures.
This commit is contained in:
parent
b923951062
commit
25fbe356b4
@ -95,18 +95,18 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name"))
|
||||
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) # type: ignore[arg-type]
|
||||
elif role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return ToolMessage(
|
||||
content=_dict.get("content", ""),
|
||||
tool_call_id=_dict.get("tool_call_id"),
|
||||
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type]
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role)
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
@ -21,11 +21,11 @@ class AzureAiServicesToolkit(BaseToolkit):
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
tools: List[BaseTool] = [
|
||||
AzureAiServicesDocumentIntelligenceTool(),
|
||||
AzureAiServicesImageAnalysisTool(),
|
||||
AzureAiServicesSpeechToTextTool(),
|
||||
AzureAiServicesTextToSpeechTool(),
|
||||
AzureAiServicesTextAnalyticsForHealthTool(),
|
||||
AzureAiServicesDocumentIntelligenceTool(), # type: ignore[call-arg]
|
||||
AzureAiServicesImageAnalysisTool(), # type: ignore[call-arg]
|
||||
AzureAiServicesSpeechToTextTool(), # type: ignore[call-arg]
|
||||
AzureAiServicesTextToSpeechTool(), # type: ignore[call-arg]
|
||||
AzureAiServicesTextAnalyticsForHealthTool(), # type: ignore[call-arg]
|
||||
]
|
||||
|
||||
return tools
|
||||
|
@ -21,13 +21,13 @@ class AzureCognitiveServicesToolkit(BaseToolkit):
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
tools: List[BaseTool] = [
|
||||
AzureCogsFormRecognizerTool(),
|
||||
AzureCogsSpeech2TextTool(),
|
||||
AzureCogsText2SpeechTool(),
|
||||
AzureCogsTextAnalyticsHealthTool(),
|
||||
AzureCogsFormRecognizerTool(), # type: ignore[call-arg]
|
||||
AzureCogsSpeech2TextTool(), # type: ignore[call-arg]
|
||||
AzureCogsText2SpeechTool(), # type: ignore[call-arg]
|
||||
AzureCogsTextAnalyticsHealthTool(), # type: ignore[call-arg]
|
||||
]
|
||||
|
||||
# TODO: Remove check once azure-ai-vision supports MacOS.
|
||||
if sys.platform.startswith("linux") or sys.platform.startswith("win"):
|
||||
tools.append(AzureCogsImageAnalysisTool())
|
||||
tools.append(AzureCogsImageAnalysisTool()) # type: ignore[call-arg]
|
||||
return tools
|
||||
|
@ -102,7 +102,7 @@ class ClickupToolkit(BaseToolkit):
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -45,6 +45,6 @@ class ConneryToolkit(BaseToolkit):
|
||||
ConneryToolkit: The Connery Toolkit.
|
||||
"""
|
||||
|
||||
instance = cls(tools=connery_service.list_actions())
|
||||
instance = cls(tools=connery_service.list_actions()) # type: ignore[arg-type]
|
||||
|
||||
return instance
|
||||
|
@ -73,7 +73,7 @@ class FileManagementToolkit(BaseToolkit):
|
||||
tools: List[BaseTool] = []
|
||||
for tool in allowed_tools:
|
||||
tool_cls = _FILE_TOOLS_MAP[tool]
|
||||
tools.append(tool_cls(root_dir=self.root_dir))
|
||||
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore[call-arg]
|
||||
return tools
|
||||
|
||||
|
||||
|
@ -308,7 +308,7 @@ class GitHubToolkit(BaseToolkit):
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -88,7 +88,7 @@ class GitLabToolkit(BaseToolkit):
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -64,7 +64,7 @@ class JiraToolkit(BaseToolkit):
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -51,7 +51,7 @@ class NasaToolkit(BaseToolkit):
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -262,12 +262,12 @@ def _create_api_controller_agent(
|
||||
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||
tools: List[BaseTool] = [
|
||||
RequestsGetToolWithParsing(
|
||||
RequestsGetToolWithParsing( # type: ignore[call-arg]
|
||||
requests_wrapper=requests_wrapper,
|
||||
llm_chain=get_llm_chain,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
),
|
||||
RequestsPostToolWithParsing(
|
||||
RequestsPostToolWithParsing( # type: ignore[call-arg]
|
||||
requests_wrapper=requests_wrapper,
|
||||
llm_chain=post_llm_chain,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
|
@ -66,7 +66,7 @@ class PowerBIToolkit(BaseToolkit):
|
||||
powerbi=self.powerbi,
|
||||
examples=self.examples,
|
||||
max_iterations=self.max_iterations,
|
||||
output_token_limit=self.output_token_limit,
|
||||
output_token_limit=self.output_token_limit, # type: ignore[arg-type]
|
||||
tiktoken_model_name=self.tiktoken_model_name,
|
||||
),
|
||||
InfoPowerBITool(powerbi=self.powerbi),
|
||||
|
@ -136,7 +136,7 @@ def create_sql_agent(
|
||||
"Must provide exactly one of 'toolkit' or 'db'. Received both."
|
||||
)
|
||||
|
||||
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
|
||||
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db) # type: ignore[arg-type]
|
||||
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
tools = toolkit.get_tools() + list(extra_tools)
|
||||
if prompt is None:
|
||||
|
@ -42,7 +42,7 @@ class SteamToolkit(BaseToolkit):
|
||||
)
|
||||
for action in operations
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -29,7 +29,7 @@ class ZapierToolkit(BaseToolkit):
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
@classmethod
|
||||
async def async_from_zapier_nla_wrapper(
|
||||
@ -46,7 +46,7 @@ class ZapierToolkit(BaseToolkit):
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
return cls(tools=tools)
|
||||
return cls(tools=tools) # type: ignore[arg-type]
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
@ -420,7 +420,7 @@ class _RedisCacheBase(BaseCache, ABC):
|
||||
)
|
||||
# In a previous life we stored the raw text directly
|
||||
# in the table, so assume it's in that format.
|
||||
generations.append(Generation(text=text))
|
||||
generations.append(Generation(text=text)) # type: ignore[arg-type]
|
||||
return generations if generations else None
|
||||
|
||||
@staticmethod
|
||||
|
@ -376,7 +376,7 @@ def create_ernie_fn_chain(
|
||||
output_key: str = "function",
|
||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
) -> LLMChain: # type: ignore[valid-type]
|
||||
"""[Legacy] Create an LLM chain that uses Ernie functions.
|
||||
|
||||
Args:
|
||||
@ -453,7 +453,7 @@ def create_ernie_fn_chain(
|
||||
}
|
||||
if len(ernie_functions) == 1:
|
||||
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
|
||||
llm_chain = LLMChain(
|
||||
llm_chain = LLMChain( # type: ignore[misc]
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
output_parser=output_parser,
|
||||
@ -472,7 +472,7 @@ def create_structured_output_chain(
|
||||
output_key: str = "function",
|
||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
) -> LLMChain: # type: ignore[valid-type]
|
||||
"""[Legacy] Create an LLMChain that uses an Ernie function to get a structured output.
|
||||
|
||||
Args:
|
||||
|
@ -148,7 +148,7 @@ class IMessageChatLoader(BaseChatLoader):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
HumanMessage(
|
||||
HumanMessage( # type: ignore[call-arg]
|
||||
role=sender,
|
||||
content=content,
|
||||
additional_kwargs={
|
||||
|
@ -51,7 +51,7 @@ class SlackChatLoader(BaseChatLoader):
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
HumanMessage(
|
||||
HumanMessage( # type: ignore[call-arg]
|
||||
role=sender,
|
||||
content=text,
|
||||
additional_kwargs={
|
||||
|
@ -77,7 +77,7 @@ def map_ai_messages_in_session(chat_sessions: ChatSession, sender: str) -> ChatS
|
||||
message = AIMessage(
|
||||
content=message.content,
|
||||
additional_kwargs=message.additional_kwargs.copy(),
|
||||
example=getattr(message, "example", None),
|
||||
example=getattr(message, "example", None), # type: ignore[arg-type]
|
||||
)
|
||||
num_converted += 1
|
||||
messages.append(message)
|
||||
|
@ -73,7 +73,7 @@ class WhatsAppChatLoader(BaseChatLoader):
|
||||
timestamp, sender, text = result.groups()
|
||||
if not self._ignore_lines.match(text.strip()):
|
||||
results.append(
|
||||
HumanMessage(
|
||||
HumanMessage( # type: ignore[call-arg]
|
||||
role=sender,
|
||||
content=text,
|
||||
additional_kwargs={
|
||||
|
@ -419,4 +419,4 @@ def _convert_delta_to_message_chunk(
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
@ -66,9 +66,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class ChatBaichuan(BaseChatModel):
|
||||
|
@ -383,7 +383,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=AIMessageChunk(
|
||||
message=AIMessageChunk( # type: ignore[call-arg]
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=additional_kwargs,
|
||||
@ -410,7 +410,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=AIMessageChunk(
|
||||
message=AIMessageChunk( # type: ignore[call-arg]
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=additional_kwargs,
|
||||
@ -552,7 +552,8 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
llm = self.bind_tools([schema])
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True, # type: ignore[list-item]
|
||||
)
|
||||
else:
|
||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
|
@ -69,7 +69,7 @@ def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChun
|
||||
elif role == "assistant":
|
||||
return AIMessageChunk(content=content)
|
||||
else:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class ChatCoze(BaseChatModel):
|
||||
|
@ -118,9 +118,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
@ -58,7 +58,7 @@ def _convert_delta_to_message_chunk(
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Any) -> BaseMessage:
|
||||
|
@ -108,9 +108,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
|
@ -72,9 +72,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
# signature generation
|
||||
|
@ -103,9 +103,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
|
@ -131,9 +131,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
@ -64,9 +64,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class LlamaEdgeChatService(BaseChatModel):
|
||||
|
@ -82,7 +82,7 @@ class MiniMaxChat(MinimaxCommon, BaseChatModel):
|
||||
|
||||
# This is required since the stop are not enforced by the model parameters
|
||||
text = text if stop is None else enforce_stop_tokens(text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(text))])
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(text))]) # type: ignore[misc]
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -139,9 +139,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@deprecated(
|
||||
|
@ -198,9 +198,9 @@ class ChatPerplexity(BaseChatModel):
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
@ -136,7 +136,7 @@ def _convert_delta_response_to_message_chunk(
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role), finish_reasons
|
||||
else:
|
||||
return default_class(content=content), finish_reasons
|
||||
return default_class(content=content), finish_reasons # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _messages_to_prompt_dict(
|
||||
|
@ -10,7 +10,7 @@ from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.llms.solar import SOLAR_SERVICE_URL_BASE, SolarCommon
|
||||
|
||||
|
||||
@deprecated(
|
||||
@deprecated( # type: ignore[arg-type]
|
||||
since="0.0.34", removal="0.3.0", alternative_import="langchain_upstage.ChatUpstage"
|
||||
)
|
||||
class SolarChat(SolarCommon, ChatOpenAI):
|
||||
|
@ -85,7 +85,7 @@ def _convert_delta_to_message_chunk(
|
||||
elif msg_role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=msg_content, role=msg_role)
|
||||
else:
|
||||
return default_class(content=msg_content)
|
||||
return default_class(content=msg_content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class ChatSparkLLM(BaseChatModel):
|
||||
@ -382,10 +382,10 @@ class _SparkLLMClient:
|
||||
on_close=self.on_close,
|
||||
on_open=self.on_open,
|
||||
)
|
||||
ws.messages = messages
|
||||
ws.user_id = user_id
|
||||
ws.model_kwargs = self.model_kwargs if model_kwargs is None else model_kwargs
|
||||
ws.streaming = streaming
|
||||
ws.messages = messages # type: ignore[attr-defined]
|
||||
ws.user_id = user_id # type: ignore[attr-defined]
|
||||
ws.model_kwargs = self.model_kwargs if model_kwargs is None else model_kwargs # type: ignore[attr-defined]
|
||||
ws.streaming = streaming # type: ignore[attr-defined]
|
||||
ws.run_forever()
|
||||
|
||||
def arun(
|
||||
|
@ -94,7 +94,7 @@ def convert_dict_to_message(
|
||||
else AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls,
|
||||
tool_calls=tool_calls, # type: ignore[arg-type]
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
)
|
||||
|
@ -437,9 +437,9 @@ def _convert_delta_to_message_chunk(
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
@ -451,7 +451,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role)
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
@ -101,7 +101,7 @@ def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage:
|
||||
if tool_calls is not None:
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
return ChatMessage(role=role, content=content)
|
||||
return ChatMessage(role=role, content=content) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
|
||||
@ -144,8 +144,8 @@ def _convert_delta_to_message_chunk(
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return default_class(content=content)
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _truncate_params(payload: Dict[str, Any]) -> None:
|
||||
|
@ -70,7 +70,7 @@ def fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
|
||||
class O365BaseLoader(BaseLoader, BaseModel):
|
||||
"""Base class for all loaders that uses O365 Package"""
|
||||
|
||||
settings: _O365Settings = Field(default_factory=_O365Settings)
|
||||
settings: _O365Settings = Field(default_factory=_O365Settings) # type: ignore[arg-type]
|
||||
"""Settings for the Office365 API client."""
|
||||
auth_with_token: bool = False
|
||||
"""Whether to authenticate with a token or not. Defaults to False."""
|
||||
|
@ -86,7 +86,7 @@ class KineticaLoader(BaseLoader):
|
||||
query_result = self._execute_query()
|
||||
if isinstance(query_result, Exception):
|
||||
print(f"An error occurred during the query: {query_result}") # noqa: T201
|
||||
return []
|
||||
return [] # type: ignore[return-value]
|
||||
page_content_columns, metadata_columns = self._get_columns(query_result)
|
||||
if "*" in page_content_columns:
|
||||
page_content_columns = list(query_result[0].keys())
|
||||
|
@ -58,8 +58,8 @@ class MHTMLLoader(BaseLoader):
|
||||
parts = [message]
|
||||
|
||||
for part in parts:
|
||||
if part.get_content_type() == "text/html":
|
||||
html = part.get_payload(decode=True).decode()
|
||||
if part.get_content_type() == "text/html": # type: ignore[union-attr]
|
||||
html = part.get_payload(decode=True).decode() # type: ignore[union-attr]
|
||||
|
||||
soup = BeautifulSoup(html, **self.bs_kwargs)
|
||||
text = soup.get_text(self.get_text_separator)
|
||||
|
@ -31,7 +31,7 @@ class _OneNoteGraphSettings(BaseSettings):
|
||||
class OneNoteLoader(BaseLoader, BaseModel):
|
||||
"""Load pages from OneNote notebooks."""
|
||||
|
||||
settings: _OneNoteGraphSettings = Field(default_factory=_OneNoteGraphSettings)
|
||||
settings: _OneNoteGraphSettings = Field(default_factory=_OneNoteGraphSettings) # type: ignore[arg-type]
|
||||
"""Settings for the Microsoft Graph API client."""
|
||||
auth_with_token: bool = False
|
||||
"""Whether to authenticate with a token or not. Defaults to False."""
|
||||
|
@ -691,7 +691,7 @@ class AmazonTextractPDFLoader(BasePDFLoader):
|
||||
# raises ValueError when multi-page and not on S3"""
|
||||
|
||||
if self.web_path and self._is_s3_url(self.web_path):
|
||||
blob = Blob(path=self.web_path) # type: ignore[misc]
|
||||
blob = Blob(path=self.web_path) # type: ignore[call-arg] # type: ignore[misc]
|
||||
else:
|
||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
||||
if AmazonTextractPDFLoader._get_number_of_pages(blob) > 1:
|
||||
|
@ -28,8 +28,8 @@ class PubMedLoader(BaseLoader):
|
||||
"""
|
||||
self.query = query
|
||||
self.load_max_docs = load_max_docs
|
||||
self._client = PubMedAPIWrapper(
|
||||
top_k_results=load_max_docs,
|
||||
self._client = PubMedAPIWrapper( # type: ignore[call-arg]
|
||||
top_k_results=load_max_docs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
|
@ -111,7 +111,7 @@ class SnowflakeLoader(BaseLoader):
|
||||
query_result = self._execute_query()
|
||||
if isinstance(query_result, Exception):
|
||||
print(f"An error occurred during the query: {query_result}") # noqa: T201
|
||||
return []
|
||||
return [] # type: ignore[return-value]
|
||||
page_content_columns, metadata_columns = self._get_columns(query_result)
|
||||
if "*" in page_content_columns:
|
||||
page_content_columns = list(query_result[0].keys())
|
||||
|
@ -66,10 +66,10 @@ class TensorflowDatasetLoader(BaseLoader):
|
||||
] = sample_to_document_function
|
||||
"""Custom function that transform a dataset sample into a Document."""
|
||||
|
||||
self._tfds_client = TensorflowDatasets(
|
||||
self._tfds_client = TensorflowDatasets( # type: ignore[call-arg]
|
||||
dataset_name=self.dataset_name,
|
||||
split_name=self.split_name,
|
||||
load_max_docs=self.load_max_docs,
|
||||
load_max_docs=self.load_max_docs, # type: ignore[arg-type]
|
||||
sample_to_document_function=self.sample_to_document_function,
|
||||
)
|
||||
|
||||
|
@ -32,7 +32,7 @@ class WeatherDataLoader(BaseLoader):
|
||||
def from_params(
|
||||
cls, places: Sequence[str], *, openweathermap_api_key: Optional[str] = None
|
||||
) -> WeatherDataLoader:
|
||||
client = OpenWeatherMapAPIWrapper(openweathermap_api_key=openweathermap_api_key)
|
||||
client = OpenWeatherMapAPIWrapper(openweathermap_api_key=openweathermap_api_key) # type: ignore[call-arg]
|
||||
return cls(client, places)
|
||||
|
||||
def lazy_load(
|
||||
|
@ -50,10 +50,10 @@ class WikipediaLoader(BaseLoader):
|
||||
A list of Document objects representing the loaded
|
||||
Wikipedia pages.
|
||||
"""
|
||||
client = WikipediaAPIWrapper(
|
||||
client = WikipediaAPIWrapper( # type: ignore[call-arg]
|
||||
lang=self.lang,
|
||||
top_k_results=self.load_max_docs,
|
||||
load_all_available_meta=self.load_all_available_meta,
|
||||
doc_content_chars_max=self.doc_content_chars_max,
|
||||
top_k_results=self.load_max_docs, # type: ignore[arg-type]
|
||||
load_all_available_meta=self.load_all_available_meta, # type: ignore[arg-type]
|
||||
doc_content_chars_max=self.doc_content_chars_max, # type: ignore[arg-type]
|
||||
)
|
||||
yield from client.load(self.query)
|
||||
|
@ -312,7 +312,7 @@ class SQLRecordManager(RecordManager):
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
@ -387,7 +387,7 @@ class SQLRecordManager(RecordManager):
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
@ -470,7 +470,7 @@ class SQLRecordManager(RecordManager):
|
||||
if limit:
|
||||
query = query.limit(limit) # type: ignore[attr-defined]
|
||||
records = query.all() # type: ignore[attr-defined]
|
||||
return [r.key for r in records]
|
||||
return [r.key for r in records] # type: ignore[misc]
|
||||
|
||||
async def alist_keys(
|
||||
self,
|
||||
|
@ -282,6 +282,6 @@ class AlephAlpha(LLM):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
aa = AlephAlpha()
|
||||
aa = AlephAlpha() # type: ignore[call-arg]
|
||||
|
||||
print(aa.invoke("How are you?")) # noqa: T201
|
||||
|
@ -490,7 +490,7 @@ class Databricks(LLM):
|
||||
task=self.task,
|
||||
)
|
||||
elif self.cluster_id and self.cluster_driver_port:
|
||||
self._client = _DatabricksClusterDriverProxyClient(
|
||||
self._client = _DatabricksClusterDriverProxyClient( # type: ignore[call-arg]
|
||||
host=self.host,
|
||||
api_token=self.api_token,
|
||||
cluster_id=self.cluster_id,
|
||||
|
@ -87,7 +87,7 @@ class MinimaxCommon(BaseModel):
|
||||
"MINIMAX_API_HOST",
|
||||
default="https://api.minimax.chat",
|
||||
)
|
||||
values["_client"] = _MinimaxEndpointClient(
|
||||
values["_client"] = _MinimaxEndpointClient( # type: ignore[call-arg]
|
||||
host=values["minimax_api_host"],
|
||||
api_key=values["minimax_api_key"],
|
||||
group_id=values["minimax_group_id"],
|
||||
|
@ -423,7 +423,7 @@ class Ollama(BaseLLM, _OllamaCommon):
|
||||
**kwargs,
|
||||
)
|
||||
generations.append([final_chunk])
|
||||
return LLMResult(generations=generations)
|
||||
return LLMResult(generations=generations) # type: ignore[arg-type]
|
||||
|
||||
async def _agenerate( # type: ignore[override]
|
||||
self,
|
||||
@ -459,7 +459,7 @@ class Ollama(BaseLLM, _OllamaCommon):
|
||||
**kwargs,
|
||||
)
|
||||
generations.append([final_chunk])
|
||||
return LLMResult(generations=generations)
|
||||
return LLMResult(generations=generations) # type: ignore[arg-type]
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
@ -155,7 +155,7 @@ class OpenLLM(LLM):
|
||||
client = client_cls(server_url, timeout)
|
||||
|
||||
super().__init__(
|
||||
**{
|
||||
**{ # type: ignore[arg-type]
|
||||
"server_url": server_url,
|
||||
"timeout": timeout,
|
||||
"server_type": server_type,
|
||||
@ -180,7 +180,7 @@ class OpenLLM(LLM):
|
||||
**llm_kwargs,
|
||||
)
|
||||
super().__init__(
|
||||
**{
|
||||
**{ # type: ignore[arg-type]
|
||||
"model_name": model_name,
|
||||
"model_id": model_id,
|
||||
"embedded": embedded,
|
||||
|
@ -274,10 +274,10 @@ class _SparkLLMClient:
|
||||
on_close=self.on_close,
|
||||
on_open=self.on_open,
|
||||
)
|
||||
ws.messages = messages
|
||||
ws.user_id = user_id
|
||||
ws.model_kwargs = self.model_kwargs if model_kwargs is None else model_kwargs
|
||||
ws.streaming = streaming
|
||||
ws.messages = messages # type: ignore[attr-defined]
|
||||
ws.user_id = user_id # type: ignore[attr-defined]
|
||||
ws.model_kwargs = self.model_kwargs if model_kwargs is None else model_kwargs # type: ignore[attr-defined]
|
||||
ws.streaming = streaming # type: ignore[attr-defined]
|
||||
ws.run_forever()
|
||||
|
||||
def arun(
|
||||
|
@ -330,13 +330,13 @@ class TextGen(LLM):
|
||||
result = websocket_client.recv()
|
||||
result = json.loads(result)
|
||||
|
||||
if result["event"] == "text_stream":
|
||||
if result["event"] == "text_stream": # type: ignore[call-overload, index]
|
||||
chunk = GenerationChunk(
|
||||
text=result["text"],
|
||||
text=result["text"], # type: ignore[call-overload, index]
|
||||
generation_info=None,
|
||||
)
|
||||
yield chunk
|
||||
elif result["event"] == "stream_end":
|
||||
elif result["event"] == "stream_end": # type: ignore[call-overload, index]
|
||||
websocket_client.close()
|
||||
return
|
||||
|
||||
@ -403,13 +403,13 @@ class TextGen(LLM):
|
||||
result = websocket_client.recv()
|
||||
result = json.loads(result)
|
||||
|
||||
if result["event"] == "text_stream":
|
||||
if result["event"] == "text_stream": # type: ignore[call-overload, index]
|
||||
chunk = GenerationChunk(
|
||||
text=result["text"],
|
||||
text=result["text"], # type: ignore[call-overload, index]
|
||||
generation_info=None,
|
||||
)
|
||||
yield chunk
|
||||
elif result["event"] == "stream_end":
|
||||
elif result["event"] == "stream_end": # type: ignore[call-overload, index]
|
||||
websocket_client.close()
|
||||
return
|
||||
|
||||
|
@ -137,7 +137,7 @@ class TitanTakeoff(LLM):
|
||||
ImportError: If you haven't installed takeoff-client, you will
|
||||
get an ImportError. To remedy run `pip install 'takeoff-client==0.4.0'`
|
||||
"""
|
||||
super().__init__(
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
base_url=base_url, port=port, mgmt_port=mgmt_port, streaming=streaming
|
||||
)
|
||||
try:
|
||||
|
@ -363,7 +363,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
generations.append(
|
||||
[self._response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
return LLMResult(generations=generations) # type: ignore[arg-type]
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
@ -100,7 +100,7 @@ class Xinference(LLM):
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
super().__init__(
|
||||
**{
|
||||
**{ # type: ignore[arg-type]
|
||||
"server_url": server_url,
|
||||
"model_uid": model_uid,
|
||||
"model_kwargs": model_kwargs,
|
||||
|
@ -21,7 +21,7 @@ class BreebsRetriever(BaseRetriever):
|
||||
url = "https://breebs.promptbreeders.com/knowledge"
|
||||
|
||||
def __init__(self, breeb_key: str):
|
||||
super().__init__(breeb_key=breeb_key)
|
||||
super().__init__(breeb_key=breeb_key) # type: ignore[call-arg]
|
||||
self.breeb_key = breeb_key
|
||||
|
||||
def _get_relevant_documents(
|
||||
|
@ -23,7 +23,7 @@ class DriaRetriever(BaseRetriever):
|
||||
contract_id: The contract ID of the knowledge base to interact with.
|
||||
"""
|
||||
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
|
||||
super().__init__(api_wrapper=api_wrapper, **kwargs)
|
||||
super().__init__(api_wrapper=api_wrapper, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
def create_knowledge_base(
|
||||
self,
|
||||
|
@ -73,7 +73,7 @@ class NeuralDBRetriever(BaseRetriever):
|
||||
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB(**model_kwargs))
|
||||
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB(**model_kwargs)) # type: ignore[arg-type]
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
@ -108,7 +108,7 @@ class NeuralDBRetriever(BaseRetriever):
|
||||
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB.from_checkpoint(checkpoint))
|
||||
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[arg-type]
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
|
@ -27,7 +27,7 @@ class ArxivQueryRun(BaseTool):
|
||||
"from scientific articles on arxiv.org. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper)
|
||||
api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper) # type: ignore[arg-type]
|
||||
args_schema: Type[BaseModel] = ArxivInput
|
||||
|
||||
def _run(
|
||||
|
@ -72,7 +72,7 @@ class HuggingFaceTextToSpeechModelInference(BaseTool):
|
||||
f"Invalid value for 'file_naming_func': {file_naming_func}"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
model=model,
|
||||
file_extension=file_extension,
|
||||
api_url=f"{self._HUGGINGFACE_API_URL_ROOT}/{model}",
|
||||
|
@ -19,7 +19,7 @@ from langchain_community.utilities.github import GitHubAPIWrapper
|
||||
class GitHubAction(BaseTool):
|
||||
"""Tool for interacting with the GitHub API."""
|
||||
|
||||
api_wrapper: GitHubAPIWrapper = Field(default_factory=GitHubAPIWrapper)
|
||||
api_wrapper: GitHubAPIWrapper = Field(default_factory=GitHubAPIWrapper) # type: ignore[arg-type]
|
||||
mode: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
@ -19,7 +19,7 @@ from langchain_community.utilities.gitlab import GitLabAPIWrapper
|
||||
class GitLabAction(BaseTool):
|
||||
"""Tool for interacting with the GitLab API."""
|
||||
|
||||
api_wrapper: GitLabAPIWrapper = Field(default_factory=GitLabAPIWrapper)
|
||||
api_wrapper: GitLabAPIWrapper = Field(default_factory=GitLabAPIWrapper) # type: ignore[arg-type]
|
||||
mode: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
@ -34,4 +34,4 @@ class GmailBaseTool(BaseTool):
|
||||
Returns:
|
||||
A tool.
|
||||
"""
|
||||
return cls(service=api_resource)
|
||||
return cls(service=api_resource) # type: ignore[call-arg]
|
||||
|
@ -53,10 +53,10 @@ class GmailGetMessage(GmailBaseTool):
|
||||
ctype = part.get_content_type()
|
||||
cdispo = str(part.get("Content-Disposition"))
|
||||
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||
message_body = part.get_payload(decode=True).decode("utf-8")
|
||||
message_body = part.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr]
|
||||
break
|
||||
else:
|
||||
message_body = email_msg.get_payload(decode=True).decode("utf-8")
|
||||
message_body = email_msg.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr]
|
||||
|
||||
body = clean_email_body(message_body)
|
||||
|
||||
|
@ -99,14 +99,14 @@ class GmailSearch(GmailBaseTool):
|
||||
cdispo = str(part.get("Content-Disposition"))
|
||||
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||
try:
|
||||
message_body = part.get_payload(decode=True).decode("utf-8")
|
||||
message_body = part.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr]
|
||||
except UnicodeDecodeError:
|
||||
message_body = part.get_payload(decode=True).decode(
|
||||
message_body = part.get_payload(decode=True).decode( # type: ignore[union-attr]
|
||||
"latin-1"
|
||||
)
|
||||
break
|
||||
else:
|
||||
message_body = email_msg.get_payload(decode=True).decode("utf-8")
|
||||
message_body = email_msg.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr]
|
||||
|
||||
body = clean_email_body(message_body)
|
||||
|
||||
|
@ -31,7 +31,7 @@ class GooglePlacesTool(BaseTool):
|
||||
"discover addressed from ambiguous text. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: GooglePlacesAPIWrapper = Field(default_factory=GooglePlacesAPIWrapper)
|
||||
api_wrapper: GooglePlacesAPIWrapper = Field(default_factory=GooglePlacesAPIWrapper) # type: ignore[arg-type]
|
||||
args_schema: Type[BaseModel] = GooglePlacesSchema
|
||||
|
||||
def _run(
|
||||
|
@ -30,7 +30,7 @@ from langchain_community.utilities.jira import JiraAPIWrapper
|
||||
class JiraAction(BaseTool):
|
||||
"""Tool that queries the Atlassian Jira API."""
|
||||
|
||||
api_wrapper: JiraAPIWrapper = Field(default_factory=JiraAPIWrapper)
|
||||
api_wrapper: JiraAPIWrapper = Field(default_factory=JiraAPIWrapper) # type: ignore[arg-type]
|
||||
mode: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
|
@ -75,7 +75,7 @@ class NucliaUnderstandingAPI(BaseTool):
|
||||
else:
|
||||
self._config["NUA_KEY"] = key
|
||||
self._config["enable_ml"] = enable_ml
|
||||
super().__init__()
|
||||
super().__init__() # type: ignore[call-arg]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
@ -530,7 +530,7 @@ class APIOperation(BaseModel):
|
||||
description=description or "",
|
||||
base_url=spec.base_url,
|
||||
path=path,
|
||||
method=method,
|
||||
method=method, # type: ignore[arg-type]
|
||||
properties=properties,
|
||||
request_body=api_request_body,
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ class OpenWeatherMapQueryRun(BaseTool):
|
||||
"""Tool that queries the OpenWeatherMap API."""
|
||||
|
||||
api_wrapper: OpenWeatherMapAPIWrapper = Field(
|
||||
default_factory=OpenWeatherMapAPIWrapper
|
||||
default_factory=OpenWeatherMapAPIWrapper # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
name: str = "open_weather_map"
|
||||
|
@ -54,4 +54,4 @@ class BaseBrowserTool(BaseTool):
|
||||
) -> BaseBrowserTool:
|
||||
"""Instantiate the tool."""
|
||||
lazy_import_playwright_browsers()
|
||||
return cls(sync_browser=sync_browser, async_browser=async_browser)
|
||||
return cls(sync_browser=sync_browser, async_browser=async_browser) # type: ignore[call-arg]
|
||||
|
@ -18,7 +18,7 @@ class PubmedQueryRun(BaseTool):
|
||||
"from biomedical literature, MEDLINE, life science journals, and online books. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper)
|
||||
api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper) # type: ignore[arg-type]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
@ -42,7 +42,7 @@ class RedditSearchRun(BaseTool):
|
||||
"A tool that searches for posts on Reddit."
|
||||
"Useful when you need to know post information on a subreddit."
|
||||
)
|
||||
api_wrapper: RedditSearchAPIWrapper = Field(default_factory=RedditSearchAPIWrapper)
|
||||
api_wrapper: RedditSearchAPIWrapper = Field(default_factory=RedditSearchAPIWrapper) # type: ignore[arg-type]
|
||||
args_schema: Type[BaseModel] = RedditSearchSchema
|
||||
|
||||
def _run(
|
||||
|
@ -23,7 +23,7 @@ class SceneXplainTool(BaseTool):
|
||||
"for an image. The input can be an image file of any format, and "
|
||||
"the output will be a text description that covers every detail of the image."
|
||||
)
|
||||
api_wrapper: SceneXplainAPIWrapper = Field(default_factory=SceneXplainAPIWrapper)
|
||||
api_wrapper: SceneXplainAPIWrapper = Field(default_factory=SceneXplainAPIWrapper) # type: ignore[arg-type]
|
||||
|
||||
def _run(
|
||||
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
|
@ -26,7 +26,7 @@ class SemanticScholarQueryRun(BaseTool):
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: SemanticScholarAPIWrapper = Field(
|
||||
default_factory=SemanticScholarAPIWrapper
|
||||
default_factory=SemanticScholarAPIWrapper # type: ignore[arg-type]
|
||||
)
|
||||
args_schema: Type[BaseModel] = SemantscholarInput
|
||||
|
||||
|
@ -97,7 +97,7 @@ class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
llm=values.get("llm"), # type: ignore[arg-type]
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["query"]
|
||||
),
|
||||
|
@ -122,7 +122,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
llm=values.get("llm"), # type: ignore[arg-type]
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
),
|
||||
|
@ -27,7 +27,7 @@ class TavilySearchResults(BaseTool):
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper)
|
||||
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type]
|
||||
max_results: int = 5
|
||||
args_schema: Type[BaseModel] = TavilyInput
|
||||
|
||||
@ -70,7 +70,7 @@ class TavilyAnswer(BaseTool):
|
||||
"Input should be a search query. "
|
||||
"This returns only the answer - not the original source data."
|
||||
)
|
||||
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper)
|
||||
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type]
|
||||
args_schema: Type[BaseModel] = TavilyInput
|
||||
|
||||
def _run(
|
||||
|
@ -95,7 +95,7 @@ class ZapierNLARunAction(BaseTool):
|
||||
|
||||
"""
|
||||
|
||||
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)
|
||||
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) # type: ignore[arg-type]
|
||||
action_id: str
|
||||
params: Optional[dict] = None
|
||||
base_prompt: str = BASE_ZAPIER_TOOL_PROMPT
|
||||
@ -174,7 +174,7 @@ class ZapierNLAListActions(BaseTool):
|
||||
description: str = BASE_ZAPIER_TOOL_PROMPT + (
|
||||
"This tool returns a list of the user's exposed actions."
|
||||
)
|
||||
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)
|
||||
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) # type: ignore[arg-type]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
@ -214,4 +214,4 @@ def _check_for_cluster(redis_client: RedisType) -> bool:
|
||||
def _redis_cluster_client(redis_url: str, **kwargs: Any) -> RedisType:
|
||||
from redis.cluster import RedisCluster
|
||||
|
||||
return RedisCluster.from_url(redis_url, **kwargs)
|
||||
return RedisCluster.from_url(redis_url, **kwargs) # type: ignore[return-value]
|
||||
|
78
libs/community/poetry.lock
generated
78
libs/community/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
@ -3455,6 +3455,7 @@ files = [
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
||||
@ -4740,52 +4741,49 @@ para = ">=0.0.1"
|
||||
|
||||
[[package]]
|
||||
name = "mypy"
|
||||
version = "0.991"
|
||||
version = "1.10.0"
|
||||
description = "Optional static typing for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "mypy-0.991-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7d17e0a9707d0772f4a7b878f04b4fd11f6f5bcb9b3813975a9b13c9332153ab"},
|
||||
{file = "mypy-0.991-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0714258640194d75677e86c786e80ccf294972cc76885d3ebbb560f11db0003d"},
|
||||
{file = "mypy-0.991-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c8f3be99e8a8bd403caa8c03be619544bc2c77a7093685dcf308c6b109426c6"},
|
||||
{file = "mypy-0.991-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9ec663ed6c8f15f4ae9d3c04c989b744436c16d26580eaa760ae9dd5d662eb"},
|
||||
{file = "mypy-0.991-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4307270436fd7694b41f913eb09210faff27ea4979ecbcd849e57d2da2f65305"},
|
||||
{file = "mypy-0.991-cp310-cp310-win_amd64.whl", hash = "sha256:901c2c269c616e6cb0998b33d4adbb4a6af0ac4ce5cd078afd7bc95830e62c1c"},
|
||||
{file = "mypy-0.991-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d13674f3fb73805ba0c45eb6c0c3053d218aa1f7abead6e446d474529aafc372"},
|
||||
{file = "mypy-0.991-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1c8cd4fb70e8584ca1ed5805cbc7c017a3d1a29fb450621089ffed3e99d1857f"},
|
||||
{file = "mypy-0.991-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:209ee89fbb0deed518605edddd234af80506aec932ad28d73c08f1400ef80a33"},
|
||||
{file = "mypy-0.991-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37bd02ebf9d10e05b00d71302d2c2e6ca333e6c2a8584a98c00e038db8121f05"},
|
||||
{file = "mypy-0.991-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:26efb2fcc6b67e4d5a55561f39176821d2adf88f2745ddc72751b7890f3194ad"},
|
||||
{file = "mypy-0.991-cp311-cp311-win_amd64.whl", hash = "sha256:3a700330b567114b673cf8ee7388e949f843b356a73b5ab22dd7cff4742a5297"},
|
||||
{file = "mypy-0.991-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1f7d1a520373e2272b10796c3ff721ea1a0712288cafaa95931e66aa15798813"},
|
||||
{file = "mypy-0.991-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:641411733b127c3e0dab94c45af15fea99e4468f99ac88b39efb1ad677da5711"},
|
||||
{file = "mypy-0.991-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3d80e36b7d7a9259b740be6d8d906221789b0d836201af4234093cae89ced0cd"},
|
||||
{file = "mypy-0.991-cp37-cp37m-win_amd64.whl", hash = "sha256:e62ebaad93be3ad1a828a11e90f0e76f15449371ffeecca4a0a0b9adc99abcef"},
|
||||
{file = "mypy-0.991-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b86ce2c1866a748c0f6faca5232059f881cda6dda2a893b9a8373353cfe3715a"},
|
||||
{file = "mypy-0.991-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac6e503823143464538efda0e8e356d871557ef60ccd38f8824a4257acc18d93"},
|
||||
{file = "mypy-0.991-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cca5adf694af539aeaa6ac633a7afe9bbd760df9d31be55ab780b77ab5ae8bf"},
|
||||
{file = "mypy-0.991-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12c56bf73cdab116df96e4ff39610b92a348cc99a1307e1da3c3768bbb5b135"},
|
||||
{file = "mypy-0.991-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:652b651d42f155033a1967739788c436491b577b6a44e4c39fb340d0ee7f0d70"},
|
||||
{file = "mypy-0.991-cp38-cp38-win_amd64.whl", hash = "sha256:4175593dc25d9da12f7de8de873a33f9b2b8bdb4e827a7cae952e5b1a342e243"},
|
||||
{file = "mypy-0.991-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:98e781cd35c0acf33eb0295e8b9c55cdbef64fcb35f6d3aa2186f289bed6e80d"},
|
||||
{file = "mypy-0.991-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6d7464bac72a85cb3491c7e92b5b62f3dcccb8af26826257760a552a5e244aa5"},
|
||||
{file = "mypy-0.991-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c9166b3f81a10cdf9b49f2d594b21b31adadb3d5e9db9b834866c3258b695be3"},
|
||||
{file = "mypy-0.991-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8472f736a5bfb159a5e36740847808f6f5b659960115ff29c7cecec1741c648"},
|
||||
{file = "mypy-0.991-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e80e758243b97b618cdf22004beb09e8a2de1af481382e4d84bc52152d1c476"},
|
||||
{file = "mypy-0.991-cp39-cp39-win_amd64.whl", hash = "sha256:74e259b5c19f70d35fcc1ad3d56499065c601dfe94ff67ae48b85596b9ec1461"},
|
||||
{file = "mypy-0.991-py3-none-any.whl", hash = "sha256:de32edc9b0a7e67c2775e574cb061a537660e51210fbf6006b0b36ea695ae9bb"},
|
||||
{file = "mypy-0.991.tar.gz", hash = "sha256:3c0165ba8f354a6d9881809ef29f1a9318a236a6d81c690094c5df32107bde06"},
|
||||
{file = "mypy-1.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:da1cbf08fb3b851ab3b9523a884c232774008267b1f83371ace57f412fe308c2"},
|
||||
{file = "mypy-1.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:12b6bfc1b1a66095ab413160a6e520e1dc076a28f3e22f7fb25ba3b000b4ef99"},
|
||||
{file = "mypy-1.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e36fb078cce9904c7989b9693e41cb9711e0600139ce3970c6ef814b6ebc2b2"},
|
||||
{file = "mypy-1.10.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2b0695d605ddcd3eb2f736cd8b4e388288c21e7de85001e9f85df9187f2b50f9"},
|
||||
{file = "mypy-1.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:cd777b780312ddb135bceb9bc8722a73ec95e042f911cc279e2ec3c667076051"},
|
||||
{file = "mypy-1.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3be66771aa5c97602f382230165b856c231d1277c511c9a8dd058be4784472e1"},
|
||||
{file = "mypy-1.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8b2cbaca148d0754a54d44121b5825ae71868c7592a53b7292eeb0f3fdae95ee"},
|
||||
{file = "mypy-1.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ec404a7cbe9fc0e92cb0e67f55ce0c025014e26d33e54d9e506a0f2d07fe5de"},
|
||||
{file = "mypy-1.10.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e22e1527dc3d4aa94311d246b59e47f6455b8729f4968765ac1eacf9a4760bc7"},
|
||||
{file = "mypy-1.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:a87dbfa85971e8d59c9cc1fcf534efe664d8949e4c0b6b44e8ca548e746a8d53"},
|
||||
{file = "mypy-1.10.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a781f6ad4bab20eef8b65174a57e5203f4be627b46291f4589879bf4e257b97b"},
|
||||
{file = "mypy-1.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b808e12113505b97d9023b0b5e0c0705a90571c6feefc6f215c1df9381256e30"},
|
||||
{file = "mypy-1.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f55583b12156c399dce2df7d16f8a5095291354f1e839c252ec6c0611e86e2e"},
|
||||
{file = "mypy-1.10.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4cf18f9d0efa1b16478c4c129eabec36148032575391095f73cae2e722fcf9d5"},
|
||||
{file = "mypy-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:bc6ac273b23c6b82da3bb25f4136c4fd42665f17f2cd850771cb600bdd2ebeda"},
|
||||
{file = "mypy-1.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9fd50226364cd2737351c79807775136b0abe084433b55b2e29181a4c3c878c0"},
|
||||
{file = "mypy-1.10.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f90cff89eea89273727d8783fef5d4a934be2fdca11b47def50cf5d311aff727"},
|
||||
{file = "mypy-1.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcfc70599efde5c67862a07a1aaf50e55bce629ace26bb19dc17cece5dd31ca4"},
|
||||
{file = "mypy-1.10.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:075cbf81f3e134eadaf247de187bd604748171d6b79736fa9b6c9685b4083061"},
|
||||
{file = "mypy-1.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:3f298531bca95ff615b6e9f2fc0333aae27fa48052903a0ac90215021cdcfa4f"},
|
||||
{file = "mypy-1.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fa7ef5244615a2523b56c034becde4e9e3f9b034854c93639adb667ec9ec2976"},
|
||||
{file = "mypy-1.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3236a4c8f535a0631f85f5fcdffba71c7feeef76a6002fcba7c1a8e57c8be1ec"},
|
||||
{file = "mypy-1.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a2b5cdbb5dd35aa08ea9114436e0d79aceb2f38e32c21684dcf8e24e1e92821"},
|
||||
{file = "mypy-1.10.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92f93b21c0fe73dc00abf91022234c79d793318b8a96faac147cd579c1671746"},
|
||||
{file = "mypy-1.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:28d0e038361b45f099cc086d9dd99c15ff14d0188f44ac883010e172ce86c38a"},
|
||||
{file = "mypy-1.10.0-py3-none-any.whl", hash = "sha256:f8c083976eb530019175aabadb60921e73b4f45736760826aa1689dda8208aee"},
|
||||
{file = "mypy-1.10.0.tar.gz", hash = "sha256:3d087fcbec056c4ee34974da493a826ce316947485cef3901f511848e687c131"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mypy-extensions = ">=0.4.3"
|
||||
mypy-extensions = ">=1.0.0"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
typing-extensions = ">=3.10"
|
||||
typing-extensions = ">=4.1.0"
|
||||
|
||||
[package.extras]
|
||||
dmypy = ["psutil (>=4.0)"]
|
||||
install-types = ["pip"]
|
||||
python2 = ["typed-ast (>=1.4.0,<2)"]
|
||||
mypyc = ["setuptools (>=50)"]
|
||||
reports = ["lxml"]
|
||||
|
||||
[[package]]
|
||||
@ -6107,8 +6105,6 @@ files = [
|
||||
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
|
||||
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
|
||||
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
|
||||
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
|
||||
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
|
||||
@ -6151,7 +6147,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
|
||||
@ -6160,8 +6155,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
|
||||
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
|
||||
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
|
||||
@ -7159,7 +7152,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -10092,4 +10084,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "657797396fff40d21216962c43171ba5605725e1047339443d7f346e2f5f0b80"
|
||||
content-hash = "66b00eea10e05312fcafa5f68e4b863942c344051bdd93b575b0b26fce9fce21"
|
||||
|
@ -174,7 +174,7 @@ optional = true
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
mypy = "^1"
|
||||
types-pyyaml = "^6.0.12.2"
|
||||
types-requests = "^2.28.11.5"
|
||||
types-toml = "^0.10.8.1"
|
||||
|
@ -15,7 +15,7 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@pytest.mark.scheduled
|
||||
def test_anthropic_call() -> None:
|
||||
"""Test valid call to anthropic."""
|
||||
chat = ChatAnthropic(model="test")
|
||||
chat = ChatAnthropic(model="test") # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@ -25,7 +25,7 @@ def test_anthropic_call() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_anthropic_generate() -> None:
|
||||
"""Test generate method of anthropic."""
|
||||
chat = ChatAnthropic(model="test")
|
||||
chat = ChatAnthropic(model="test") # type: ignore[call-arg]
|
||||
chat_messages: List[List[BaseMessage]] = [
|
||||
[HumanMessage(content="How many toes do dogs have?")]
|
||||
]
|
||||
@ -42,7 +42,7 @@ def test_anthropic_generate() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_anthropic_streaming() -> None:
|
||||
"""Test streaming tokens from anthropic."""
|
||||
chat = ChatAnthropic(model="test", streaming=True)
|
||||
chat = ChatAnthropic(model="test", streaming=True) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@ -54,7 +54,7 @@ def test_anthropic_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatAnthropic(
|
||||
chat = ChatAnthropic( # type: ignore[call-arg]
|
||||
model="test",
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
@ -70,7 +70,7 @@ async def test_anthropic_async_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatAnthropic(
|
||||
chat = ChatAnthropic( # type: ignore[call-arg]
|
||||
model="test",
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
|
@ -20,7 +20,7 @@ DEPLOYMENT_NAME = os.environ.get(
|
||||
|
||||
|
||||
def _get_llm(**kwargs: Any) -> AzureChatOpenAI:
|
||||
return AzureChatOpenAI(
|
||||
return AzureChatOpenAI( # type: ignore[call-arg]
|
||||
deployment_name=DEPLOYMENT_NAME,
|
||||
openai_api_version=OPENAI_API_VERSION,
|
||||
azure_endpoint=OPENAI_API_BASE,
|
||||
|
@ -23,7 +23,7 @@ def test_chat_baichuan_default_non_streaming() -> None:
|
||||
|
||||
|
||||
def test_chat_baichuan_turbo() -> None:
|
||||
chat = ChatBaichuan(model="Baichuan2-Turbo", streaming=True)
|
||||
chat = ChatBaichuan(model="Baichuan2-Turbo", streaming=True) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@ -31,7 +31,7 @@ def test_chat_baichuan_turbo() -> None:
|
||||
|
||||
|
||||
def test_chat_baichuan_turbo_non_streaming() -> None:
|
||||
chat = ChatBaichuan(model="Baichuan2-Turbo")
|
||||
chat = ChatBaichuan(model="Baichuan2-Turbo") # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
|
@ -17,7 +17,7 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
@pytest.fixture
|
||||
def chat() -> BedrockChat:
|
||||
return BedrockChat(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0})
|
||||
return BedrockChat(model_id="anthropic.claude-v2", model_kwargs={"temperature": 0}) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@ -63,7 +63,7 @@ def test_chat_bedrock_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = BedrockChat(
|
||||
chat = BedrockChat( # type: ignore[call-arg]
|
||||
model_id="anthropic.claude-v2",
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
@ -92,7 +92,7 @@ def test_chat_bedrock_streaming_generation_info() -> None:
|
||||
|
||||
callback = _FakeCallback()
|
||||
callback_manager = CallbackManager([callback])
|
||||
chat = BedrockChat(
|
||||
chat = BedrockChat( # type: ignore[call-arg]
|
||||
model_id="anthropic.claude-v2",
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ from langchain_community.chat_models.coze import ChatCoze
|
||||
def test_chat_coze_default() -> None:
|
||||
chat = ChatCoze(
|
||||
coze_api_base="https://api.coze.com",
|
||||
coze_api_key="pat_...",
|
||||
coze_api_key="pat_...", # type: ignore[arg-type]
|
||||
bot_id="7....",
|
||||
user="123",
|
||||
conversation_id="",
|
||||
@ -24,7 +24,7 @@ def test_chat_coze_default() -> None:
|
||||
def test_chat_coze_default_non_streaming() -> None:
|
||||
chat = ChatCoze(
|
||||
coze_api_base="https://api.coze.com",
|
||||
coze_api_key="pat_...",
|
||||
coze_api_key="pat_...", # type: ignore[arg-type]
|
||||
bot_id="7....",
|
||||
user="123",
|
||||
conversation_id="",
|
||||
|
@ -12,7 +12,7 @@ from langchain_community.chat_models.dappier import (
|
||||
@pytest.mark.scheduled
|
||||
def test_dappier_chat() -> None:
|
||||
"""Test ChatDappierAI wrapper."""
|
||||
chat = ChatDappierAI(
|
||||
chat = ChatDappierAI( # type: ignore[call-arg]
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation",
|
||||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
)
|
||||
@ -25,7 +25,7 @@ def test_dappier_chat() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_dappier_generate() -> None:
|
||||
"""Test generate method of Dappier AI."""
|
||||
chat = ChatDappierAI(
|
||||
chat = ChatDappierAI( # type: ignore[call-arg]
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation",
|
||||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
)
|
||||
@ -45,7 +45,7 @@ def test_dappier_generate() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_dappier_agenerate() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatDappierAI(
|
||||
chat = ChatDappierAI( # type: ignore[call-arg]
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation",
|
||||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ from langchain_community.chat_models.edenai import (
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_edenai() -> None:
|
||||
"""Test ChatEdenAI wrapper."""
|
||||
chat = ChatEdenAI(
|
||||
chat = ChatEdenAI( # type: ignore[call-arg]
|
||||
provider="openai", model="gpt-3.5-turbo", temperature=0, max_tokens=1000
|
||||
)
|
||||
message = HumanMessage(content="Who are you ?")
|
||||
@ -25,7 +25,7 @@ def test_chat_edenai() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_edenai_generate() -> None:
|
||||
"""Test generate method of edenai."""
|
||||
chat = ChatEdenAI(provider="google")
|
||||
chat = ChatEdenAI(provider="google") # type: ignore[call-arg]
|
||||
chat_messages: List[List[BaseMessage]] = [
|
||||
[HumanMessage(content="What is the meaning of life?")]
|
||||
]
|
||||
@ -42,7 +42,7 @@ def test_edenai_generate() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_edenai_async_generate() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatEdenAI(provider="google", max_tokens=50)
|
||||
chat = ChatEdenAI(provider="google", max_tokens=50) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
result: LLMResult = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(result, LLMResult)
|
||||
@ -55,7 +55,7 @@ async def test_edenai_async_generate() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_edenai_streaming() -> None:
|
||||
"""Test streaming EdenAI chat."""
|
||||
llm = ChatEdenAI(provider="openai", max_tokens=50)
|
||||
llm = ChatEdenAI(provider="openai", max_tokens=50) # type: ignore[call-arg]
|
||||
|
||||
for chunk in llm.stream("Generate a high fantasy story."):
|
||||
assert isinstance(chunk.content, str)
|
||||
@ -64,7 +64,7 @@ def test_edenai_streaming() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_edenai_astream() -> None:
|
||||
"""Test streaming from EdenAI."""
|
||||
llm = ChatEdenAI(provider="openai", max_tokens=50)
|
||||
llm = ChatEdenAI(provider="openai", max_tokens=50) # type: ignore[call-arg]
|
||||
|
||||
async for token in llm.astream("Generate a high fantasy story."):
|
||||
assert isinstance(token.content, str)
|
||||
|
@ -12,7 +12,7 @@ from langchain_community.chat_models import ChatGooglePalm
|
||||
|
||||
def test_chat_google_palm() -> None:
|
||||
"""Test Google PaLM Chat API wrapper."""
|
||||
chat = ChatGooglePalm()
|
||||
chat = ChatGooglePalm() # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
@ -21,7 +21,7 @@ def test_chat_google_palm() -> None:
|
||||
|
||||
def test_chat_google_palm_system_message() -> None:
|
||||
"""Test Google PaLM Chat API wrapper with system message."""
|
||||
chat = ChatGooglePalm()
|
||||
chat = ChatGooglePalm() # type: ignore[call-arg]
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([system_message, human_message])
|
||||
@ -31,7 +31,7 @@ def test_chat_google_palm_system_message() -> None:
|
||||
|
||||
def test_chat_google_palm_generate() -> None:
|
||||
"""Test Google PaLM Chat API wrapper with generate."""
|
||||
chat = ChatGooglePalm(n=2, temperature=1.0)
|
||||
chat = ChatGooglePalm(n=2, temperature=1.0) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -48,7 +48,7 @@ def test_chat_google_palm_multiple_completions() -> None:
|
||||
"""Test Google PaLM Chat API wrapper with multiple completions."""
|
||||
# The API de-dupes duplicate responses, so set temperature higher. This
|
||||
# could be a flakey test though...
|
||||
chat = ChatGooglePalm(n=5, temperature=1.0)
|
||||
chat = ChatGooglePalm(n=5, temperature=1.0) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
@ -60,7 +60,7 @@ def test_chat_google_palm_multiple_completions() -> None:
|
||||
|
||||
async def test_async_chat_google_palm() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatGooglePalm(n=2, temperature=1.0)
|
||||
chat = ChatGooglePalm(n=2, temperature=1.0) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
|
@ -16,9 +16,9 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_api_key_is_string() -> None:
|
||||
gpt_router = GPTRouter(
|
||||
gpt_router = GPTRouter( # type: ignore[call-arg]
|
||||
gpt_router_api_base="https://example.com",
|
||||
gpt_router_api_key="secret-api-key",
|
||||
gpt_router_api_key="secret-api-key", # type: ignore[arg-type]
|
||||
)
|
||||
assert isinstance(gpt_router.gpt_router_api_key, SecretStr)
|
||||
|
||||
@ -26,9 +26,9 @@ def test_api_key_is_string() -> None:
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
gpt_router = GPTRouter(
|
||||
gpt_router = GPTRouter( # type: ignore[call-arg]
|
||||
gpt_router_api_base="https://example.com",
|
||||
gpt_router_api_key="secret-api-key",
|
||||
gpt_router_api_key="secret-api-key", # type: ignore[arg-type]
|
||||
)
|
||||
print(gpt_router.gpt_router_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
@ -14,7 +14,7 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_jinachat_api_key_is_secret_string() -> None:
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key")
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key") # type: ignore[arg-type, call-arg]
|
||||
assert isinstance(llm.jinachat_api_key, SecretStr)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ def test_jinachat_api_key_masked_when_passed_from_env(
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("JINACHAT_API_KEY", "secret-api-key")
|
||||
llm = JinaChat()
|
||||
llm = JinaChat() # type: ignore[call-arg]
|
||||
print(llm.jinachat_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
@ -34,7 +34,7 @@ def test_jinachat_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key")
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key") # type: ignore[arg-type, call-arg]
|
||||
print(llm.jinachat_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
@ -43,13 +43,13 @@ def test_jinachat_api_key_masked_when_passed_via_constructor(
|
||||
|
||||
def test_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key")
|
||||
llm = JinaChat(jinachat_api_key="secret-api-key") # type: ignore[arg-type, call-arg]
|
||||
assert cast(SecretStr, llm.jinachat_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
|
||||
def test_jinachat() -> None:
|
||||
"""Test JinaChat wrapper."""
|
||||
chat = JinaChat(max_tokens=10)
|
||||
chat = JinaChat(max_tokens=10) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
@ -58,7 +58,7 @@ def test_jinachat() -> None:
|
||||
|
||||
def test_jinachat_system_message() -> None:
|
||||
"""Test JinaChat wrapper with system message."""
|
||||
chat = JinaChat(max_tokens=10)
|
||||
chat = JinaChat(max_tokens=10) # type: ignore[call-arg]
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([system_message, human_message])
|
||||
@ -68,7 +68,7 @@ def test_jinachat_system_message() -> None:
|
||||
|
||||
def test_jinachat_generate() -> None:
|
||||
"""Test JinaChat wrapper with generate."""
|
||||
chat = JinaChat(max_tokens=10)
|
||||
chat = JinaChat(max_tokens=10) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -85,7 +85,7 @@ def test_jinachat_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = JinaChat(
|
||||
chat = JinaChat( # type: ignore[call-arg]
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
@ -100,7 +100,7 @@ def test_jinachat_streaming() -> None:
|
||||
|
||||
async def test_async_jinachat() -> None:
|
||||
"""Test async generation."""
|
||||
chat = JinaChat(max_tokens=102)
|
||||
chat = JinaChat(max_tokens=102) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -117,7 +117,7 @@ async def test_async_jinachat_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = JinaChat(
|
||||
chat = JinaChat( # type: ignore[call-arg]
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
@ -140,18 +140,18 @@ async def test_async_jinachat_streaming() -> None:
|
||||
def test_jinachat_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = JinaChat(foo=3, max_tokens=10)
|
||||
llm = JinaChat(foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = JinaChat(foo=3, model_kwargs={"bar": 2})
|
||||
llm = JinaChat(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
JinaChat(foo=3, model_kwargs={"foo": 2})
|
||||
JinaChat(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
JinaChat(model_kwargs={"temperature": 0.2})
|
||||
JinaChat(model_kwargs={"temperature": 0.2}) # type: ignore[call-arg]
|
||||
|
@ -74,7 +74,7 @@ class TestChatKinetica:
|
||||
"""Create an LLM instance."""
|
||||
import gpudb
|
||||
|
||||
kinetica_llm = ChatKinetica()
|
||||
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
||||
LOG.info(kinetica_llm._identifying_params)
|
||||
|
||||
assert isinstance(kinetica_llm.kdbc, gpudb.GPUdb)
|
||||
@ -83,7 +83,7 @@ class TestChatKinetica:
|
||||
@pytest.mark.vcr()
|
||||
def test_load_context(self) -> None:
|
||||
"""Load the LLM context from the DB."""
|
||||
kinetica_llm = ChatKinetica()
|
||||
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
||||
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
||||
|
||||
system_message = ctx_messages[0]
|
||||
@ -96,7 +96,7 @@ class TestChatKinetica:
|
||||
@pytest.mark.vcr()
|
||||
def test_generate(self) -> None:
|
||||
"""Generate SQL from a chain."""
|
||||
kinetica_llm = ChatKinetica()
|
||||
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
||||
|
||||
# create chain
|
||||
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
||||
@ -113,7 +113,7 @@ class TestChatKinetica:
|
||||
@pytest.mark.vcr()
|
||||
def test_full_chain(self) -> None:
|
||||
"""Generate SQL from a chain and execute the query."""
|
||||
kinetica_llm = ChatKinetica()
|
||||
kinetica_llm = ChatKinetica() # type: ignore[call-arg]
|
||||
|
||||
# create chain
|
||||
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user