mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 19:09:57 +00:00
Tweak type hints to match dependency's behavior. (#11355)
Needs #11353 to merge first, and a new `langchain` to be published with those changes.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Vector SQL Database Chain Retriever"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
@@ -76,10 +76,8 @@ class VectorSQLRetrieveAllOutputParser(VectorSQLOutputParser):
|
||||
return super().parse(text)
|
||||
|
||||
|
||||
def get_result_from_sqldb(
|
||||
db: SQLDatabase, cmd: str
|
||||
) -> Union[str, List[Dict[str, Any]], Dict[str, Any]]:
|
||||
result = db._execute(cmd, fetch="all") # type: ignore
|
||||
def get_result_from_sqldb(db: SQLDatabase, cmd: str) -> Sequence[Dict[str, Any]]:
|
||||
result = db._execute(cmd, fetch="all")
|
||||
return result
|
||||
|
||||
|
||||
@@ -179,8 +177,9 @@ class VectorSQLDatabaseChain(SQLDatabaseChain):
|
||||
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
|
||||
# If return direct, we just set the final result equal to
|
||||
# the result of the sql query result, otherwise try to get a human readable
|
||||
# final answer
|
||||
# the result of the sql query result (`Sequence[Dict[str, Any]]`),
|
||||
# otherwise try to get a human readable final answer (`str`).
|
||||
final_result: Union[str, Sequence[Dict[str, Any]]]
|
||||
if self.return_direct:
|
||||
final_result = result
|
||||
else:
|
||||
|
Reference in New Issue
Block a user