mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 17:45:31 +00:00
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import errno
|
||||
import socket
|
||||
from typing import Set, Tuple
|
||||
|
||||
|
||||
def _get_ip_address(address: str = "10.254.254.254:1") -> str:
|
||||
@@ -22,3 +23,34 @@ def _get_ip_address(address: str = "10.254.254.254:1") -> str:
|
||||
finally:
|
||||
s.close()
|
||||
return curr_address
|
||||
|
||||
|
||||
async def _async_get_free_port(
|
||||
port_range: Tuple[int, int], timeout: int, used_ports: Set[int]
|
||||
):
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, _get_free_port, port_range, timeout, used_ports
|
||||
)
|
||||
|
||||
|
||||
def _get_free_port(port_range: Tuple[int, int], timeout: int, used_ports: Set[int]):
|
||||
import random
|
||||
|
||||
available_ports = set(range(port_range[0], port_range[1] + 1)) - used_ports
|
||||
if not available_ports:
|
||||
raise RuntimeError("No available ports in the specified range")
|
||||
|
||||
while available_ports:
|
||||
port = random.choice(list(available_ports))
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", port))
|
||||
used_ports.add(port)
|
||||
return port
|
||||
except OSError:
|
||||
available_ports.remove(port)
|
||||
|
||||
raise RuntimeError("No available ports in the specified range")
|
||||
|
@@ -15,3 +15,29 @@ class PaginationResult(BaseModel, Generic[T]):
|
||||
total_pages: int = Field(..., description="total number of pages")
|
||||
page: int = Field(..., description="Current page number")
|
||||
page_size: int = Field(..., description="Number of items per page")
|
||||
|
||||
@classmethod
|
||||
def build_from_all(
|
||||
cls, all_items: List[T], page: int, page_size: int
|
||||
) -> "PaginationResult[T]":
|
||||
"""Build a pagination result from all items"""
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
total_count = len(all_items)
|
||||
total_pages = (
|
||||
(total_count + page_size - 1) // page_size if total_count > 0 else 0
|
||||
)
|
||||
page = max(1, min(page, total_pages)) if total_pages > 0 else 0
|
||||
start_index = (page - 1) * page_size if page > 0 else 0
|
||||
end_index = min(start_index + page_size, total_count)
|
||||
items = all_items[start_index:end_index]
|
||||
|
||||
return cls(
|
||||
items=items,
|
||||
total_count=total_count,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
87
dbgpt/util/serialization/check.py
Normal file
87
dbgpt/util/serialization/check.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import inspect
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Optional, TextIO
|
||||
|
||||
|
||||
def check_serializable(
|
||||
obj: Any, obj_name: str = "Object", error_msg: str = "Object is not serializable"
|
||||
):
|
||||
import cloudpickle
|
||||
|
||||
try:
|
||||
cloudpickle.dumps(obj)
|
||||
except Exception as e:
|
||||
inspect_info = inspect_serializability(obj, obj_name)
|
||||
msg = f"{error_msg}\n{inspect_info['report']}"
|
||||
raise TypeError(msg) from e
|
||||
|
||||
|
||||
class SerializabilityInspector:
|
||||
def __init__(self, stream: Optional[TextIO] = None):
|
||||
self.stream = stream or StringIO()
|
||||
self.failures = {}
|
||||
self.indent_level = 0
|
||||
|
||||
def log(self, message: str):
|
||||
indent = " " * self.indent_level
|
||||
self.stream.write(f"{indent}{message}\n")
|
||||
|
||||
def inspect(self, obj: Any, name: str, depth: int = 3) -> bool:
|
||||
import cloudpickle
|
||||
|
||||
self.log(f"Inspecting '{name}'")
|
||||
self.indent_level += 1
|
||||
|
||||
try:
|
||||
cloudpickle.dumps(obj)
|
||||
self.indent_level -= 1
|
||||
return True
|
||||
except Exception as e:
|
||||
self.failures[name] = str(e)
|
||||
self.log(f"Failure: {str(e)}")
|
||||
|
||||
if depth > 0:
|
||||
if inspect.isfunction(obj) or inspect.ismethod(obj):
|
||||
self._inspect_function(obj, depth - 1)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
self._inspect_object(obj, depth - 1)
|
||||
|
||||
self.indent_level -= 1
|
||||
return False
|
||||
|
||||
def _inspect_function(self, func, depth):
|
||||
closure = inspect.getclosurevars(func)
|
||||
for name, value in closure.nonlocals.items():
|
||||
self.inspect(value, f"{func.__name__}.{name}", depth)
|
||||
for name, value in closure.globals.items():
|
||||
self.inspect(value, f"global:{name}", depth)
|
||||
|
||||
def _inspect_object(self, obj, depth):
|
||||
for name, value in inspect.getmembers(obj):
|
||||
if not name.startswith("__"):
|
||||
self.inspect(value, f"{type(obj).__name__}.{name}", depth)
|
||||
|
||||
def get_report(self) -> str:
|
||||
summary = "\nSummary of Serialization Failures:\n"
|
||||
if not self.failures:
|
||||
summary += "All components are serializable.\n"
|
||||
else:
|
||||
for name, error in self.failures.items():
|
||||
summary += f" - {name}: {error}\n"
|
||||
|
||||
return self.stream.getvalue() + summary
|
||||
|
||||
|
||||
def inspect_serializability(
|
||||
obj: Any,
|
||||
name: Optional[str] = None,
|
||||
depth: int = 5,
|
||||
stream: Optional[TextIO] = None,
|
||||
) -> Dict[str, Any]:
|
||||
inspector = SerializabilityInspector(stream)
|
||||
success = inspector.inspect(obj, name or type(obj).__name__, depth)
|
||||
return {
|
||||
"success": success,
|
||||
"failures": inspector.failures,
|
||||
"report": inspector.get_report(),
|
||||
}
|
84
dbgpt/util/tests/test_pagination_utils.py
Normal file
84
dbgpt/util/tests/test_pagination_utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
|
||||
def test_build_from_all_normal_case():
|
||||
items = list(range(100))
|
||||
result = PaginationResult.build_from_all(items, page=2, page_size=20)
|
||||
|
||||
assert len(result.items) == 20
|
||||
assert result.items == list(range(20, 40))
|
||||
assert result.total_count == 100
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 2
|
||||
assert result.page_size == 20
|
||||
|
||||
|
||||
def test_build_from_all_empty_list():
|
||||
items = []
|
||||
result = PaginationResult.build_from_all(items, page=1, page_size=5)
|
||||
|
||||
assert result.items == []
|
||||
assert result.total_count == 0
|
||||
assert result.total_pages == 0
|
||||
assert result.page == 0
|
||||
assert result.page_size == 5
|
||||
|
||||
|
||||
def test_build_from_all_last_page():
|
||||
items = list(range(95))
|
||||
result = PaginationResult.build_from_all(items, page=5, page_size=20)
|
||||
|
||||
assert len(result.items) == 15
|
||||
assert result.items == list(range(80, 95))
|
||||
assert result.total_count == 95
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 5
|
||||
assert result.page_size == 20
|
||||
|
||||
|
||||
def test_build_from_all_page_out_of_range():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=10, page_size=10)
|
||||
|
||||
assert len(result.items) == 10
|
||||
assert result.items == list(range(40, 50))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 5
|
||||
assert result.page_size == 10
|
||||
|
||||
|
||||
def test_build_from_all_page_zero():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=0, page_size=10)
|
||||
|
||||
assert len(result.items) == 10
|
||||
assert result.items == list(range(0, 10))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 1
|
||||
assert result.page_size == 10
|
||||
|
||||
|
||||
def test_build_from_all_negative_page():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=-1, page_size=10)
|
||||
|
||||
assert len(result.items) == 10
|
||||
assert result.items == list(range(0, 10))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 5
|
||||
assert result.page == 1
|
||||
assert result.page_size == 10
|
||||
|
||||
|
||||
def test_build_from_all_page_size_larger_than_total():
|
||||
items = list(range(50))
|
||||
result = PaginationResult.build_from_all(items, page=1, page_size=100)
|
||||
|
||||
assert len(result.items) == 50
|
||||
assert result.items == list(range(50))
|
||||
assert result.total_count == 50
|
||||
assert result.total_pages == 1
|
||||
assert result.page == 1
|
||||
assert result.page_size == 100
|
Reference in New Issue
Block a user