mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
PowerBI fix for table names with spaces (#4170)
small fix to make sure a table name with spaces is passed correctly to the API for the schema lookup.
This commit is contained in:
parent
b1e2e29222
commit
3095546851
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -12,8 +13,6 @@ from aiohttp import ServerTimeoutError
|
|||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, root_validator
|
||||||
from requests.exceptions import Timeout
|
from requests.exceptions import Timeout
|
||||||
|
|
||||||
from langchain.tools.powerbi.prompt import SCHEMA_ERROR_RESPONSE, UNAUTHORIZED_RESPONSE
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
BASE_URL = os.getenv("POWERBI_BASE_URL", "https://api.powerbi.com/v1.0/myorg")
|
BASE_URL = os.getenv("POWERBI_BASE_URL", "https://api.powerbi.com/v1.0/myorg")
|
||||||
@ -63,27 +62,29 @@ class PowerBIDataset(BaseModel):
|
|||||||
@property
|
@property
|
||||||
def headers(self) -> Dict[str, str]:
|
def headers(self) -> Dict[str, str]:
|
||||||
"""Get the token."""
|
"""Get the token."""
|
||||||
from azure.core.exceptions import ClientAuthenticationError
|
|
||||||
|
|
||||||
token = None
|
|
||||||
if self.token:
|
if self.token:
|
||||||
token = self.token
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer " + self.token,
|
||||||
|
}
|
||||||
|
from azure.core.exceptions import ( # pylint: disable=import-outside-toplevel
|
||||||
|
ClientAuthenticationError,
|
||||||
|
)
|
||||||
|
|
||||||
if self.credential:
|
if self.credential:
|
||||||
try:
|
try:
|
||||||
token = self.credential.get_token(
|
token = self.credential.get_token(
|
||||||
"https://analysis.windows.net/powerbi/api/.default"
|
"https://analysis.windows.net/powerbi/api/.default"
|
||||||
).token
|
).token
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer " + token,
|
||||||
|
}
|
||||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
raise ClientAuthenticationError(
|
raise ClientAuthenticationError(
|
||||||
"Could not get a token from the supplied credentials."
|
"Could not get a token from the supplied credentials."
|
||||||
) from exc
|
) from exc
|
||||||
if not token:
|
raise ClientAuthenticationError("No credential or token supplied.")
|
||||||
raise ClientAuthenticationError("No credential or token supplied.")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": "Bearer " + token,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_table_names(self) -> Iterable[str]:
|
def get_table_names(self) -> Iterable[str]:
|
||||||
"""Get names of tables available."""
|
"""Get names of tables available."""
|
||||||
@ -116,10 +117,12 @@ class PowerBIDataset(BaseModel):
|
|||||||
return self.table_names
|
return self.table_names
|
||||||
|
|
||||||
def _get_tables_todo(self, tables_todo: List[str]) -> List[str]:
|
def _get_tables_todo(self, tables_todo: List[str]) -> List[str]:
|
||||||
for table in tables_todo:
|
"""Get the tables that still need to be queried."""
|
||||||
|
todo = deepcopy(tables_todo)
|
||||||
|
for table in todo:
|
||||||
if table in self.schemas:
|
if table in self.schemas:
|
||||||
tables_todo.remove(table)
|
todo.remove(table)
|
||||||
return tables_todo
|
return todo
|
||||||
|
|
||||||
def _get_schema_for_tables(self, table_names: List[str]) -> str:
|
def _get_schema_for_tables(self, table_names: List[str]) -> str:
|
||||||
"""Create a string of the table schemas for the supplied tables."""
|
"""Create a string of the table schemas for the supplied tables."""
|
||||||
@ -135,19 +138,20 @@ class PowerBIDataset(BaseModel):
|
|||||||
tables_requested = self._get_tables_to_query(table_names)
|
tables_requested = self._get_tables_to_query(table_names)
|
||||||
tables_todo = self._get_tables_todo(tables_requested)
|
tables_todo = self._get_tables_todo(tables_requested)
|
||||||
for table in tables_todo:
|
for table in tables_todo:
|
||||||
|
if " " in table and not table.startswith("'") and not table.endswith("'"):
|
||||||
|
table = f"'{table}'"
|
||||||
try:
|
try:
|
||||||
result = self.run(
|
result = self.run(
|
||||||
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
||||||
)
|
)
|
||||||
except Timeout:
|
except Timeout:
|
||||||
_LOGGER.warning("Timeout while getting table info for %s", table)
|
_LOGGER.warning("Timeout while getting table info for %s", table)
|
||||||
|
self.schemas[table] = "unknown"
|
||||||
continue
|
continue
|
||||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
if "bad request" in str(exc).lower():
|
_LOGGER.warning("Error while getting table info for %s: %s", table, exc)
|
||||||
return SCHEMA_ERROR_RESPONSE
|
self.schemas[table] = "unknown"
|
||||||
if "unauthorized" in str(exc).lower():
|
continue
|
||||||
return UNAUTHORIZED_RESPONSE
|
|
||||||
return str(exc)
|
|
||||||
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
||||||
return self._get_schema_for_tables(tables_requested)
|
return self._get_schema_for_tables(tables_requested)
|
||||||
|
|
||||||
@ -158,19 +162,20 @@ class PowerBIDataset(BaseModel):
|
|||||||
tables_requested = self._get_tables_to_query(table_names)
|
tables_requested = self._get_tables_to_query(table_names)
|
||||||
tables_todo = self._get_tables_todo(tables_requested)
|
tables_todo = self._get_tables_todo(tables_requested)
|
||||||
for table in tables_todo:
|
for table in tables_todo:
|
||||||
|
if " " in table and not table.startswith("'") and not table.endswith("'"):
|
||||||
|
table = f"'{table}'"
|
||||||
try:
|
try:
|
||||||
result = await self.arun(
|
result = await self.arun(
|
||||||
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
||||||
)
|
)
|
||||||
except ServerTimeoutError:
|
except ServerTimeoutError:
|
||||||
_LOGGER.warning("Timeout while getting table info for %s", table)
|
_LOGGER.warning("Timeout while getting table info for %s", table)
|
||||||
|
self.schemas[table] = "unknown"
|
||||||
continue
|
continue
|
||||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
if "bad request" in str(exc).lower():
|
_LOGGER.warning("Error while getting table info for %s: %s", table, exc)
|
||||||
return SCHEMA_ERROR_RESPONSE
|
self.schemas[table] = "unknown"
|
||||||
if "unauthorized" in str(exc).lower():
|
continue
|
||||||
return UNAUTHORIZED_RESPONSE
|
|
||||||
return str(exc)
|
|
||||||
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
||||||
return self._get_schema_for_tables(tables_requested)
|
return self._get_schema_for_tables(tables_requested)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user