mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 11:47:49 +00:00
community: add mypy warn_unused_ignores rule (#30816)
This commit is contained in:
parent
a2863f8757
commit
aee7988a94
@ -104,7 +104,7 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
additional_kwargs["name"] = _dict["name"]
|
additional_kwargs["name"] = _dict["name"]
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content=_dict.get("content", ""),
|
content=_dict.get("content", ""),
|
||||||
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type]
|
tool_call_id=_dict.get("tool_call_id"),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -22,7 +22,7 @@ class AzureAiServicesToolkit(BaseToolkit):
|
|||||||
|
|
||||||
tools: List[BaseTool] = [
|
tools: List[BaseTool] = [
|
||||||
AzureAiServicesDocumentIntelligenceTool(), # type: ignore[call-arg]
|
AzureAiServicesDocumentIntelligenceTool(), # type: ignore[call-arg]
|
||||||
AzureAiServicesImageAnalysisTool(), # type: ignore[call-arg]
|
AzureAiServicesImageAnalysisTool(),
|
||||||
AzureAiServicesSpeechToTextTool(), # type: ignore[call-arg]
|
AzureAiServicesSpeechToTextTool(), # type: ignore[call-arg]
|
||||||
AzureAiServicesTextToSpeechTool(), # type: ignore[call-arg]
|
AzureAiServicesTextToSpeechTool(), # type: ignore[call-arg]
|
||||||
AzureAiServicesTextAnalyticsForHealthTool(), # type: ignore[call-arg]
|
AzureAiServicesTextAnalyticsForHealthTool(), # type: ignore[call-arg]
|
||||||
|
@ -81,7 +81,7 @@ class FileManagementToolkit(BaseToolkit):
|
|||||||
tools: List[BaseTool] = []
|
tools: List[BaseTool] = []
|
||||||
for tool in allowed_tools:
|
for tool in allowed_tools:
|
||||||
tool_cls = _FILE_TOOLS_MAP[tool]
|
tool_cls = _FILE_TOOLS_MAP[tool]
|
||||||
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore[call-arg]
|
tools.append(tool_cls(root_dir=self.root_dir))
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
|
|||||||
from langchain_community.utilities.requests import Requests
|
from langchain_community.utilities.requests import Requests
|
||||||
|
|
||||||
|
|
||||||
class NLATool(Tool): # type: ignore[override]
|
class NLATool(Tool):
|
||||||
"""Natural Language API Tool."""
|
"""Natural Language API Tool."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -64,7 +64,7 @@ def _get_default_llm_chain_factory(
|
|||||||
return partial(_get_default_llm_chain, prompt)
|
return partial(_get_default_llm_chain, prompt)
|
||||||
|
|
||||||
|
|
||||||
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[override]
|
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
|
||||||
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
|
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
|
||||||
|
|
||||||
name: str = "requests_get"
|
name: str = "requests_get"
|
||||||
@ -98,7 +98,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[ov
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[override]
|
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
|
||||||
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
|
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
|
||||||
|
|
||||||
name: str = "requests_post"
|
name: str = "requests_post"
|
||||||
@ -129,7 +129,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[o
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[override]
|
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
|
||||||
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
|
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
|
||||||
|
|
||||||
name: str = "requests_patch"
|
name: str = "requests_patch"
|
||||||
@ -162,7 +162,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[override]
|
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
|
||||||
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
|
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
|
||||||
|
|
||||||
name: str = "requests_put"
|
name: str = "requests_put"
|
||||||
@ -193,7 +193,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[ov
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool): # type: ignore[override]
|
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
|
||||||
"""Tool that sends a DELETE request and parses the response."""
|
"""Tool that sends a DELETE request and parses the response."""
|
||||||
|
|
||||||
name: str = "requests_delete"
|
name: str = "requests_delete"
|
||||||
@ -266,7 +266,7 @@ def _create_api_controller_agent(
|
|||||||
if "GET" in allowed_operations:
|
if "GET" in allowed_operations:
|
||||||
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
|
||||||
tools.append(
|
tools.append(
|
||||||
RequestsGetToolWithParsing( # type: ignore[call-arg]
|
RequestsGetToolWithParsing(
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
llm_chain=get_llm_chain,
|
llm_chain=get_llm_chain,
|
||||||
allow_dangerous_requests=allow_dangerous_requests,
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
@ -275,7 +275,7 @@ def _create_api_controller_agent(
|
|||||||
if "POST" in allowed_operations:
|
if "POST" in allowed_operations:
|
||||||
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
|
||||||
tools.append(
|
tools.append(
|
||||||
RequestsPostToolWithParsing( # type: ignore[call-arg]
|
RequestsPostToolWithParsing(
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
llm_chain=post_llm_chain,
|
llm_chain=post_llm_chain,
|
||||||
allow_dangerous_requests=allow_dangerous_requests,
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
@ -284,7 +284,7 @@ def _create_api_controller_agent(
|
|||||||
if "PUT" in allowed_operations:
|
if "PUT" in allowed_operations:
|
||||||
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
|
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
|
||||||
tools.append(
|
tools.append(
|
||||||
RequestsPutToolWithParsing( # type: ignore[call-arg]
|
RequestsPutToolWithParsing(
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
llm_chain=put_llm_chain,
|
llm_chain=put_llm_chain,
|
||||||
allow_dangerous_requests=allow_dangerous_requests,
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
@ -293,7 +293,7 @@ def _create_api_controller_agent(
|
|||||||
if "DELETE" in allowed_operations:
|
if "DELETE" in allowed_operations:
|
||||||
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
|
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
|
||||||
tools.append(
|
tools.append(
|
||||||
RequestsDeleteToolWithParsing( # type: ignore[call-arg]
|
RequestsDeleteToolWithParsing(
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
llm_chain=delete_llm_chain,
|
llm_chain=delete_llm_chain,
|
||||||
allow_dangerous_requests=allow_dangerous_requests,
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
@ -302,7 +302,7 @@ def _create_api_controller_agent(
|
|||||||
if "PATCH" in allowed_operations:
|
if "PATCH" in allowed_operations:
|
||||||
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
|
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
|
||||||
tools.append(
|
tools.append(
|
||||||
RequestsPatchToolWithParsing( # type: ignore[call-arg]
|
RequestsPatchToolWithParsing(
|
||||||
requests_wrapper=requests_wrapper,
|
requests_wrapper=requests_wrapper,
|
||||||
llm_chain=patch_llm_chain,
|
llm_chain=patch_llm_chain,
|
||||||
allow_dangerous_requests=allow_dangerous_requests,
|
allow_dangerous_requests=allow_dangerous_requests,
|
||||||
|
@ -75,7 +75,7 @@ class PowerBIToolkit(BaseToolkit):
|
|||||||
powerbi=self.powerbi,
|
powerbi=self.powerbi,
|
||||||
examples=self.examples,
|
examples=self.examples,
|
||||||
max_iterations=self.max_iterations,
|
max_iterations=self.max_iterations,
|
||||||
output_token_limit=self.output_token_limit, # type: ignore[arg-type]
|
output_token_limit=self.output_token_limit,
|
||||||
tiktoken_model_name=self.tiktoken_model_name,
|
tiktoken_model_name=self.tiktoken_model_name,
|
||||||
),
|
),
|
||||||
InfoPowerBITool(powerbi=self.powerbi),
|
InfoPowerBITool(powerbi=self.powerbi),
|
||||||
|
@ -289,7 +289,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
|||||||
name=name,
|
name=name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
tools=[_get_assistants_tool(tool) for tool in tools],
|
tools=[_get_assistants_tool(tool) for tool in tools],
|
||||||
tool_resources=tool_resources, # type: ignore[arg-type]
|
tool_resources=tool_resources,
|
||||||
model=model,
|
model=model,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@ -431,7 +431,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
|||||||
name=name,
|
name=name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
tools=openai_tools,
|
tools=openai_tools,
|
||||||
tool_resources=tool_resources, # type: ignore[arg-type]
|
tool_resources=tool_resources,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
|
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
|
||||||
|
@ -579,7 +579,7 @@ class AsyncRedisCache(_RedisCacheBase):
|
|||||||
try:
|
try:
|
||||||
async with self.redis.pipeline() as pipe:
|
async with self.redis.pipeline() as pipe:
|
||||||
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
|
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
|
||||||
await pipe.execute() # type: ignore[attr-defined]
|
await pipe.execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Redis async update failed: {e}")
|
logger.error(f"Redis async update failed: {e}")
|
||||||
|
|
||||||
|
@ -378,7 +378,7 @@ def create_ernie_fn_chain(
|
|||||||
output_key: str = "function",
|
output_key: str = "function",
|
||||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMChain: # type: ignore[valid-type]
|
) -> LLMChain:
|
||||||
"""[Legacy] Create an LLM chain that uses Ernie functions.
|
"""[Legacy] Create an LLM chain that uses Ernie functions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -455,7 +455,7 @@ def create_ernie_fn_chain(
|
|||||||
}
|
}
|
||||||
if len(ernie_functions) == 1:
|
if len(ernie_functions) == 1:
|
||||||
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
|
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
|
||||||
llm_chain = LLMChain( # type: ignore[misc]
|
llm_chain = LLMChain(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
@ -474,7 +474,7 @@ def create_structured_output_chain(
|
|||||||
output_key: str = "function",
|
output_key: str = "function",
|
||||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMChain: # type: ignore[valid-type]
|
) -> LLMChain:
|
||||||
"""[Legacy] Create an LLMChain that uses an Ernie function to get a structured output.
|
"""[Legacy] Create an LLMChain that uses an Ernie function to get a structured output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -324,7 +324,7 @@ class GraphCypherQAChain(Chain):
|
|||||||
|
|
||||||
cypher_generation_chain = LLMChain(
|
cypher_generation_chain = LLMChain(
|
||||||
llm=cypher_llm or llm, # type: ignore[arg-type]
|
llm=cypher_llm or llm, # type: ignore[arg-type]
|
||||||
**use_cypher_llm_kwargs, # type: ignore[arg-type]
|
**use_cypher_llm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if exclude_types and include_types:
|
if exclude_types and include_types:
|
||||||
|
@ -235,7 +235,7 @@ class MemgraphQAChain(Chain):
|
|||||||
llm_to_use = cypher_llm if cypher_llm is not None else llm
|
llm_to_use = cypher_llm if cypher_llm is not None else llm
|
||||||
|
|
||||||
if prompt is not None and llm_to_use is not None:
|
if prompt is not None and llm_to_use is not None:
|
||||||
cypher_generation_chain = prompt | llm_to_use | StrOutputParser() # type: ignore[arg-type]
|
cypher_generation_chain = prompt | llm_to_use | StrOutputParser()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing required components for the cypher generation chain: "
|
"Missing required components for the cypher generation chain: "
|
||||||
|
@ -181,7 +181,7 @@ class NeptuneSparqlQAChain(Chain):
|
|||||||
)
|
)
|
||||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
|
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
|
||||||
|
|
||||||
return cls( # type: ignore[call-arg]
|
return cls(
|
||||||
qa_chain=qa_chain,
|
qa_chain=qa_chain,
|
||||||
sparql_generation_chain=sparql_generation_chain,
|
sparql_generation_chain=sparql_generation_chain,
|
||||||
examples=examples,
|
examples=examples,
|
||||||
|
@ -28,7 +28,7 @@ class LLMRequestsChain(Chain):
|
|||||||
See https://python.langchain.com/docs/security for more information.
|
See https://python.langchain.com/docs/security for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm_chain: LLMChain # type: ignore[valid-type]
|
llm_chain: LLMChain
|
||||||
requests_wrapper: TextRequestsWrapper = Field(
|
requests_wrapper: TextRequestsWrapper = Field(
|
||||||
default_factory=lambda: TextRequestsWrapper(headers=DEFAULT_HEADERS),
|
default_factory=lambda: TextRequestsWrapper(headers=DEFAULT_HEADERS),
|
||||||
exclude=True,
|
exclude=True,
|
||||||
@ -88,7 +88,7 @@ class LLMRequestsChain(Chain):
|
|||||||
# extract the text from the html
|
# extract the text from the html
|
||||||
soup = BeautifulSoup(res, "html.parser")
|
soup = BeautifulSoup(res, "html.parser")
|
||||||
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
|
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
|
||||||
result = self.llm_chain.predict( # type: ignore[attr-defined]
|
result = self.llm_chain.predict(
|
||||||
callbacks=_run_manager.get_child(), **other_keys
|
callbacks=_run_manager.get_child(), **other_keys
|
||||||
)
|
)
|
||||||
return {self.output_key: result}
|
return {self.output_key: result}
|
||||||
|
@ -158,7 +158,7 @@ class IMessageChatLoader(BaseChatLoader):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
HumanMessage( # type: ignore[call-arg]
|
HumanMessage(
|
||||||
role=sender,
|
role=sender,
|
||||||
content=content,
|
content=content,
|
||||||
additional_kwargs={
|
additional_kwargs={
|
||||||
|
@ -52,7 +52,7 @@ class SlackChatLoader(BaseChatLoader):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results.append(
|
results.append(
|
||||||
HumanMessage( # type: ignore[call-arg]
|
HumanMessage(
|
||||||
role=sender,
|
role=sender,
|
||||||
content=text,
|
content=text,
|
||||||
additional_kwargs={
|
additional_kwargs={
|
||||||
|
@ -78,7 +78,7 @@ def map_ai_messages_in_session(chat_sessions: ChatSession, sender: str) -> ChatS
|
|||||||
message = AIMessage(
|
message = AIMessage(
|
||||||
content=message.content,
|
content=message.content,
|
||||||
additional_kwargs=message.additional_kwargs.copy(),
|
additional_kwargs=message.additional_kwargs.copy(),
|
||||||
example=getattr(message, "example", None), # type: ignore[arg-type]
|
example=getattr(message, "example", None),
|
||||||
)
|
)
|
||||||
num_converted += 1
|
num_converted += 1
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
@ -73,7 +73,7 @@ class WhatsAppChatLoader(BaseChatLoader):
|
|||||||
timestamp, sender, text = result.groups()
|
timestamp, sender, text = result.groups()
|
||||||
if not self._ignore_lines.match(text.strip()):
|
if not self._ignore_lines.match(text.strip()):
|
||||||
results.append(
|
results.append(
|
||||||
HumanMessage( # type: ignore[call-arg]
|
HumanMessage(
|
||||||
role=sender,
|
role=sender,
|
||||||
content=text,
|
content=text,
|
||||||
additional_kwargs={
|
additional_kwargs={
|
||||||
|
@ -85,7 +85,7 @@ def create_message_model(table_name: str, DynamicBase: Any) -> Any:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Model declared inside a function to have a dynamic table name.
|
# Model declared inside a function to have a dynamic table name.
|
||||||
class Message(DynamicBase): # type: ignore[valid-type, misc]
|
class Message(DynamicBase):
|
||||||
__tablename__ = table_name
|
__tablename__ = table_name
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
session_id = Column(Text)
|
session_id = Column(Text)
|
||||||
|
@ -167,7 +167,7 @@ class ChatAnyscale(ChatOpenAI):
|
|||||||
else:
|
else:
|
||||||
values["openai_api_base"] = values["anyscale_api_base"]
|
values["openai_api_base"] = values["anyscale_api_base"]
|
||||||
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
||||||
values["client"] = openai.ChatCompletion # type: ignore[attr-defined]
|
values["client"] = openai.ChatCompletion
|
||||||
except AttributeError as exc:
|
except AttributeError as exc:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
|
@ -227,7 +227,7 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
**client_params
|
**client_params
|
||||||
).chat.completions
|
).chat.completions
|
||||||
else:
|
else:
|
||||||
values["client"] = openai.ChatCompletion # type: ignore[attr-defined]
|
values["client"] = openai.ChatCompletion
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -304,7 +304,7 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
|||||||
"http_client": None,
|
"http_client": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
client = openai.OpenAI(**client_params) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
client = openai.OpenAI(**client_params)
|
||||||
message_dicts = [
|
message_dicts = [
|
||||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||||
for m in messages
|
for m in messages
|
||||||
@ -312,30 +312,30 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
|||||||
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
for chunk in client.chat.completions.create(messages=message_dicts, **params): # type: ignore[arg-type]
|
for chunk in client.chat.completions.create(messages=message_dicts, **params):
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.dict() # type: ignore[attr-defined]
|
chunk = chunk.dict()
|
||||||
if len(chunk["choices"]) == 0: # type: ignore[call-overload]
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0] # type: ignore[call-overload]
|
choice = chunk["choices"][0]
|
||||||
chunk = _convert_delta_to_message_chunk( # type: ignore[assignment]
|
chunk = _convert_delta_to_message_chunk(
|
||||||
choice["delta"], # type: ignore[arg-type, index]
|
choice["delta"],
|
||||||
default_chunk_class, # type: ignore[arg-type, index]
|
default_chunk_class,
|
||||||
)
|
)
|
||||||
generation_info = {}
|
generation_info = {}
|
||||||
if finish_reason := choice.get("finish_reason"): # type: ignore[union-attr]
|
if finish_reason := choice.get("finish_reason"):
|
||||||
generation_info["finish_reason"] = finish_reason
|
generation_info["finish_reason"] = finish_reason
|
||||||
logprobs = choice.get("logprobs") # type: ignore[union-attr]
|
logprobs = choice.get("logprobs")
|
||||||
if logprobs:
|
if logprobs:
|
||||||
generation_info["logprobs"] = logprobs
|
generation_info["logprobs"] = logprobs
|
||||||
default_chunk_class = chunk.__class__ # type: ignore[assignment]
|
default_chunk_class = chunk.__class__
|
||||||
chunk = ChatGenerationChunk( # type: ignore[assignment]
|
chunk = ChatGenerationChunk(
|
||||||
message=chunk, # type: ignore[arg-type]
|
message=chunk,
|
||||||
generation_info=generation_info or None, # type: ignore[arg-type]
|
generation_info=generation_info or None,
|
||||||
)
|
)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) # type: ignore[attr-defined, arg-type]
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||||
yield chunk # type: ignore[misc]
|
yield chunk
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
@ -359,7 +359,7 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
|||||||
"http_client": None,
|
"http_client": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
async_client = openai.AsyncOpenAI(**client_params) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
async_client = openai.AsyncOpenAI(**client_params)
|
||||||
message_dicts = [
|
message_dicts = [
|
||||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||||
for m in messages
|
for m in messages
|
||||||
@ -367,9 +367,9 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
|||||||
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
default_chunk_class = AIMessageChunk
|
||||||
async for chunk in await async_client.chat.completions.create( # type: ignore[attr-defined]
|
async for chunk in await async_client.chat.completions.create(
|
||||||
messages=message_dicts, # type: ignore[arg-type]
|
messages=message_dicts,
|
||||||
**params, # type: ignore[arg-type]
|
**params,
|
||||||
):
|
):
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.dict()
|
chunk = chunk.dict()
|
||||||
|
@ -128,7 +128,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
return AIMessage(
|
return AIMessage(
|
||||||
content=content,
|
content=content,
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
tool_calls=tool_calls, # type: ignore[arg-type]
|
tool_calls=tool_calls,
|
||||||
invalid_tool_calls=invalid_tool_calls,
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
)
|
)
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
@ -137,7 +137,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
additional_kwargs["name"] = _dict["name"]
|
additional_kwargs["name"] = _dict["name"]
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content=content,
|
content=content,
|
||||||
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type]
|
tool_call_id=_dict.get("tool_call_id"),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
)
|
)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
|
@ -821,7 +821,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
|
@ -213,7 +213,7 @@ class ChatCloudflareWorkersAI(BaseChatModel):
|
|||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputKeyToolsParser(
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
@ -222,7 +222,7 @@ class ChatCloudflareWorkersAI(BaseChatModel):
|
|||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(response_format={"type": "json_object"})
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
else JsonOutputParser()
|
else JsonOutputParser()
|
||||||
)
|
)
|
||||||
|
@ -110,7 +110,7 @@ class ChatEverlyAI(ChatOpenAI):
|
|||||||
"Please install it with `pip install openai`.",
|
"Please install it with `pip install openai`.",
|
||||||
) from e
|
) from e
|
||||||
try:
|
try:
|
||||||
values["client"] = openai.ChatCompletion # type: ignore[attr-defined]
|
values["client"] = openai.ChatCompletion
|
||||||
except AttributeError as exc:
|
except AttributeError as exc:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
|
@ -70,11 +70,11 @@ def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
|
|||||||
stop=stop_after_attempt(llm.max_retries),
|
stop=stop_after_attempt(llm.max_retries),
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout) # type: ignore[attr-defined]
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
| retry_if_exception_type(openai.error.APIError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
),
|
),
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
@ -234,7 +234,7 @@ class JinaChat(BaseChatModel):
|
|||||||
"Please install it with `pip install openai`."
|
"Please install it with `pip install openai`."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
values["client"] = openai.ChatCompletion # type: ignore[attr-defined]
|
values["client"] = openai.ChatCompletion
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
@ -266,11 +266,11 @@ class JinaChat(BaseChatModel):
|
|||||||
stop=stop_after_attempt(self.max_retries),
|
stop=stop_after_attempt(self.max_retries),
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout) # type: ignore[attr-defined]
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
| retry_if_exception_type(openai.error.APIError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
),
|
),
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
|
@ -42,7 +42,7 @@ DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatKonko(ChatOpenAI): # type: ignore[override]
|
class ChatKonko(ChatOpenAI):
|
||||||
"""`ChatKonko` Chat large language models API.
|
"""`ChatKonko` Chat large language models API.
|
||||||
|
|
||||||
To use, you should have the ``konko`` python package installed, and the
|
To use, you should have the ``konko`` python package installed, and the
|
||||||
|
@ -664,7 +664,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
name=name,
|
name=name,
|
||||||
id=id_,
|
id=id_,
|
||||||
tool_calls=tool_calls, # type: ignore[arg-type]
|
tool_calls=tool_calls,
|
||||||
invalid_tool_calls=invalid_tool_calls,
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
)
|
)
|
||||||
elif role == "system":
|
elif role == "system":
|
||||||
|
@ -777,7 +777,7 @@ class MiniMaxChat(BaseChatModel):
|
|||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
|
@ -12,7 +12,7 @@ from langchain_community.chat_models import ChatOpenAI
|
|||||||
from langchain_community.llms.moonshot import MOONSHOT_SERVICE_URL_BASE, MoonshotCommon
|
from langchain_community.llms.moonshot import MOONSHOT_SERVICE_URL_BASE, MoonshotCommon
|
||||||
|
|
||||||
|
|
||||||
class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc, override, override]
|
class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc]
|
||||||
"""Moonshot chat model integration.
|
"""Moonshot chat model integration.
|
||||||
|
|
||||||
Setup:
|
Setup:
|
||||||
|
@ -587,7 +587,7 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
|
|||||||
if method == "json_mode":
|
if method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(response_format={"type": "json_object"})
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
else JsonOutputParser()
|
else JsonOutputParser()
|
||||||
)
|
)
|
||||||
|
@ -725,7 +725,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
|
|||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(response_format={"type": "json_object"})
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
else JsonOutputParser()
|
else JsonOutputParser()
|
||||||
)
|
)
|
||||||
|
@ -98,7 +98,7 @@ class ChatOctoAI(ChatOpenAI):
|
|||||||
else:
|
else:
|
||||||
values["openai_api_base"] = values["octoai_api_base"]
|
values["openai_api_base"] = values["octoai_api_base"]
|
||||||
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
|
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
|
||||||
values["client"] = openai.ChatCompletion # type: ignore[attr-defined]
|
values["client"] = openai.ChatCompletion
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
|
@ -88,11 +88,11 @@ def _create_retry_decorator(
|
|||||||
import openai
|
import openai
|
||||||
|
|
||||||
errors = [
|
errors = [
|
||||||
openai.error.Timeout, # type: ignore[attr-defined]
|
openai.error.Timeout,
|
||||||
openai.error.APIError, # type: ignore[attr-defined]
|
openai.error.APIError,
|
||||||
openai.error.APIConnectionError, # type: ignore[attr-defined]
|
openai.error.APIConnectionError,
|
||||||
openai.error.RateLimitError, # type: ignore[attr-defined]
|
openai.error.RateLimitError,
|
||||||
openai.error.ServiceUnavailableError, # type: ignore[attr-defined]
|
openai.error.ServiceUnavailableError,
|
||||||
]
|
]
|
||||||
return create_base_retry_decorator(
|
return create_base_retry_decorator(
|
||||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
@ -358,7 +358,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
**client_params
|
**client_params
|
||||||
).chat.completions
|
).chat.completions
|
||||||
elif not values.get("client"):
|
elif not values.get("client"):
|
||||||
values["client"] = openai.ChatCompletion # type: ignore[attr-defined]
|
values["client"] = openai.ChatCompletion
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return values
|
return values
|
||||||
@ -595,7 +595,7 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
if self.openai_proxy:
|
if self.openai_proxy:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[attr-defined]
|
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}
|
||||||
return {**self._default_params, **openai_creds}
|
return {**self._default_params, **openai_creds}
|
||||||
|
|
||||||
def _get_invocation_params(
|
def _get_invocation_params(
|
||||||
|
@ -486,7 +486,7 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
if is_pydantic_schema and hasattr(
|
if is_pydantic_schema and hasattr(
|
||||||
schema, "model_json_schema"
|
schema, "model_json_schema"
|
||||||
): # accounting for pydantic v1 and v2
|
): # accounting for pydantic v1 and v2
|
||||||
response_format = schema.model_json_schema() # type: ignore[union-attr]
|
response_format = schema.model_json_schema()
|
||||||
elif is_pydantic_schema:
|
elif is_pydantic_schema:
|
||||||
response_format = schema.schema() # type: ignore[union-attr]
|
response_format = schema.schema() # type: ignore[union-attr]
|
||||||
elif isinstance(schema, dict):
|
elif isinstance(schema, dict):
|
||||||
|
@ -636,7 +636,7 @@ class ChatSambaNovaCloud(BaseChatModel):
|
|||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike[Any] = PydanticToolsParser(
|
output_parser: OutputParserLike[Any] = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputKeyToolsParser(
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
@ -648,7 +648,7 @@ class ChatSambaNovaCloud(BaseChatModel):
|
|||||||
# llm = self.bind(response_format={"type": "json_object"})
|
# llm = self.bind(response_format={"type": "json_object"})
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
schema = cast(Type[BaseModel], schema)
|
schema = cast(Type[BaseModel], schema)
|
||||||
output_parser = PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
output_parser = PydanticOutputParser(pydantic_object=schema)
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser()
|
output_parser = JsonOutputParser()
|
||||||
|
|
||||||
@ -666,7 +666,7 @@ class ChatSambaNovaCloud(BaseChatModel):
|
|||||||
# )
|
# )
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
schema = cast(Type[BaseModel], schema)
|
schema = cast(Type[BaseModel], schema)
|
||||||
output_parser = PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
output_parser = PydanticOutputParser(pydantic_object=schema)
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser()
|
output_parser = JsonOutputParser()
|
||||||
else:
|
else:
|
||||||
|
@ -13,7 +13,7 @@ from langchain_community.llms.solar import SOLAR_SERVICE_URL_BASE, SolarCommon
|
|||||||
@deprecated( # type: ignore[arg-type]
|
@deprecated( # type: ignore[arg-type]
|
||||||
since="0.0.34", removal="1.0", alternative_import="langchain_upstage.ChatUpstage"
|
since="0.0.34", removal="1.0", alternative_import="langchain_upstage.ChatUpstage"
|
||||||
)
|
)
|
||||||
class SolarChat(SolarCommon, ChatOpenAI): # type: ignore[override, override]
|
class SolarChat(SolarCommon, ChatOpenAI):
|
||||||
"""Wrapper around Solar large language models.
|
"""Wrapper around Solar large language models.
|
||||||
To use, you should have the ``openai`` python package installed, and the
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
environment variable ``SOLAR_API_KEY`` set with your API key.
|
environment variable ``SOLAR_API_KEY`` set with your API key.
|
||||||
|
@ -176,7 +176,7 @@ class ChatNebula(BaseChatModel):
|
|||||||
json_payload = json.dumps(payload)
|
json_payload = json.dumps(payload)
|
||||||
|
|
||||||
async with ClientSession() as session:
|
async with ClientSession() as session:
|
||||||
async with session.post( # type: ignore[call-arg]
|
async with session.post( # type: ignore[call-arg,unused-ignore]
|
||||||
url, data=json_payload, headers=headers, stream=True
|
url, data=json_payload, headers=headers, stream=True
|
||||||
) as response:
|
) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -140,7 +140,7 @@ def convert_dict_to_message(
|
|||||||
else AIMessage(
|
else AIMessage(
|
||||||
content=content,
|
content=content,
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
tool_calls=tool_calls, # type: ignore[arg-type]
|
tool_calls=tool_calls,
|
||||||
invalid_tool_calls=invalid_tool_calls,
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -163,7 +163,7 @@ def convert_dict_to_message(
|
|||||||
if is_chunk
|
if is_chunk
|
||||||
else ToolMessage(
|
else ToolMessage(
|
||||||
content=_dict.get("content", ""),
|
content=_dict.get("content", ""),
|
||||||
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type]
|
tool_call_id=_dict.get("tool_call_id"),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -894,7 +894,7 @@ class ChatTongyi(BaseChatModel):
|
|||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
|
@ -209,7 +209,7 @@ def _get_question(messages: List[BaseMessage]) -> HumanMessage:
|
|||||||
removal="1.0",
|
removal="1.0",
|
||||||
alternative_import="langchain_google_vertexai.ChatVertexAI",
|
alternative_import="langchain_google_vertexai.ChatVertexAI",
|
||||||
)
|
)
|
||||||
class ChatVertexAI(_VertexAICommon, BaseChatModel): # type: ignore[override]
|
class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||||
"""`Vertex AI` Chat large language models API."""
|
"""`Vertex AI` Chat large language models API."""
|
||||||
|
|
||||||
model_name: str = "chat-bison"
|
model_name: str = "chat-bison"
|
||||||
|
@ -162,7 +162,7 @@ def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage:
|
|||||||
additional_kwargs["name"] = dct["name"]
|
additional_kwargs["name"] = dct["name"]
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content=content,
|
content=content,
|
||||||
tool_call_id=dct.get("tool_call_id"), # type: ignore[arg-type]
|
tool_call_id=dct.get("tool_call_id"),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
)
|
)
|
||||||
return ChatMessage(role=role, content=content) # type: ignore[arg-type]
|
return ChatMessage(role=role, content=content) # type: ignore[arg-type]
|
||||||
@ -861,7 +861,7 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputKeyToolsParser(
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
|
@ -226,8 +226,8 @@ class AsyncHtmlLoader(BaseLoader):
|
|||||||
# in a separate loop, in a separate thread.
|
# in a separate loop, in a separate thread.
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
future: Future[List[str]] = executor.submit(
|
future: Future[List[str]] = executor.submit(
|
||||||
asyncio.run, # type: ignore[arg-type]
|
asyncio.run,
|
||||||
self.fetch_all(self.web_paths), # type: ignore[arg-type]
|
self.fetch_all(self.web_paths),
|
||||||
)
|
)
|
||||||
results = future.result()
|
results = future.result()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
@ -224,7 +224,7 @@ class CloudBlobLoader(BlobLoader):
|
|||||||
yield self.path
|
yield self.path
|
||||||
return
|
return
|
||||||
|
|
||||||
paths = self.path.glob(self.glob) # type: ignore[attr-defined]
|
paths = self.path.glob(self.glob)
|
||||||
for path in paths:
|
for path in paths:
|
||||||
if self.exclude:
|
if self.exclude:
|
||||||
if any(path.match(glob) for glob in self.exclude):
|
if any(path.match(glob) for glob in self.exclude):
|
||||||
|
@ -24,9 +24,9 @@ class ConcurrentLoader(GenericLoader):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
blob_loader: BlobLoader, # type: ignore[valid-type]
|
blob_loader: BlobLoader,
|
||||||
blob_parser: BaseBlobParser,
|
blob_parser: BaseBlobParser,
|
||||||
num_workers: int = 4, # type: ignore[valid-type]
|
num_workers: int = 4,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(blob_loader, blob_parser)
|
super().__init__(blob_loader, blob_parser)
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
@ -40,7 +40,7 @@ class ConcurrentLoader(GenericLoader):
|
|||||||
) as executor:
|
) as executor:
|
||||||
futures = {
|
futures = {
|
||||||
executor.submit(self.blob_parser.lazy_parse, blob)
|
executor.submit(self.blob_parser.lazy_parse, blob)
|
||||||
for blob in self.blob_loader.yield_blobs() # type: ignore[attr-defined]
|
for blob in self.blob_loader.yield_blobs()
|
||||||
}
|
}
|
||||||
for future in concurrent.futures.as_completed(futures):
|
for future in concurrent.futures.as_completed(futures):
|
||||||
yield from future.result()
|
yield from future.result()
|
||||||
@ -72,7 +72,7 @@ class ConcurrentLoader(GenericLoader):
|
|||||||
num_workers: Max number of concurrent workers to use.
|
num_workers: Max number of concurrent workers to use.
|
||||||
parser_kwargs: Keyword arguments to pass to the parser.
|
parser_kwargs: Keyword arguments to pass to the parser.
|
||||||
"""
|
"""
|
||||||
blob_loader = FileSystemBlobLoader( # type: ignore[attr-defined, misc]
|
blob_loader = FileSystemBlobLoader(
|
||||||
path,
|
path,
|
||||||
glob=glob,
|
glob=glob,
|
||||||
exclude=exclude,
|
exclude=exclude,
|
||||||
|
@ -428,7 +428,7 @@ class ConfluenceLoader(BaseLoader):
|
|||||||
self.number_of_retries # type: ignore[arg-type]
|
self.number_of_retries # type: ignore[arg-type]
|
||||||
),
|
),
|
||||||
wait=wait_exponential(
|
wait=wait_exponential(
|
||||||
multiplier=1, # type: ignore[arg-type]
|
multiplier=1,
|
||||||
min=self.min_retry_seconds, # type: ignore[arg-type]
|
min=self.min_retry_seconds, # type: ignore[arg-type]
|
||||||
max=self.max_retry_seconds, # type: ignore[arg-type]
|
max=self.max_retry_seconds, # type: ignore[arg-type]
|
||||||
),
|
),
|
||||||
|
@ -223,4 +223,4 @@ class UnstructuredCSVLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.csv import partition_csv
|
from unstructured.partition.csv import partition_csv
|
||||||
|
|
||||||
return partition_csv(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_csv(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -101,7 +101,7 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader):
|
|||||||
self.url_path = url_path
|
self.url_path = url_path
|
||||||
self.bytes_source = bytes_source
|
self.bytes_source = bytes_source
|
||||||
|
|
||||||
self.parser = AzureAIDocumentIntelligenceParser( # type: ignore[misc]
|
self.parser = AzureAIDocumentIntelligenceParser(
|
||||||
api_endpoint=api_endpoint,
|
api_endpoint=api_endpoint,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
@ -116,10 +116,10 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader):
|
|||||||
) -> Iterator[Document]:
|
) -> Iterator[Document]:
|
||||||
"""Lazy load the document as pages."""
|
"""Lazy load the document as pages."""
|
||||||
if self.file_path is not None:
|
if self.file_path is not None:
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
yield from self.parser.parse(blob)
|
yield from self.parser.parse(blob)
|
||||||
elif self.url_path is not None:
|
elif self.url_path is not None:
|
||||||
yield from self.parser.parse_url(self.url_path) # type: ignore[arg-type]
|
yield from self.parser.parse_url(self.url_path)
|
||||||
elif self.bytes_source is not None:
|
elif self.bytes_source is not None:
|
||||||
yield from self.parser.parse_bytes(self.bytes_source)
|
yield from self.parser.parse_bytes(self.bytes_source)
|
||||||
else:
|
else:
|
||||||
|
@ -60,16 +60,16 @@ class UnstructuredEmailLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
from unstructured.file_utils.filetype import FileType, detect_filetype
|
||||||
|
|
||||||
filetype = detect_filetype(self.file_path) # type: ignore[arg-type]
|
filetype = detect_filetype(self.file_path)
|
||||||
|
|
||||||
if filetype == FileType.EML:
|
if filetype == FileType.EML:
|
||||||
from unstructured.partition.email import partition_email
|
from unstructured.partition.email import partition_email
|
||||||
|
|
||||||
return partition_email(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_email(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
elif satisfies_min_unstructured_version("0.5.8") and filetype == FileType.MSG:
|
elif satisfies_min_unstructured_version("0.5.8") and filetype == FileType.MSG:
|
||||||
from unstructured.partition.msg import partition_msg
|
from unstructured.partition.msg import partition_msg
|
||||||
|
|
||||||
return partition_msg(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_msg(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Filetype {filetype} is not supported in UnstructuredEmailLoader."
|
f"Filetype {filetype} is not supported in UnstructuredEmailLoader."
|
||||||
|
@ -52,4 +52,4 @@ class UnstructuredEPubLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.epub import partition_epub
|
from unstructured.partition.epub import partition_epub
|
||||||
|
|
||||||
return partition_epub(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_epub(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -49,4 +49,4 @@ class UnstructuredExcelLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.xlsx import partition_xlsx
|
from unstructured.partition.xlsx import partition_xlsx
|
||||||
|
|
||||||
return partition_xlsx(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_xlsx(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -96,7 +96,7 @@ class GenericLoader(BaseLoader):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
blob_loader: BlobLoader, # type: ignore[valid-type]
|
blob_loader: BlobLoader,
|
||||||
blob_parser: BaseBlobParser,
|
blob_parser: BaseBlobParser,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A generic document loader.
|
"""A generic document loader.
|
||||||
@ -112,7 +112,7 @@ class GenericLoader(BaseLoader):
|
|||||||
self,
|
self,
|
||||||
) -> Iterator[Document]:
|
) -> Iterator[Document]:
|
||||||
"""Load documents lazily. Use this when working at a large scale."""
|
"""Load documents lazily. Use this when working at a large scale."""
|
||||||
for blob in self.blob_loader.yield_blobs(): # type: ignore[attr-defined]
|
for blob in self.blob_loader.yield_blobs():
|
||||||
yield from self.blob_parser.lazy_parse(blob)
|
yield from self.blob_parser.lazy_parse(blob)
|
||||||
|
|
||||||
def load_and_split(
|
def load_and_split(
|
||||||
@ -159,7 +159,7 @@ class GenericLoader(BaseLoader):
|
|||||||
Returns:
|
Returns:
|
||||||
A generic document loader.
|
A generic document loader.
|
||||||
"""
|
"""
|
||||||
blob_loader = FileSystemBlobLoader( # type: ignore[attr-defined, misc]
|
blob_loader = FileSystemBlobLoader(
|
||||||
path,
|
path,
|
||||||
glob=glob,
|
glob=glob,
|
||||||
exclude=exclude,
|
exclude=exclude,
|
||||||
|
@ -74,7 +74,7 @@ class GitLoader(BaseLoader):
|
|||||||
|
|
||||||
file_path = os.path.join(self.repo_path, item.path)
|
file_path = os.path.join(self.repo_path, item.path)
|
||||||
|
|
||||||
ignored_files = repo.ignored([file_path]) # type: ignore[arg-type]
|
ignored_files = repo.ignored([file_path])
|
||||||
if len(ignored_files):
|
if len(ignored_files):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -48,4 +48,4 @@ class UnstructuredHTMLLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.html import partition_html
|
from unstructured.partition.html import partition_html
|
||||||
|
|
||||||
return partition_html(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_html(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -48,4 +48,4 @@ class UnstructuredImageLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.image import partition_image
|
from unstructured.partition.image import partition_image
|
||||||
|
|
||||||
return partition_image(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_image(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -76,13 +76,13 @@ class ImageCaptionLoader(BaseLoader):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
image = Image.open(BytesIO(image)).convert("RGB") # type: ignore[assignment]
|
image = Image.open(BytesIO(image)).convert("RGB")
|
||||||
elif isinstance(image, str) and (
|
elif isinstance(image, str) and (
|
||||||
image.startswith("http://") or image.startswith("https://")
|
image.startswith("http://") or image.startswith("https://")
|
||||||
):
|
):
|
||||||
image = Image.open(requests.get(image, stream=True).raw).convert("RGB") # type: ignore[assignment, arg-type]
|
image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
|
||||||
else:
|
else:
|
||||||
image = Image.open(image).convert("RGB") # type: ignore[assignment]
|
image = Image.open(image).convert("RGB")
|
||||||
except Exception:
|
except Exception:
|
||||||
if isinstance(image_source, bytes):
|
if isinstance(image_source, bytes):
|
||||||
msg = "Could not get image data from bytes"
|
msg = "Could not get image data from bytes"
|
||||||
|
@ -93,4 +93,4 @@ class UnstructuredMarkdownLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.md import partition_md
|
from unstructured.partition.md import partition_md
|
||||||
|
|
||||||
return partition_md(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_md(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -52,4 +52,4 @@ class UnstructuredODTLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.odt import partition_odt
|
from unstructured.partition.odt import partition_odt
|
||||||
|
|
||||||
return partition_odt(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_odt(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -52,4 +52,4 @@ class UnstructuredOrgModeLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.org import partition_org
|
from unstructured.partition.org import partition_org
|
||||||
|
|
||||||
return partition_org(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_org(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -322,7 +322,7 @@ class OpenAIWhisperParser(BaseBlobParser):
|
|||||||
model=self.model, file=file_obj, **self._create_params
|
model=self.model, file=file_obj, **self._create_params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transcript = openai.Audio.transcribe(self.model, file_obj) # type: ignore[attr-defined]
|
transcript = openai.Audio.transcribe(self.model, file_obj)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
attempts += 1
|
attempts += 1
|
||||||
|
@ -9,7 +9,7 @@ from langchain_community.document_loaders.blob_loaders import Blob
|
|||||||
class MsWordParser(BaseBlobParser):
|
class MsWordParser(BaseBlobParser):
|
||||||
"""Parse the Microsoft Word documents from a blob."""
|
"""Parse the Microsoft Word documents from a blob."""
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Parse a Microsoft Word document into the Document iterator.
|
"""Parse a Microsoft Word document into the Document iterator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -33,13 +33,13 @@ class MsWordParser(BaseBlobParser):
|
|||||||
partition_docx
|
partition_docx
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if blob.mimetype not in ( # type: ignore[attr-defined]
|
if blob.mimetype not in (
|
||||||
"application/msword",
|
"application/msword",
|
||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
):
|
):
|
||||||
raise ValueError("This blob type is not supported for this parser.")
|
raise ValueError("This blob type is not supported for this parser.")
|
||||||
with blob.as_bytes_io() as word_document: # type: ignore[attr-defined]
|
with blob.as_bytes_io() as word_document:
|
||||||
elements = mime_type_parser[blob.mimetype](file=word_document) # type: ignore[attr-defined] # type: ignore[operator] # type: ignore[operator] # type: ignore[operator] # type: ignore[operator] # type: ignore[operator] # type: ignore[operator]
|
elements = mime_type_parser[blob.mimetype](file=word_document)
|
||||||
text = "\n\n".join([str(el) for el in elements])
|
text = "\n\n".join([str(el) for el in elements])
|
||||||
metadata = {"source": blob.source} # type: ignore[attr-defined]
|
metadata = {"source": blob.source}
|
||||||
yield Document(page_content=text, metadata=metadata)
|
yield Document(page_content=text, metadata=metadata)
|
||||||
|
@ -340,7 +340,7 @@ class PyPDFParser(BaseBlobParser):
|
|||||||
self.extraction_mode = extraction_mode
|
self.extraction_mode = extraction_mode
|
||||||
self.extraction_kwargs = extraction_kwargs or {}
|
self.extraction_kwargs = extraction_kwargs or {}
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""
|
"""
|
||||||
Lazily parse the blob.
|
Lazily parse the blob.
|
||||||
Insert image, if possible, between two paragraphs.
|
Insert image, if possible, between two paragraphs.
|
||||||
@ -380,7 +380,7 @@ class PyPDFParser(BaseBlobParser):
|
|||||||
**self.extraction_kwargs,
|
**self.extraction_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
with blob.as_bytes_io() as pdf_file_obj: # type: ignore[attr-defined]
|
with blob.as_bytes_io() as pdf_file_obj:
|
||||||
pdf_reader = pypdf.PdfReader(pdf_file_obj, password=self.password)
|
pdf_reader = pypdf.PdfReader(pdf_file_obj, password=self.password)
|
||||||
|
|
||||||
doc_metadata = _purge_metadata(
|
doc_metadata = _purge_metadata(
|
||||||
@ -434,7 +434,7 @@ class PyPDFParser(BaseBlobParser):
|
|||||||
if "/XObject" not in cast(dict, page["/Resources"]).keys():
|
if "/XObject" not in cast(dict, page["/Resources"]).keys():
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
xObject = page["/Resources"]["/XObject"].get_object() # type: ignore[index]
|
xObject = page["/Resources"]["/XObject"].get_object()
|
||||||
images = []
|
images = []
|
||||||
for obj in xObject:
|
for obj in xObject:
|
||||||
np_image: Any = None
|
np_image: Any = None
|
||||||
@ -677,7 +677,7 @@ class PDFMinerParser(BaseBlobParser):
|
|||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""
|
"""
|
||||||
Lazily parse the blob.
|
Lazily parse the blob.
|
||||||
Insert image, if possible, between two paragraphs.
|
Insert image, if possible, between two paragraphs.
|
||||||
@ -919,7 +919,7 @@ class PyMuPDFParser(BaseBlobParser):
|
|||||||
self.extract_tables = extract_tables
|
self.extract_tables = extract_tables
|
||||||
self.extract_tables_settings = extract_tables_settings
|
self.extract_tables_settings = extract_tables_settings
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
return self._lazy_parse(
|
return self._lazy_parse(
|
||||||
blob,
|
blob,
|
||||||
)
|
)
|
||||||
@ -930,7 +930,7 @@ class PyMuPDFParser(BaseBlobParser):
|
|||||||
# text-kwargs is present for backwards compatibility.
|
# text-kwargs is present for backwards compatibility.
|
||||||
# Users should not use it directly.
|
# Users should not use it directly.
|
||||||
text_kwargs: Optional[dict[str, Any]] = None,
|
text_kwargs: Optional[dict[str, Any]] = None,
|
||||||
) -> Iterator[Document]: # type: ignore[valid-type]
|
) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob.
|
"""Lazily parse the blob.
|
||||||
Insert image, if possible, between two paragraphs.
|
Insert image, if possible, between two paragraphs.
|
||||||
In this way, a paragraph can be continued on the next page.
|
In this way, a paragraph can be continued on the next page.
|
||||||
@ -990,8 +990,8 @@ class PyMuPDFParser(BaseBlobParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with PyMuPDFParser._lock:
|
with PyMuPDFParser._lock:
|
||||||
with blob.as_bytes_io() as file_path: # type: ignore[attr-defined]
|
with blob.as_bytes_io() as file_path:
|
||||||
if blob.data is None: # type: ignore[attr-defined]
|
if blob.data is None:
|
||||||
doc = pymupdf.open(file_path)
|
doc = pymupdf.open(file_path)
|
||||||
else:
|
else:
|
||||||
doc = pymupdf.open(stream=file_path, filetype="pdf")
|
doc = pymupdf.open(stream=file_path, filetype="pdf")
|
||||||
@ -1066,8 +1066,8 @@ class PyMuPDFParser(BaseBlobParser):
|
|||||||
"producer": "PyMuPDF",
|
"producer": "PyMuPDF",
|
||||||
"creator": "PyMuPDF",
|
"creator": "PyMuPDF",
|
||||||
"creationdate": "",
|
"creationdate": "",
|
||||||
"source": blob.source, # type: ignore[attr-defined]
|
"source": blob.source,
|
||||||
"file_path": blob.source, # type: ignore[attr-defined]
|
"file_path": blob.source,
|
||||||
"total_pages": len(doc),
|
"total_pages": len(doc),
|
||||||
},
|
},
|
||||||
**{
|
**{
|
||||||
@ -1273,7 +1273,7 @@ class PyPDFium2Parser(BaseBlobParser):
|
|||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.pages_delimiter = pages_delimiter
|
self.pages_delimiter = pages_delimiter
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""
|
"""
|
||||||
Lazily parse the blob.
|
Lazily parse the blob.
|
||||||
Insert image, if possible, between two paragraphs.
|
Insert image, if possible, between two paragraphs.
|
||||||
@ -1299,7 +1299,7 @@ class PyPDFium2Parser(BaseBlobParser):
|
|||||||
# pypdfium2 is really finicky with respect to closing things,
|
# pypdfium2 is really finicky with respect to closing things,
|
||||||
# if done incorrectly creates seg faults.
|
# if done incorrectly creates seg faults.
|
||||||
with PyPDFium2Parser._lock:
|
with PyPDFium2Parser._lock:
|
||||||
with blob.as_bytes_io() as file_path: # type: ignore[attr-defined]
|
with blob.as_bytes_io() as file_path:
|
||||||
pdf_reader = None
|
pdf_reader = None
|
||||||
try:
|
try:
|
||||||
pdf_reader = pypdfium2.PdfDocument(
|
pdf_reader = pypdfium2.PdfDocument(
|
||||||
@ -1410,11 +1410,11 @@ class PDFPlumberParser(BaseBlobParser):
|
|||||||
self.dedupe = dedupe
|
self.dedupe = dedupe
|
||||||
self.extract_images = extract_images
|
self.extract_images = extract_images
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob."""
|
"""Lazily parse the blob."""
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
|
|
||||||
with blob.as_bytes_io() as file_path: # type: ignore[attr-defined]
|
with blob.as_bytes_io() as file_path:
|
||||||
doc = pdfplumber.open(file_path) # open document
|
doc = pdfplumber.open(file_path) # open document
|
||||||
|
|
||||||
yield from [
|
yield from [
|
||||||
@ -1424,8 +1424,8 @@ class PDFPlumberParser(BaseBlobParser):
|
|||||||
+ self._extract_images_from_page(page),
|
+ self._extract_images_from_page(page),
|
||||||
metadata=dict(
|
metadata=dict(
|
||||||
{
|
{
|
||||||
"source": blob.source, # type: ignore[attr-defined]
|
"source": blob.source,
|
||||||
"file_path": blob.source, # type: ignore[attr-defined]
|
"file_path": blob.source,
|
||||||
"page": page.page_number - 1,
|
"page": page.page_number - 1,
|
||||||
"total_pages": len(doc.pages),
|
"total_pages": len(doc.pages),
|
||||||
},
|
},
|
||||||
@ -1593,14 +1593,14 @@ class AmazonTextractPDFParser(BaseBlobParser):
|
|||||||
else:
|
else:
|
||||||
self.boto3_textract_client = client
|
self.boto3_textract_client = client
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Iterates over the Blob pages and returns an Iterator with a Document
|
"""Iterates over the Blob pages and returns an Iterator with a Document
|
||||||
for each page, like the other parsers If multi-page document, blob.path
|
for each page, like the other parsers If multi-page document, blob.path
|
||||||
has to be set to the S3 URI and for single page docs
|
has to be set to the S3 URI and for single page docs
|
||||||
the blob.data is taken
|
the blob.data is taken
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url_parse_result = urlparse(str(blob.path)) if blob.path else None # type: ignore[attr-defined]
|
url_parse_result = urlparse(str(blob.path)) if blob.path else None
|
||||||
# Either call with S3 path (multi-page) or with bytes (single-page)
|
# Either call with S3 path (multi-page) or with bytes (single-page)
|
||||||
if (
|
if (
|
||||||
url_parse_result
|
url_parse_result
|
||||||
@ -1608,13 +1608,13 @@ class AmazonTextractPDFParser(BaseBlobParser):
|
|||||||
and url_parse_result.netloc
|
and url_parse_result.netloc
|
||||||
):
|
):
|
||||||
textract_response_json = self.tc.call_textract(
|
textract_response_json = self.tc.call_textract(
|
||||||
input_document=str(blob.path), # type: ignore[attr-defined]
|
input_document=str(blob.path),
|
||||||
features=self.textract_features,
|
features=self.textract_features,
|
||||||
boto3_textract_client=self.boto3_textract_client,
|
boto3_textract_client=self.boto3_textract_client,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
textract_response_json = self.tc.call_textract(
|
textract_response_json = self.tc.call_textract(
|
||||||
input_document=blob.as_bytes(), # type: ignore[attr-defined]
|
input_document=blob.as_bytes(),
|
||||||
features=self.textract_features,
|
features=self.textract_features,
|
||||||
call_mode=self.tc.Textract_Call_Mode.FORCE_SYNC,
|
call_mode=self.tc.Textract_Call_Mode.FORCE_SYNC,
|
||||||
boto3_textract_client=self.boto3_textract_client,
|
boto3_textract_client=self.boto3_textract_client,
|
||||||
@ -1625,7 +1625,7 @@ class AmazonTextractPDFParser(BaseBlobParser):
|
|||||||
for idx, page in enumerate(document.pages):
|
for idx, page in enumerate(document.pages):
|
||||||
yield Document(
|
yield Document(
|
||||||
page_content=page.get_text(config=self.linearization_config),
|
page_content=page.get_text(config=self.linearization_config),
|
||||||
metadata={"source": blob.source, "page": idx + 1}, # type: ignore[attr-defined]
|
metadata={"source": blob.source, "page": idx + 1},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1645,23 +1645,23 @@ class DocumentIntelligenceParser(BaseBlobParser):
|
|||||||
self.client = client
|
self.client = client
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def _generate_docs(self, blob: Blob, result: Any) -> Iterator[Document]: # type: ignore[valid-type]
|
def _generate_docs(self, blob: Blob, result: Any) -> Iterator[Document]:
|
||||||
for p in result.pages:
|
for p in result.pages:
|
||||||
content = " ".join([line.content for line in p.lines])
|
content = " ".join([line.content for line in p.lines])
|
||||||
|
|
||||||
d = Document(
|
d = Document(
|
||||||
page_content=content,
|
page_content=content,
|
||||||
metadata={
|
metadata={
|
||||||
"source": blob.source, # type: ignore[attr-defined]
|
"source": blob.source,
|
||||||
"page": p.page_number,
|
"page": p.page_number,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob."""
|
"""Lazily parse the blob."""
|
||||||
|
|
||||||
with blob.as_bytes_io() as file_obj: # type: ignore[attr-defined]
|
with blob.as_bytes_io() as file_obj:
|
||||||
poller = self.client.begin_analyze_document(self.model, file_obj)
|
poller = self.client.begin_analyze_document(self.model, file_obj)
|
||||||
result = poller.result()
|
result = poller.result()
|
||||||
|
|
||||||
|
@ -11,6 +11,6 @@ from langchain_community.document_loaders.blob_loaders import Blob
|
|||||||
class TextParser(BaseBlobParser):
|
class TextParser(BaseBlobParser):
|
||||||
"""Parser for text blobs."""
|
"""Parser for text blobs."""
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[valid-type]
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob."""
|
"""Lazily parse the blob."""
|
||||||
yield Document(page_content=blob.as_string(), metadata={"source": blob.source}) # type: ignore[attr-defined]
|
yield Document(page_content=blob.as_string(), metadata={"source": blob.source})
|
||||||
|
@ -91,7 +91,7 @@ class UnstructuredPDFLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> list:
|
def _get_elements(self) -> list:
|
||||||
from unstructured.partition.pdf import partition_pdf
|
from unstructured.partition.pdf import partition_pdf
|
||||||
|
|
||||||
return partition_pdf(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_pdf(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BasePDFLoader(BaseLoader, ABC):
|
class BasePDFLoader(BaseLoader, ABC):
|
||||||
@ -299,11 +299,9 @@ class PyPDFLoader(BasePDFLoader):
|
|||||||
In this way, a paragraph can be continued on the next page.
|
In this way, a paragraph can be continued on the next page.
|
||||||
"""
|
"""
|
||||||
if self.web_path:
|
if self.web_path:
|
||||||
blob = Blob.from_data( # type: ignore[attr-defined]
|
blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
|
||||||
open(self.file_path, "rb").read(), path=self.web_path
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
yield from self.parser.lazy_parse(blob)
|
yield from self.parser.lazy_parse(blob)
|
||||||
|
|
||||||
|
|
||||||
@ -415,11 +413,9 @@ class PyPDFium2Loader(BasePDFLoader):
|
|||||||
In this way, a paragraph can be continued on the next page.
|
In this way, a paragraph can be continued on the next page.
|
||||||
"""
|
"""
|
||||||
if self.web_path:
|
if self.web_path:
|
||||||
blob = Blob.from_data( # type: ignore[attr-defined]
|
blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
|
||||||
open(self.file_path, "rb").read(), path=self.web_path
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
yield from self.parser.parse(blob)
|
yield from self.parser.parse(blob)
|
||||||
|
|
||||||
|
|
||||||
@ -674,11 +670,9 @@ class PDFMinerLoader(BasePDFLoader):
|
|||||||
In this way, a paragraph can be continued on the next page.
|
In this way, a paragraph can be continued on the next page.
|
||||||
"""
|
"""
|
||||||
if self.web_path:
|
if self.web_path:
|
||||||
blob = Blob.from_data( # type: ignore[attr-defined]
|
blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
|
||||||
open(self.file_path, "rb").read(), path=self.web_path
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
yield from self.parser.lazy_parse(blob)
|
yield from self.parser.lazy_parse(blob)
|
||||||
|
|
||||||
|
|
||||||
@ -850,9 +844,9 @@ class PyMuPDFLoader(BasePDFLoader):
|
|||||||
)
|
)
|
||||||
parser = self.parser
|
parser = self.parser
|
||||||
if self.web_path:
|
if self.web_path:
|
||||||
blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path) # type: ignore[attr-defined]
|
blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
|
||||||
else:
|
else:
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
yield from parser._lazy_parse(blob, text_kwargs=kwargs)
|
yield from parser._lazy_parse(blob, text_kwargs=kwargs)
|
||||||
|
|
||||||
def load(self, **kwargs: Any) -> list[Document]:
|
def load(self, **kwargs: Any) -> list[Document]:
|
||||||
@ -1046,9 +1040,9 @@ class PDFPlumberLoader(BasePDFLoader):
|
|||||||
extract_images=self.extract_images,
|
extract_images=self.extract_images,
|
||||||
)
|
)
|
||||||
if self.web_path:
|
if self.web_path:
|
||||||
blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path) # type: ignore[attr-defined]
|
blob = Blob.from_data(open(self.file_path, "rb").read(), path=self.web_path)
|
||||||
else:
|
else:
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
return parser.parse(blob)
|
return parser.parse(blob)
|
||||||
|
|
||||||
|
|
||||||
@ -1163,7 +1157,7 @@ class AmazonTextractPDFLoader(BasePDFLoader):
|
|||||||
# raises ValueError when multipage and not on S3"""
|
# raises ValueError when multipage and not on S3"""
|
||||||
|
|
||||||
if self.web_path and self._is_s3_url(self.web_path):
|
if self.web_path and self._is_s3_url(self.web_path):
|
||||||
blob = Blob(path=self.web_path) # type: ignore[call-arg] # type: ignore[misc]
|
blob = Blob(path=self.web_path)
|
||||||
else:
|
else:
|
||||||
blob = Blob.from_path(self.file_path)
|
blob = Blob.from_path(self.file_path)
|
||||||
if AmazonTextractPDFLoader._get_number_of_pages(blob) > 1:
|
if AmazonTextractPDFLoader._get_number_of_pages(blob) > 1:
|
||||||
@ -1176,7 +1170,7 @@ class AmazonTextractPDFLoader(BasePDFLoader):
|
|||||||
yield from self.parser.parse(blob)
|
yield from self.parser.parse(blob)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_number_of_pages(blob: Blob) -> int: # type: ignore[valid-type]
|
def _get_number_of_pages(blob: Blob) -> int:
|
||||||
try:
|
try:
|
||||||
import pypdf
|
import pypdf
|
||||||
from PIL import Image, ImageSequence
|
from PIL import Image, ImageSequence
|
||||||
@ -1186,22 +1180,20 @@ class AmazonTextractPDFLoader(BasePDFLoader):
|
|||||||
"Could not import pypdf or Pilloe python package. "
|
"Could not import pypdf or Pilloe python package. "
|
||||||
"Please install it with `pip install pypdf Pillow`."
|
"Please install it with `pip install pypdf Pillow`."
|
||||||
)
|
)
|
||||||
if blob.mimetype == "application/pdf": # type: ignore[attr-defined]
|
if blob.mimetype == "application/pdf":
|
||||||
with blob.as_bytes_io() as input_pdf_file: # type: ignore[attr-defined]
|
with blob.as_bytes_io() as input_pdf_file:
|
||||||
pdf_reader = pypdf.PdfReader(input_pdf_file)
|
pdf_reader = pypdf.PdfReader(input_pdf_file)
|
||||||
return len(pdf_reader.pages)
|
return len(pdf_reader.pages)
|
||||||
elif blob.mimetype == "image/tiff": # type: ignore[attr-defined]
|
elif blob.mimetype == "image/tiff":
|
||||||
num_pages = 0
|
num_pages = 0
|
||||||
img = Image.open(blob.as_bytes()) # type: ignore[attr-defined]
|
img = Image.open(blob.as_bytes())
|
||||||
for _, _ in enumerate(ImageSequence.Iterator(img)):
|
for _, _ in enumerate(ImageSequence.Iterator(img)):
|
||||||
num_pages += 1
|
num_pages += 1
|
||||||
return num_pages
|
return num_pages
|
||||||
elif blob.mimetype in ["image/png", "image/jpeg"]: # type: ignore[attr-defined]
|
elif blob.mimetype in ["image/png", "image/jpeg"]:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
raise ValueError( # type: ignore[attr-defined]
|
raise ValueError(f"unsupported mime type: {blob.mimetype}")
|
||||||
f"unsupported mime type: {blob.mimetype}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DedocPDFLoader(DedocBaseLoader):
|
class DedocPDFLoader(DedocBaseLoader):
|
||||||
@ -1348,7 +1340,7 @@ class DocumentIntelligenceLoader(BasePDFLoader):
|
|||||||
self,
|
self,
|
||||||
) -> Iterator[Document]:
|
) -> Iterator[Document]:
|
||||||
"""Lazy load given path as pages."""
|
"""Lazy load given path as pages."""
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
yield from self.parser.parse(blob)
|
yield from self.parser.parse(blob)
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ class UnstructuredPowerPointLoader(UnstructuredFileLoader):
|
|||||||
try:
|
try:
|
||||||
import magic # noqa: F401
|
import magic # noqa: F401
|
||||||
|
|
||||||
is_ppt = detect_filetype(self.file_path) == FileType.PPT # type: ignore[arg-type]
|
is_ppt = detect_filetype(self.file_path) == FileType.PPT
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_, extension = os.path.splitext(str(self.file_path))
|
_, extension = os.path.splitext(str(self.file_path))
|
||||||
is_ppt = extension == ".ppt"
|
is_ppt = extension == ".ppt"
|
||||||
@ -70,8 +70,8 @@ class UnstructuredPowerPointLoader(UnstructuredFileLoader):
|
|||||||
if is_ppt:
|
if is_ppt:
|
||||||
from unstructured.partition.ppt import partition_ppt
|
from unstructured.partition.ppt import partition_ppt
|
||||||
|
|
||||||
return partition_ppt(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_ppt(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
else:
|
else:
|
||||||
from unstructured.partition.pptx import partition_pptx
|
from unstructured.partition.pptx import partition_pptx
|
||||||
|
|
||||||
return partition_pptx(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_pptx(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -56,4 +56,4 @@ class UnstructuredRSTLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.rst import partition_rst
|
from unstructured.partition.rst import partition_rst
|
||||||
|
|
||||||
return partition_rst(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_rst(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -56,4 +56,4 @@ class UnstructuredRTFLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.rtf import partition_rtf
|
from unstructured.partition.rtf import partition_rtf
|
||||||
|
|
||||||
return partition_rtf(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_rtf(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -39,4 +39,4 @@ class UnstructuredTSVLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.tsv import partition_tsv
|
from unstructured.partition.tsv import partition_tsv
|
||||||
|
|
||||||
return partition_tsv(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_tsv(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -37,7 +37,7 @@ class VsdxLoader(BaseLoader, ABC):
|
|||||||
elif not os.path.isfile(self.file_path):
|
elif not os.path.isfile(self.file_path):
|
||||||
raise ValueError("File path %s is not a valid file or url" % self.file_path)
|
raise ValueError("File path %s is not a valid file or url" % self.file_path)
|
||||||
|
|
||||||
self.parser = VsdxParser() # type: ignore[misc]
|
self.parser = VsdxParser()
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if hasattr(self, "temp_file"):
|
if hasattr(self, "temp_file"):
|
||||||
@ -50,5 +50,5 @@ class VsdxLoader(BaseLoader, ABC):
|
|||||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
return bool(parsed.netloc) and bool(parsed.scheme)
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
blob = Blob.from_path(self.file_path) # type: ignore[attr-defined]
|
blob = Blob.from_path(self.file_path)
|
||||||
return list(self.parser.parse(blob))
|
return list(self.parser.parse(blob))
|
||||||
|
@ -33,7 +33,7 @@ class WeatherDataLoader(BaseLoader):
|
|||||||
def from_params(
|
def from_params(
|
||||||
cls, places: Sequence[str], *, openweathermap_api_key: Optional[str] = None
|
cls, places: Sequence[str], *, openweathermap_api_key: Optional[str] = None
|
||||||
) -> WeatherDataLoader:
|
) -> WeatherDataLoader:
|
||||||
client = OpenWeatherMapAPIWrapper(openweathermap_api_key=openweathermap_api_key) # type: ignore[call-arg]
|
client = OpenWeatherMapAPIWrapper(openweathermap_api_key=openweathermap_api_key)
|
||||||
return cls(client, places)
|
return cls(client, places)
|
||||||
|
|
||||||
def lazy_load(
|
def lazy_load(
|
||||||
|
@ -121,7 +121,7 @@ class UnstructuredWordDocumentLoader(UnstructuredFileLoader):
|
|||||||
try:
|
try:
|
||||||
import magic # noqa: F401
|
import magic # noqa: F401
|
||||||
|
|
||||||
is_doc = detect_filetype(self.file_path) == FileType.DOC # type: ignore[arg-type]
|
is_doc = detect_filetype(self.file_path) == FileType.DOC
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_, extension = os.path.splitext(str(self.file_path))
|
_, extension = os.path.splitext(str(self.file_path))
|
||||||
is_doc = extension == ".doc"
|
is_doc = extension == ".doc"
|
||||||
@ -132,8 +132,8 @@ class UnstructuredWordDocumentLoader(UnstructuredFileLoader):
|
|||||||
if is_doc:
|
if is_doc:
|
||||||
from unstructured.partition.doc import partition_doc
|
from unstructured.partition.doc import partition_doc
|
||||||
|
|
||||||
return partition_doc(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_doc(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
else:
|
else:
|
||||||
from unstructured.partition.docx import partition_docx
|
from unstructured.partition.docx import partition_docx
|
||||||
|
|
||||||
return partition_docx(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_docx(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -46,4 +46,4 @@ class UnstructuredXMLLoader(UnstructuredFileLoader):
|
|||||||
def _get_elements(self) -> List:
|
def _get_elements(self) -> List:
|
||||||
from unstructured.partition.xml import partition_xml
|
from unstructured.partition.xml import partition_xml
|
||||||
|
|
||||||
return partition_xml(filename=self.file_path, **self.unstructured_kwargs) # type: ignore[arg-type]
|
return partition_xml(filename=self.file_path, **self.unstructured_kwargs)
|
||||||
|
@ -33,7 +33,7 @@ class GoogleTranslateTransformer(BaseDocumentTransformer):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from google.api_core.client_options import ClientOptions
|
from google.api_core.client_options import ClientOptions
|
||||||
from google.cloud import translate # type: ignore[attr-defined]
|
from google.cloud import translate
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Install Google Cloud Translate to use this parser."
|
"Install Google Cloud Translate to use this parser."
|
||||||
@ -76,7 +76,7 @@ class GoogleTranslateTransformer(BaseDocumentTransformer):
|
|||||||
Options: `text/plain`, `text/html`
|
Options: `text/plain`, `text/html`
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from google.cloud import translate # type: ignore[attr-defined]
|
from google.cloud import translate
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Install Google Cloud Translate to use this parser."
|
"Install Google Cloud Translate to use this parser."
|
||||||
|
@ -58,7 +58,7 @@ class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
|
|||||||
new_documents = []
|
new_documents = []
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment]
|
extracted_metadata: Dict = self.tagging_chain.run(document.page_content)
|
||||||
new_document = Document(
|
new_document = Document(
|
||||||
page_content=document.page_content,
|
page_content=document.page_content,
|
||||||
metadata={**extracted_metadata, **document.metadata},
|
metadata={**extracted_metadata, **document.metadata},
|
||||||
|
@ -68,7 +68,7 @@ class AnyscaleEmbeddings(OpenAIEmbeddings):
|
|||||||
else:
|
else:
|
||||||
values["openai_api_base"] = values["anyscale_api_base"]
|
values["openai_api_base"] = values["anyscale_api_base"]
|
||||||
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
||||||
values["client"] = openai.Embedding # type: ignore[attr-defined]
|
values["client"] = openai.Embedding
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -20,7 +20,7 @@ from langchain_community.utils.openai import is_openai_v1
|
|||||||
removal="1.0",
|
removal="1.0",
|
||||||
alternative_import="langchain_openai.AzureOpenAIEmbeddings",
|
alternative_import="langchain_openai.AzureOpenAIEmbeddings",
|
||||||
)
|
)
|
||||||
class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
|
class AzureOpenAIEmbeddings(OpenAIEmbeddings):
|
||||||
"""`Azure OpenAI` Embeddings API."""
|
"""`Azure OpenAI` Embeddings API."""
|
||||||
|
|
||||||
azure_endpoint: Union[str, None] = None
|
azure_endpoint: Union[str, None] = None
|
||||||
@ -170,16 +170,16 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): # type: ignore[override]
|
|||||||
"default_query": self.default_query,
|
"default_query": self.default_query,
|
||||||
"http_client": self.http_client,
|
"http_client": self.http_client,
|
||||||
}
|
}
|
||||||
self.client = openai.AzureOpenAI(**client_params).embeddings # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
self.client = openai.AzureOpenAI(**client_params).embeddings
|
||||||
|
|
||||||
if self.azure_ad_async_token_provider:
|
if self.azure_ad_async_token_provider:
|
||||||
client_params["azure_ad_token_provider"] = (
|
client_params["azure_ad_token_provider"] = (
|
||||||
self.azure_ad_async_token_provider
|
self.azure_ad_async_token_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
|
||||||
else:
|
else:
|
||||||
self.client = openai.Embedding # type: ignore[attr-defined]
|
self.client = openai.Embedding
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -46,11 +46,11 @@ def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], An
|
|||||||
stop=stop_after_attempt(embeddings.max_retries),
|
stop=stop_after_attempt(embeddings.max_retries),
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout) # type: ignore[attr-defined]
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
| retry_if_exception_type(openai.error.APIError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
),
|
),
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
@ -68,11 +68,11 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
|
|||||||
stop=stop_after_attempt(embeddings.max_retries),
|
stop=stop_after_attempt(embeddings.max_retries),
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout) # type: ignore[attr-defined]
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
| retry_if_exception_type(openai.error.APIError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
),
|
),
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
@ -93,7 +93,7 @@ def _check_response(response: dict) -> dict:
|
|||||||
if any(len(d["embedding"]) == 1 for d in response["data"]):
|
if any(len(d["embedding"]) == 1 for d in response["data"]):
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
raise openai.error.APIError("LocalAI API returned an empty embedding") # type: ignore[attr-defined]
|
raise openai.error.APIError("LocalAI API returned an empty embedding")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
|||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
values["client"] = openai.Embedding # type: ignore[attr-defined]
|
values["client"] = openai.Embedding
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
@ -253,10 +253,10 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
|
|||||||
if self.openai_proxy:
|
if self.openai_proxy:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.proxy = { # type: ignore[attr-defined]
|
openai.proxy = {
|
||||||
"http": self.openai_proxy,
|
"http": self.openai_proxy,
|
||||||
"https": self.openai_proxy,
|
"https": self.openai_proxy,
|
||||||
} # type: ignore[assignment]
|
}
|
||||||
return openai_args
|
return openai_args
|
||||||
|
|
||||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
|
@ -72,7 +72,7 @@ class MlflowEmbeddings(Embeddings, BaseModel):
|
|||||||
for txt in _chunk(texts, 20):
|
for txt in _chunk(texts, 20):
|
||||||
resp = self._client.predict(
|
resp = self._client.predict(
|
||||||
endpoint=self.endpoint,
|
endpoint=self.endpoint,
|
||||||
inputs={"input": txt, **params}, # type: ignore[arg-type]
|
inputs={"input": txt, **params},
|
||||||
)
|
)
|
||||||
embeddings.extend(r["embedding"] for r in resp["data"])
|
embeddings.extend(r["embedding"] for r in resp["data"])
|
||||||
return embeddings
|
return embeddings
|
||||||
|
@ -74,8 +74,8 @@ class OctoAIEmbeddings(OpenAIEmbeddings):
|
|||||||
else:
|
else:
|
||||||
values["openai_api_base"] = values["endpoint_url"]
|
values["openai_api_base"] = values["endpoint_url"]
|
||||||
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
|
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
|
||||||
values["client"] = openai.Embedding # type: ignore[attr-defined]
|
values["client"] = openai.Embedding
|
||||||
values["async_client"] = openai.Embedding # type: ignore[attr-defined]
|
values["async_client"] = openai.Embedding
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
@ -58,11 +58,11 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any
|
|||||||
max=embeddings.retry_max_seconds,
|
max=embeddings.retry_max_seconds,
|
||||||
),
|
),
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout) # type: ignore[attr-defined]
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
| retry_if_exception_type(openai.error.APIError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
),
|
),
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
@ -85,11 +85,11 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
|
|||||||
max=embeddings.retry_max_seconds,
|
max=embeddings.retry_max_seconds,
|
||||||
),
|
),
|
||||||
retry=(
|
retry=(
|
||||||
retry_if_exception_type(openai.error.Timeout) # type: ignore[attr-defined]
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
| retry_if_exception_type(openai.error.APIError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
| retry_if_exception_type(openai.error.RateLimitError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError) # type: ignore[attr-defined]
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
),
|
),
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
@ -110,7 +110,7 @@ def _check_response(response: dict, skip_empty: bool = False) -> dict:
|
|||||||
if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
|
if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
raise openai.error.APIError("OpenAI API returned an empty embedding") # type: ignore[attr-defined]
|
raise openai.error.APIError("OpenAI API returned an empty embedding")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@ -357,7 +357,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
**client_params
|
**client_params
|
||||||
).embeddings
|
).embeddings
|
||||||
elif not values.get("client"):
|
elif not values.get("client"):
|
||||||
values["client"] = openai.Embedding # type: ignore[attr-defined]
|
values["client"] = openai.Embedding
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return values
|
return values
|
||||||
@ -390,10 +390,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"Please install it with `pip install openai`."
|
"Please install it with `pip install openai`."
|
||||||
)
|
)
|
||||||
|
|
||||||
openai.proxy = { # type: ignore[attr-defined]
|
openai.proxy = {
|
||||||
"http": self.openai_proxy,
|
"http": self.openai_proxy,
|
||||||
"https": self.openai_proxy,
|
"https": self.openai_proxy,
|
||||||
} # type: ignore[assignment]
|
}
|
||||||
return openai_args
|
return openai_args
|
||||||
|
|
||||||
# please refer to
|
# please refer to
|
||||||
|
@ -54,7 +54,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
|
|||||||
# Try to load the spaCy model
|
# Try to load the spaCy model
|
||||||
import spacy
|
import spacy
|
||||||
|
|
||||||
values["nlp"] = spacy.load(model_name) # type: ignore[arg-type]
|
values["nlp"] = spacy.load(model_name)
|
||||||
except OSError:
|
except OSError:
|
||||||
# If the model is not found, raise a ValueError
|
# If the model is not found, raise a ValueError
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -25,7 +25,7 @@ _MIN_BATCH_SIZE = 5
|
|||||||
removal="1.0",
|
removal="1.0",
|
||||||
alternative_import="langchain_google_vertexai.VertexAIEmbeddings",
|
alternative_import="langchain_google_vertexai.VertexAIEmbeddings",
|
||||||
)
|
)
|
||||||
class VertexAIEmbeddings(_VertexAICommon, Embeddings): # type: ignore[override]
|
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
|
||||||
"""Google Cloud VertexAI embedding models."""
|
"""Google Cloud VertexAI embedding models."""
|
||||||
|
|
||||||
# Instance context
|
# Instance context
|
||||||
@ -163,8 +163,8 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings): # type: ignore[override]
|
|||||||
DeadlineExceeded,
|
DeadlineExceeded,
|
||||||
]
|
]
|
||||||
retry_decorator = create_base_retry_decorator(
|
retry_decorator = create_base_retry_decorator(
|
||||||
error_types=errors, # type: ignore[arg-type]
|
error_types=errors,
|
||||||
max_retries=self.max_retries, # type: ignore[arg-type]
|
max_retries=self.max_retries,
|
||||||
)
|
)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
|
@ -205,7 +205,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # typ
|
|||||||
for text in texts:
|
for text in texts:
|
||||||
request = TextEmbeddingRequest(model_uri=model_uri, text=text)
|
request = TextEmbeddingRequest(model_uri=model_uri, text=text)
|
||||||
stub = EmbeddingsServiceStub(channel)
|
stub = EmbeddingsServiceStub(channel)
|
||||||
res = stub.TextEmbedding(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
|
res = stub.TextEmbedding(request, metadata=self.grpc_metadata)
|
||||||
result.append(list(res.embedding))
|
result.append(list(res.embedding))
|
||||||
time.sleep(self.sleep_interval)
|
time.sleep(self.sleep_interval)
|
||||||
|
|
||||||
|
@ -481,20 +481,20 @@ class Neo4jGraph(GraphStore):
|
|||||||
or e.code
|
or e.code
|
||||||
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
|
||||||
)
|
)
|
||||||
and "in an implicit transaction" in e.message # type: ignore[operator]
|
and "in an implicit transaction" in e.message
|
||||||
)
|
)
|
||||||
or ( # isPeriodicCommitError
|
or ( # isPeriodicCommitError
|
||||||
e.code == "Neo.ClientError.Statement.SemanticError"
|
e.code == "Neo.ClientError.Statement.SemanticError"
|
||||||
and (
|
and (
|
||||||
"in an open transaction is not possible" in e.message # type: ignore[operator]
|
"in an open transaction is not possible" in e.message
|
||||||
or "tried to execute in an explicit transaction" in e.message # type: ignore[operator]
|
or "tried to execute in an explicit transaction" in e.message
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise
|
raise
|
||||||
# fallback to allow implicit transactions
|
# fallback to allow implicit transactions
|
||||||
with self._driver.session(database=self._database) as session:
|
with self._driver.session(database=self._database) as session:
|
||||||
data = session.run(Query(text=query, timeout=self.timeout), params) # type: ignore[assignment]
|
data = session.run(Query(text=query, timeout=self.timeout), params)
|
||||||
json_data = [r.data() for r in data]
|
json_data = [r.data() for r in data]
|
||||||
if self.sanitize:
|
if self.sanitize:
|
||||||
json_data = [value_sanitize(el) for el in json_data]
|
json_data = [value_sanitize(el) for el in json_data]
|
||||||
|
@ -304,7 +304,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
stmt = insert_stmt.on_conflict_do_update(
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
# attr-defined type ignore
|
||||||
@ -318,7 +318,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
stmt = insert_stmt.on_conflict_do_update(
|
||||||
"uix_key_namespace", # Name of constraint
|
"uix_key_namespace", # Name of constraint
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
# attr-defined type ignore
|
||||||
@ -379,7 +379,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
stmt = insert_stmt.on_conflict_do_update(
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
# attr-defined type ignore
|
||||||
@ -393,7 +393,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
stmt = insert_stmt.on_conflict_do_update(
|
||||||
"uix_key_namespace", # Name of constraint
|
"uix_key_namespace", # Name of constraint
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
# attr-defined type ignore
|
||||||
@ -412,7 +412,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
records = (
|
records = (
|
||||||
# mypy does not recognize .all()
|
# mypy does not recognize .all()
|
||||||
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
|
session.query(UpsertionRecord.key)
|
||||||
.filter(
|
.filter(
|
||||||
and_(
|
and_(
|
||||||
UpsertionRecord.key.in_(keys),
|
UpsertionRecord.key.in_(keys),
|
||||||
@ -460,21 +460,15 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
# mypy does not recognize .all() or .filter()
|
# mypy does not recognize .all() or .filter()
|
||||||
if after:
|
if after:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at > after)
|
||||||
UpsertionRecord.updated_at > after
|
|
||||||
)
|
|
||||||
if before:
|
if before:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at < before)
|
||||||
UpsertionRecord.updated_at < before
|
|
||||||
)
|
|
||||||
if group_ids:
|
if group_ids:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||||
UpsertionRecord.group_id.in_(group_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
query = query.limit(limit) # type: ignore[attr-defined]
|
query = query.limit(limit)
|
||||||
records = query.all() # type: ignore[attr-defined]
|
records = query.all()
|
||||||
return [r.key for r in records] # type: ignore[misc]
|
return [r.key for r in records] # type: ignore[misc]
|
||||||
|
|
||||||
async def alist_keys(
|
async def alist_keys(
|
||||||
@ -493,20 +487,14 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
# mypy does not recognize .all() or .filter()
|
# mypy does not recognize .all() or .filter()
|
||||||
if after:
|
if after:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at > after)
|
||||||
UpsertionRecord.updated_at > after
|
|
||||||
)
|
|
||||||
if before:
|
if before:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at < before)
|
||||||
UpsertionRecord.updated_at < before
|
|
||||||
)
|
|
||||||
if group_ids:
|
if group_ids:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||||
UpsertionRecord.group_id.in_(group_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
query = query.limit(limit) # type: ignore[attr-defined]
|
query = query.limit(limit)
|
||||||
records = (await session.execute(query)).scalars().all()
|
records = (await session.execute(query)).scalars().all()
|
||||||
return list(records)
|
return list(records)
|
||||||
|
|
||||||
@ -519,7 +507,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
UpsertionRecord.key.in_(keys),
|
UpsertionRecord.key.in_(keys),
|
||||||
UpsertionRecord.namespace == self.namespace,
|
UpsertionRecord.namespace == self.namespace,
|
||||||
)
|
)
|
||||||
).delete() # type: ignore[attr-defined]
|
).delete()
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||||
|
@ -281,6 +281,6 @@ class AlephAlpha(LLM):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
aa = AlephAlpha() # type: ignore[call-arg]
|
aa = AlephAlpha()
|
||||||
|
|
||||||
print(aa.invoke("How are you?")) # noqa: T201
|
print(aa.invoke("How are you?")) # noqa: T201
|
||||||
|
@ -62,7 +62,7 @@ def create_llm_result(
|
|||||||
return LLMResult(generations=generations, llm_output=llm_output)
|
return LLMResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
|
||||||
class Anyscale(BaseOpenAI): # type: ignore[override]
|
class Anyscale(BaseOpenAI):
|
||||||
"""Anyscale large language models.
|
"""Anyscale large language models.
|
||||||
|
|
||||||
To use, you should have the environment variable ``ANYSCALE_API_KEY``set with your
|
To use, you should have the environment variable ``ANYSCALE_API_KEY``set with your
|
||||||
@ -136,7 +136,7 @@ class Anyscale(BaseOpenAI): # type: ignore[override]
|
|||||||
else:
|
else:
|
||||||
values["openai_api_base"] = values["anyscale_api_base"]
|
values["openai_api_base"] = values["anyscale_api_base"]
|
||||||
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
||||||
values["client"] = openai.Completion # type: ignore[attr-defined]
|
values["client"] = openai.Completion
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
|
@ -194,7 +194,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
|
|||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]["0"]
|
choice = json.loads(output)[0]["0"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
@ -239,7 +239,7 @@ class HFContentFormatter(ContentFormatterBase):
|
|||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]["0"]["generated_text"]
|
choice = json.loads(output)[0]["0"]["generated_text"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
@ -268,7 +268,7 @@ class DollyContentFormatter(ContentFormatterBase):
|
|||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]
|
choice = json.loads(output)[0]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
@ -315,7 +315,7 @@ class CustomOpenAIContentFormatter(ContentFormatterBase):
|
|||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]["0"]
|
choice = json.loads(output)[0]["0"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
if api_type == AzureMLEndpointApiType.serverless:
|
if api_type == AzureMLEndpointApiType.serverless:
|
||||||
try:
|
try:
|
||||||
@ -327,7 +327,7 @@ class CustomOpenAIContentFormatter(ContentFormatterBase):
|
|||||||
"received."
|
"received."
|
||||||
)
|
)
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
return Generation(
|
return Generation(
|
||||||
text=choice["text"].strip(),
|
text=choice["text"].strip(),
|
||||||
generation_info=dict(
|
generation_info=dict(
|
||||||
|
@ -19,7 +19,7 @@ DEFAULT_NUM_TRIES = 10
|
|||||||
DEFAULT_SLEEP_TIME = 4
|
DEFAULT_SLEEP_TIME = 4
|
||||||
|
|
||||||
|
|
||||||
class Beam(LLM): # type: ignore[override, override, override, override]
|
class Beam(LLM):
|
||||||
"""Beam API for gpt2 large language model.
|
"""Beam API for gpt2 large language model.
|
||||||
|
|
||||||
To use, you should have the ``beam-sdk`` python package installed,
|
To use, you should have the ``beam-sdk`` python package installed,
|
||||||
|
@ -97,8 +97,8 @@ class GooseAI(LLM):
|
|||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.api_key = gooseai_api_key.get_secret_value()
|
openai.api_key = gooseai_api_key.get_secret_value()
|
||||||
openai.api_base = "https://api.goose.ai/v1" # type: ignore[attr-defined]
|
openai.api_base = "https://api.goose.ai/v1"
|
||||||
values["client"] = openai.Completion # type: ignore[attr-defined]
|
values["client"] = openai.Completion
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
|
@ -12,7 +12,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
# Ignoring type because below is valid pydantic code
|
# Ignoring type because below is valid pydantic code
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class Params(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class Params(BaseModel, extra="allow"):
|
||||||
"""Parameters for the Javelin AI Gateway LLM."""
|
"""Parameters for the Javelin AI Gateway LLM."""
|
||||||
|
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
# Ignoring type because below is valid pydantic code
|
# Ignoring type because below is valid pydantic code
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class Params(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class Params(BaseModel, extra="allow"):
|
||||||
"""Parameters for the MLflow AI Gateway LLM."""
|
"""Parameters for the MLflow AI Gateway LLM."""
|
||||||
|
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
|
@ -10,7 +10,7 @@ DEFAULT_BASE_URL = "https://text.octoai.run/v1/"
|
|||||||
DEFAULT_MODEL = "codellama-7b-instruct"
|
DEFAULT_MODEL = "codellama-7b-instruct"
|
||||||
|
|
||||||
|
|
||||||
class OctoAIEndpoint(BaseOpenAI): # type: ignore[override]
|
class OctoAIEndpoint(BaseOpenAI):
|
||||||
"""OctoAI LLM Endpoints - OpenAI compatible.
|
"""OctoAI LLM Endpoints - OpenAI compatible.
|
||||||
|
|
||||||
OctoAIEndpoint is a class to interact with OctoAI Compute Service large
|
OctoAIEndpoint is a class to interact with OctoAI Compute Service large
|
||||||
@ -102,7 +102,7 @@ class OctoAIEndpoint(BaseOpenAI): # type: ignore[override]
|
|||||||
else:
|
else:
|
||||||
values["openai_api_base"] = values["octoai_api_base"]
|
values["openai_api_base"] = values["octoai_api_base"]
|
||||||
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
|
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
|
||||||
values["client"] = openai.Completion # type: ignore[attr-defined]
|
values["client"] = openai.Completion
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
|
@ -319,9 +319,9 @@ class _OllamaCommon(BaseLanguageModel):
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
**(self.headers if isinstance(self.headers, dict) else {}),
|
**(self.headers if isinstance(self.headers, dict) else {}),
|
||||||
},
|
},
|
||||||
auth=self.auth, # type: ignore[arg-type]
|
auth=self.auth, # type: ignore[arg-type,unused-ignore]
|
||||||
json=request_payload,
|
json=request_payload,
|
||||||
timeout=self.timeout, # type: ignore[arg-type]
|
timeout=self.timeout, # type: ignore[arg-type,unused-ignore]
|
||||||
) as response:
|
) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
if response.status == 404:
|
if response.status == 404:
|
||||||
|
@ -100,11 +100,11 @@ def _create_retry_decorator(
|
|||||||
import openai
|
import openai
|
||||||
|
|
||||||
errors = [
|
errors = [
|
||||||
openai.error.Timeout, # type: ignore[attr-defined]
|
openai.error.Timeout,
|
||||||
openai.error.APIError, # type: ignore[attr-defined]
|
openai.error.APIError,
|
||||||
openai.error.APIConnectionError, # type: ignore[attr-defined]
|
openai.error.APIConnectionError,
|
||||||
openai.error.RateLimitError, # type: ignore[attr-defined]
|
openai.error.RateLimitError,
|
||||||
openai.error.ServiceUnavailableError, # type: ignore[attr-defined]
|
openai.error.ServiceUnavailableError,
|
||||||
]
|
]
|
||||||
return create_base_retry_decorator(
|
return create_base_retry_decorator(
|
||||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
@ -323,7 +323,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
if not values.get("async_client"):
|
if not values.get("async_client"):
|
||||||
values["async_client"] = openai.AsyncOpenAI(**client_params).completions
|
values["async_client"] = openai.AsyncOpenAI(**client_params).completions
|
||||||
elif not values.get("client"):
|
elif not values.get("client"):
|
||||||
values["client"] = openai.Completion # type: ignore[attr-defined]
|
values["client"] = openai.Completion
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -607,7 +607,7 @@ class BaseOpenAI(BaseLLM):
|
|||||||
if self.openai_proxy:
|
if self.openai_proxy:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined]
|
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}
|
||||||
return {**openai_creds, **self._default_params}
|
return {**openai_creds, **self._default_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -943,7 +943,7 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
).completions
|
).completions
|
||||||
|
|
||||||
else:
|
else:
|
||||||
values["client"] = openai.Completion # type: ignore[attr-defined]
|
values["client"] = openai.Completion
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -1068,18 +1068,18 @@ class OpenAIChat(BaseLLM):
|
|||||||
|
|
||||||
openai.api_key = openai_api_key
|
openai.api_key = openai_api_key
|
||||||
if openai_api_base:
|
if openai_api_base:
|
||||||
openai.api_base = openai_api_base # type: ignore[attr-defined]
|
openai.api_base = openai_api_base
|
||||||
if openai_organization:
|
if openai_organization:
|
||||||
openai.organization = openai_organization
|
openai.organization = openai_organization
|
||||||
if openai_proxy:
|
if openai_proxy:
|
||||||
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined] # type: ignore[attr-defined]
|
openai.proxy = {"http": openai_proxy, "https": openai_proxy}
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import openai python package. "
|
"Could not import openai python package. "
|
||||||
"Please install it with `pip install openai`."
|
"Please install it with `pip install openai`."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
values["client"] = openai.ChatCompletion # type: ignore[attr-defined]
|
values["client"] = openai.ChatCompletion
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
|
@ -49,7 +49,7 @@ def is_gemini_model(model_name: str) -> bool:
|
|||||||
return model_name is not None and "gemini" in model_name
|
return model_name is not None and "gemini" in model_name
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry( # type: ignore[no-redef]
|
def completion_with_retry(
|
||||||
llm: VertexAI,
|
llm: VertexAI,
|
||||||
prompt: List[Union[str, "Image"]],
|
prompt: List[Union[str, "Image"]],
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
@ -124,7 +124,7 @@ class _VertexAIBase(BaseModel):
|
|||||||
return cls.task_executor
|
return cls.task_executor
|
||||||
|
|
||||||
|
|
||||||
class _VertexAICommon(_VertexAIBase): # type: ignore[override]
|
class _VertexAICommon(_VertexAIBase):
|
||||||
client: "_LanguageModel" = None #: :meta private:
|
client: "_LanguageModel" = None #: :meta private:
|
||||||
client_preview: "_LanguageModel" = None #: :meta private:
|
client_preview: "_LanguageModel" = None #: :meta private:
|
||||||
model_name: str
|
model_name: str
|
||||||
@ -208,7 +208,7 @@ class _VertexAICommon(_VertexAIBase): # type: ignore[override]
|
|||||||
removal="1.0",
|
removal="1.0",
|
||||||
alternative_import="langchain_google_vertexai.VertexAI",
|
alternative_import="langchain_google_vertexai.VertexAI",
|
||||||
)
|
)
|
||||||
class VertexAI(_VertexAICommon, BaseLLM): # type: ignore[override]
|
class VertexAI(_VertexAICommon, BaseLLM):
|
||||||
"""Google Vertex AI large language models."""
|
"""Google Vertex AI large language models."""
|
||||||
|
|
||||||
model_name: str = "text-bison"
|
model_name: str = "text-bison"
|
||||||
@ -332,7 +332,7 @@ class VertexAI(_VertexAICommon, BaseLLM): # type: ignore[override]
|
|||||||
generation += chunk
|
generation += chunk
|
||||||
generations.append([generation])
|
generations.append([generation])
|
||||||
else:
|
else:
|
||||||
res = completion_with_retry( # type: ignore[misc]
|
res = completion_with_retry(
|
||||||
self,
|
self,
|
||||||
[prompt],
|
[prompt],
|
||||||
stream=should_stream,
|
stream=should_stream,
|
||||||
@ -375,7 +375,7 @@ class VertexAI(_VertexAICommon, BaseLLM): # type: ignore[override]
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||||
for stream_resp in completion_with_retry( # type: ignore[misc]
|
for stream_resp in completion_with_retry(
|
||||||
self,
|
self,
|
||||||
[prompt],
|
[prompt],
|
||||||
stream=True,
|
stream=True,
|
||||||
@ -448,7 +448,7 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|||||||
@property
|
@property
|
||||||
def endpoint_path(self) -> str:
|
def endpoint_path(self) -> str:
|
||||||
return self.client.endpoint_path(
|
return self.client.endpoint_path(
|
||||||
project=self.project, # type: ignore[arg-type]
|
project=self.project,
|
||||||
location=self.location,
|
location=self.location,
|
||||||
endpoint=self.endpoint_id,
|
endpoint=self.endpoint_id,
|
||||||
)
|
)
|
||||||
|
@ -235,7 +235,7 @@ def _make_request(
|
|||||||
messages=[Message(role="user", text=prompt)],
|
messages=[Message(role="user", text=prompt)],
|
||||||
)
|
)
|
||||||
stub = TextGenerationServiceStub(channel)
|
stub = TextGenerationServiceStub(channel)
|
||||||
res = stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
|
res = stub.Completion(request, metadata=self.grpc_metadata)
|
||||||
return list(res)[0].alternatives[0].message.text
|
return list(res)[0].alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
|
|||||||
messages=[Message(role="user", text=prompt)],
|
messages=[Message(role="user", text=prompt)],
|
||||||
)
|
)
|
||||||
stub = TextGenerationAsyncServiceStub(channel)
|
stub = TextGenerationAsyncServiceStub(channel)
|
||||||
operation = await stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
|
operation = await stub.Completion(request, metadata=self.grpc_metadata)
|
||||||
async with grpc.aio.secure_channel(
|
async with grpc.aio.secure_channel(
|
||||||
operation_api_url, channel_credentials
|
operation_api_url, channel_credentials
|
||||||
) as operation_channel:
|
) as operation_channel:
|
||||||
@ -301,7 +301,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
|
|||||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||||
operation = await operation_stub.Get(
|
operation = await operation_stub.Get(
|
||||||
operation_request,
|
operation_request,
|
||||||
metadata=self.grpc_metadata, # type: ignore[attr-defined]
|
metadata=self.grpc_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_response = CompletionResponse()
|
completion_response = CompletionResponse()
|
||||||
|
@ -8,7 +8,7 @@ try:
|
|||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
from zep_cloud import MemoryGetRequestMemoryType
|
from zep_cloud import MemoryGetRequestMemoryType
|
||||||
|
|
||||||
class ZepCloudMemory(ConversationBufferMemory): # type: ignore[override]
|
class ZepCloudMemory(ConversationBufferMemory):
|
||||||
"""Persist your chain history to the Zep MemoryStore.
|
"""Persist your chain history to the Zep MemoryStore.
|
||||||
|
|
||||||
Documentation: https://help.getzep.com
|
Documentation: https://help.getzep.com
|
||||||
|
@ -7,7 +7,7 @@ from langchain_community.chat_message_histories import ZepChatMessageHistory
|
|||||||
try:
|
try:
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
|
||||||
class ZepMemory(ConversationBufferMemory): # type: ignore[override]
|
class ZepMemory(ConversationBufferMemory):
|
||||||
"""Persist your chain history to the Zep MemoryStore.
|
"""Persist your chain history to the Zep MemoryStore.
|
||||||
|
|
||||||
The number of messages returned by Zep and when the Zep server summarizes chat
|
The number of messages returned by Zep and when the Zep server summarizes chat
|
||||||
|
@ -7,13 +7,13 @@ from langchain_core.retrievers import BaseRetriever
|
|||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class VectorSearchConfig(BaseModel, extra="allow"):
|
||||||
"""Configuration for vector search."""
|
"""Configuration for vector search."""
|
||||||
|
|
||||||
numberOfResults: int = 4
|
numberOfResults: int = 4
|
||||||
|
|
||||||
|
|
||||||
class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class RetrievalConfig(BaseModel, extra="allow"):
|
||||||
"""Configuration for retrieval."""
|
"""Configuration for retrieval."""
|
||||||
|
|
||||||
vectorSearchConfiguration: VectorSearchConfig
|
vectorSearchConfiguration: VectorSearchConfig
|
||||||
|
@ -68,7 +68,7 @@ Dates are also represented as str.
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class Highlight(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class Highlight(BaseModel, extra="allow"):
|
||||||
"""Information that highlights the keywords in the excerpt."""
|
"""Information that highlights the keywords in the excerpt."""
|
||||||
|
|
||||||
BeginOffset: int
|
BeginOffset: int
|
||||||
@ -82,7 +82,7 @@ class Highlight(BaseModel, extra="allow"): # type: ignore[call-arg]
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class TextWithHighLights(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class TextWithHighLights(BaseModel, extra="allow"):
|
||||||
"""Text with highlights."""
|
"""Text with highlights."""
|
||||||
|
|
||||||
Text: str
|
Text: str
|
||||||
@ -92,9 +92,7 @@ class TextWithHighLights(BaseModel, extra="allow"): # type: ignore[call-arg]
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class AdditionalResultAttributeValue( # type: ignore[call-arg]
|
class AdditionalResultAttributeValue(BaseModel, extra="allow"):
|
||||||
BaseModel, extra="allow"
|
|
||||||
):
|
|
||||||
"""Value of an additional result attribute."""
|
"""Value of an additional result attribute."""
|
||||||
|
|
||||||
TextWithHighlightsValue: TextWithHighLights
|
TextWithHighlightsValue: TextWithHighLights
|
||||||
@ -102,7 +100,7 @@ class AdditionalResultAttributeValue( # type: ignore[call-arg]
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class AdditionalResultAttribute(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class AdditionalResultAttribute(BaseModel, extra="allow"):
|
||||||
"""Additional result attribute."""
|
"""Additional result attribute."""
|
||||||
|
|
||||||
Key: str
|
Key: str
|
||||||
@ -117,7 +115,7 @@ class AdditionalResultAttribute(BaseModel, extra="allow"): # type: ignore[call-
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class DocumentAttributeValue(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class DocumentAttributeValue(BaseModel, extra="allow"):
|
||||||
"""Value of a document attribute."""
|
"""Value of a document attribute."""
|
||||||
|
|
||||||
DateValue: Optional[str] = None
|
DateValue: Optional[str] = None
|
||||||
@ -148,7 +146,7 @@ class DocumentAttributeValue(BaseModel, extra="allow"): # type: ignore[call-arg
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class DocumentAttribute(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class DocumentAttribute(BaseModel, extra="allow"):
|
||||||
"""Document attribute."""
|
"""Document attribute."""
|
||||||
|
|
||||||
Key: str
|
Key: str
|
||||||
@ -158,7 +156,7 @@ class DocumentAttribute(BaseModel, extra="allow"): # type: ignore[call-arg]
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class ResultItem(BaseModel, ABC, extra="allow"): # type: ignore[call-arg]
|
class ResultItem(BaseModel, ABC, extra="allow"):
|
||||||
"""Base class of a result item."""
|
"""Base class of a result item."""
|
||||||
|
|
||||||
Id: Optional[str]
|
Id: Optional[str]
|
||||||
@ -288,7 +286,7 @@ class RetrieveResultItem(ResultItem):
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class QueryResult(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class QueryResult(BaseModel, extra="allow"):
|
||||||
"""`Amazon Kendra Query API` search result.
|
"""`Amazon Kendra Query API` search result.
|
||||||
|
|
||||||
It is composed of:
|
It is composed of:
|
||||||
@ -302,7 +300,7 @@ class QueryResult(BaseModel, extra="allow"): # type: ignore[call-arg]
|
|||||||
|
|
||||||
|
|
||||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
class RetrieveResult(BaseModel, extra="allow"): # type: ignore[call-arg]
|
class RetrieveResult(BaseModel, extra="allow"):
|
||||||
"""`Amazon Kendra Retrieve API` search result.
|
"""`Amazon Kendra Retrieve API` search result.
|
||||||
|
|
||||||
It is composed of:
|
It is composed of:
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user