chore: Add pylint for DB-GPT rag lib (#1267)

This commit is contained in:
Fangyin Cheng
2024-03-07 23:27:43 +08:00
committed by GitHub
parent aaaf34db17
commit 7446817340
70 changed files with 1135 additions and 587 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
@@ -499,7 +499,7 @@ class RDBMSDatabase(BaseConnect):
ans = cursor.fetchall()
return ans[0][1]
def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(

View File

@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
import sqlparse
from sqlalchemy import MetaData, text
@@ -145,7 +145,7 @@ class ClickhouseConnect(RDBMSDatabase):
for name, column_type, _, _, comment in fields[0]
]
def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
session = self.client

View File

@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
@@ -68,7 +68,7 @@ class DorisConnect(RDBMSDatabase):
"""Get user info."""
return []
def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
cursor = self.get_session().execute(
text(

View File

@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
@@ -85,7 +85,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
print("postgresql get users error: ", e)
return []
def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(

View File

@@ -4,7 +4,7 @@
import logging
import os
import tempfile
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple
from sqlalchemy import create_engine, text
@@ -58,7 +58,7 @@ class SQLiteConnect(RDBMSDatabase):
ans = cursor.fetchall()
return ans[0][0]
def get_fields(self, table_name):
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
fields = cursor.fetchall()

View File

@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
@@ -68,7 +68,7 @@ class StarRocksConnect(RDBMSDatabase):
"""Get user info."""
return []
def get_fields(self, table_name, db_name="database()"):
def get_fields(self, table_name, db_name="database()") -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
if db_name != "database()":