Tweak type hints to match dependency's behavior. (#11355)

Needs #11353 to merge first, and a new `langchain` to be published with
those changes.
This commit is contained in:
Predrag Gruevski
2023-10-04 22:36:58 -04:00
committed by GitHub
parent 940b9ae30a
commit c9986bc3a9
3 changed files with 42 additions and 85 deletions

View File

@@ -1,7 +1,7 @@
"""Vector SQL Database Chain Retriever"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.llm import LLMChain
@@ -76,10 +76,8 @@ class VectorSQLRetrieveAllOutputParser(VectorSQLOutputParser):
return super().parse(text)
def get_result_from_sqldb(
db: SQLDatabase, cmd: str
) -> Union[str, List[Dict[str, Any]], Dict[str, Any]]:
result = db._execute(cmd, fetch="all") # type: ignore
def get_result_from_sqldb(db: SQLDatabase, cmd: str) -> Sequence[Dict[str, Any]]:
result = db._execute(cmd, fetch="all")
return result
@@ -179,8 +177,9 @@ class VectorSQLDatabaseChain(SQLDatabaseChain):
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
# If return direct, we just set the final result equal to
# the result of the sql query result, otherwise try to get a human readable
# final answer
# the result of the sql query result (`Sequence[Dict[str, Any]]`),
# otherwise try to get a human readable final answer (`str`).
final_result: Union[str, Sequence[Dict[str, Any]]]
if self.return_direct:
final_result = result
else: