mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 19:49:09 +00:00
PowerBI fix for table names with spaces (#4170)
small fix to make sure a table name with spaces is passed correctly to the API for the schema lookup.
This commit is contained in:
parent
b1e2e29222
commit
3095546851
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
@ -12,8 +13,6 @@ from aiohttp import ServerTimeoutError
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from requests.exceptions import Timeout
|
||||
|
||||
from langchain.tools.powerbi.prompt import SCHEMA_ERROR_RESPONSE, UNAUTHORIZED_RESPONSE
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = os.getenv("POWERBI_BASE_URL", "https://api.powerbi.com/v1.0/myorg")
|
||||
@ -63,27 +62,29 @@ class PowerBIDataset(BaseModel):
|
||||
@property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
"""Get the token."""
|
||||
from azure.core.exceptions import ClientAuthenticationError
|
||||
|
||||
token = None
|
||||
if self.token:
|
||||
token = self.token
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.token,
|
||||
}
|
||||
from azure.core.exceptions import ( # pylint: disable=import-outside-toplevel
|
||||
ClientAuthenticationError,
|
||||
)
|
||||
|
||||
if self.credential:
|
||||
try:
|
||||
token = self.credential.get_token(
|
||||
"https://analysis.windows.net/powerbi/api/.default"
|
||||
).token
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + token,
|
||||
}
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
raise ClientAuthenticationError(
|
||||
"Could not get a token from the supplied credentials."
|
||||
) from exc
|
||||
if not token:
|
||||
raise ClientAuthenticationError("No credential or token supplied.")
|
||||
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + token,
|
||||
}
|
||||
raise ClientAuthenticationError("No credential or token supplied.")
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
@ -116,10 +117,12 @@ class PowerBIDataset(BaseModel):
|
||||
return self.table_names
|
||||
|
||||
def _get_tables_todo(self, tables_todo: List[str]) -> List[str]:
|
||||
for table in tables_todo:
|
||||
"""Get the tables that still need to be queried."""
|
||||
todo = deepcopy(tables_todo)
|
||||
for table in todo:
|
||||
if table in self.schemas:
|
||||
tables_todo.remove(table)
|
||||
return tables_todo
|
||||
todo.remove(table)
|
||||
return todo
|
||||
|
||||
def _get_schema_for_tables(self, table_names: List[str]) -> str:
|
||||
"""Create a string of the table schemas for the supplied tables."""
|
||||
@ -135,19 +138,20 @@ class PowerBIDataset(BaseModel):
|
||||
tables_requested = self._get_tables_to_query(table_names)
|
||||
tables_todo = self._get_tables_todo(tables_requested)
|
||||
for table in tables_todo:
|
||||
if " " in table and not table.startswith("'") and not table.endswith("'"):
|
||||
table = f"'{table}'"
|
||||
try:
|
||||
result = self.run(
|
||||
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
||||
)
|
||||
except Timeout:
|
||||
_LOGGER.warning("Timeout while getting table info for %s", table)
|
||||
self.schemas[table] = "unknown"
|
||||
continue
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
if "bad request" in str(exc).lower():
|
||||
return SCHEMA_ERROR_RESPONSE
|
||||
if "unauthorized" in str(exc).lower():
|
||||
return UNAUTHORIZED_RESPONSE
|
||||
return str(exc)
|
||||
_LOGGER.warning("Error while getting table info for %s: %s", table, exc)
|
||||
self.schemas[table] = "unknown"
|
||||
continue
|
||||
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
||||
return self._get_schema_for_tables(tables_requested)
|
||||
|
||||
@ -158,19 +162,20 @@ class PowerBIDataset(BaseModel):
|
||||
tables_requested = self._get_tables_to_query(table_names)
|
||||
tables_todo = self._get_tables_todo(tables_requested)
|
||||
for table in tables_todo:
|
||||
if " " in table and not table.startswith("'") and not table.endswith("'"):
|
||||
table = f"'{table}'"
|
||||
try:
|
||||
result = await self.arun(
|
||||
f"EVALUATE TOPN({self.sample_rows_in_table_info}, {table})"
|
||||
)
|
||||
except ServerTimeoutError:
|
||||
_LOGGER.warning("Timeout while getting table info for %s", table)
|
||||
self.schemas[table] = "unknown"
|
||||
continue
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
if "bad request" in str(exc).lower():
|
||||
return SCHEMA_ERROR_RESPONSE
|
||||
if "unauthorized" in str(exc).lower():
|
||||
return UNAUTHORIZED_RESPONSE
|
||||
return str(exc)
|
||||
_LOGGER.warning("Error while getting table info for %s: %s", table, exc)
|
||||
self.schemas[table] = "unknown"
|
||||
continue
|
||||
self.schemas[table] = json_to_md(result["results"][0]["tables"][0]["rows"])
|
||||
return self._get_schema_for_tables(tables_requested)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user