feat(core): Support higher-order operators (#1984)

Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
Fangyin Cheng
2024-09-09 10:15:37 +08:00
committed by GitHub
parent f6d5fc4595
commit 65c875db20
62 changed files with 6281 additions and 386 deletions

View File

@@ -1,5 +1,6 @@
import errno
import socket
from typing import Set, Tuple
def _get_ip_address(address: str = "10.254.254.254:1") -> str:
@@ -22,3 +23,34 @@ def _get_ip_address(address: str = "10.254.254.254:1") -> str:
finally:
s.close()
return curr_address
async def _async_get_free_port(
port_range: Tuple[int, int], timeout: int, used_ports: Set[int]
):
import asyncio
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, _get_free_port, port_range, timeout, used_ports
)
def _get_free_port(port_range: Tuple[int, int], timeout: int, used_ports: Set[int]):
import random
available_ports = set(range(port_range[0], port_range[1] + 1)) - used_ports
if not available_ports:
raise RuntimeError("No available ports in the specified range")
while available_ports:
port = random.choice(list(available_ports))
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
used_ports.add(port)
return port
except OSError:
available_ports.remove(port)
raise RuntimeError("No available ports in the specified range")

View File

@@ -15,3 +15,29 @@ class PaginationResult(BaseModel, Generic[T]):
total_pages: int = Field(..., description="total number of pages")
page: int = Field(..., description="Current page number")
page_size: int = Field(..., description="Number of items per page")
@classmethod
def build_from_all(
cls, all_items: List[T], page: int, page_size: int
) -> "PaginationResult[T]":
"""Build a pagination result from all items"""
if page < 1:
page = 1
if page_size < 1:
page_size = 1
total_count = len(all_items)
total_pages = (
(total_count + page_size - 1) // page_size if total_count > 0 else 0
)
page = max(1, min(page, total_pages)) if total_pages > 0 else 0
start_index = (page - 1) * page_size if page > 0 else 0
end_index = min(start_index + page_size, total_count)
items = all_items[start_index:end_index]
return cls(
items=items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)

View File

@@ -0,0 +1,87 @@
import inspect
from io import StringIO
from typing import Any, Dict, Optional, TextIO
def check_serializable(
obj: Any, obj_name: str = "Object", error_msg: str = "Object is not serializable"
):
import cloudpickle
try:
cloudpickle.dumps(obj)
except Exception as e:
inspect_info = inspect_serializability(obj, obj_name)
msg = f"{error_msg}\n{inspect_info['report']}"
raise TypeError(msg) from e
class SerializabilityInspector:
def __init__(self, stream: Optional[TextIO] = None):
self.stream = stream or StringIO()
self.failures = {}
self.indent_level = 0
def log(self, message: str):
indent = " " * self.indent_level
self.stream.write(f"{indent}{message}\n")
def inspect(self, obj: Any, name: str, depth: int = 3) -> bool:
import cloudpickle
self.log(f"Inspecting '{name}'")
self.indent_level += 1
try:
cloudpickle.dumps(obj)
self.indent_level -= 1
return True
except Exception as e:
self.failures[name] = str(e)
self.log(f"Failure: {str(e)}")
if depth > 0:
if inspect.isfunction(obj) or inspect.ismethod(obj):
self._inspect_function(obj, depth - 1)
elif hasattr(obj, "__dict__"):
self._inspect_object(obj, depth - 1)
self.indent_level -= 1
return False
def _inspect_function(self, func, depth):
closure = inspect.getclosurevars(func)
for name, value in closure.nonlocals.items():
self.inspect(value, f"{func.__name__}.{name}", depth)
for name, value in closure.globals.items():
self.inspect(value, f"global:{name}", depth)
def _inspect_object(self, obj, depth):
for name, value in inspect.getmembers(obj):
if not name.startswith("__"):
self.inspect(value, f"{type(obj).__name__}.{name}", depth)
def get_report(self) -> str:
summary = "\nSummary of Serialization Failures:\n"
if not self.failures:
summary += "All components are serializable.\n"
else:
for name, error in self.failures.items():
summary += f" - {name}: {error}\n"
return self.stream.getvalue() + summary
def inspect_serializability(
obj: Any,
name: Optional[str] = None,
depth: int = 5,
stream: Optional[TextIO] = None,
) -> Dict[str, Any]:
inspector = SerializabilityInspector(stream)
success = inspector.inspect(obj, name or type(obj).__name__, depth)
return {
"success": success,
"failures": inspector.failures,
"report": inspector.get_report(),
}

View File

@@ -0,0 +1,84 @@
from dbgpt.util.pagination_utils import PaginationResult
def test_build_from_all_normal_case():
items = list(range(100))
result = PaginationResult.build_from_all(items, page=2, page_size=20)
assert len(result.items) == 20
assert result.items == list(range(20, 40))
assert result.total_count == 100
assert result.total_pages == 5
assert result.page == 2
assert result.page_size == 20
def test_build_from_all_empty_list():
items = []
result = PaginationResult.build_from_all(items, page=1, page_size=5)
assert result.items == []
assert result.total_count == 0
assert result.total_pages == 0
assert result.page == 0
assert result.page_size == 5
def test_build_from_all_last_page():
items = list(range(95))
result = PaginationResult.build_from_all(items, page=5, page_size=20)
assert len(result.items) == 15
assert result.items == list(range(80, 95))
assert result.total_count == 95
assert result.total_pages == 5
assert result.page == 5
assert result.page_size == 20
def test_build_from_all_page_out_of_range():
items = list(range(50))
result = PaginationResult.build_from_all(items, page=10, page_size=10)
assert len(result.items) == 10
assert result.items == list(range(40, 50))
assert result.total_count == 50
assert result.total_pages == 5
assert result.page == 5
assert result.page_size == 10
def test_build_from_all_page_zero():
items = list(range(50))
result = PaginationResult.build_from_all(items, page=0, page_size=10)
assert len(result.items) == 10
assert result.items == list(range(0, 10))
assert result.total_count == 50
assert result.total_pages == 5
assert result.page == 1
assert result.page_size == 10
def test_build_from_all_negative_page():
items = list(range(50))
result = PaginationResult.build_from_all(items, page=-1, page_size=10)
assert len(result.items) == 10
assert result.items == list(range(0, 10))
assert result.total_count == 50
assert result.total_pages == 5
assert result.page == 1
assert result.page_size == 10
def test_build_from_all_page_size_larger_than_total():
items = list(range(50))
result = PaginationResult.build_from_all(items, page=1, page_size=100)
assert len(result.items) == 50
assert result.items == list(range(50))
assert result.total_count == 50
assert result.total_pages == 1
assert result.page == 1
assert result.page_size == 100