diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 9266e6927..c563aa138 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,9 +1,8 @@ """ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ -import dataclasses import logging -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import Any, Dict, Optional, Union import torch @@ -215,7 +214,7 @@ class InferenceConfig: @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": # Get the list of attributes of this dataclass. - attrs = [attr.name for attr in dataclasses.fields(cls)] + attrs = [attr.name for attr in fields(cls)] inference_config_args = {} for attr in attrs: if attr in config_dict: diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py index e23d0b90f..9c630177d 100644 --- a/colossalai/inference/core/async_engine.py +++ b/colossalai/inference/core/async_engine.py @@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve logger = logging.getLogger("colossalai-inference") -def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: +def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None: msg = "Task finished unexpectedly. This should never happen! " try: try: @@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac class RequstStream: - """A stream of Output for a request that can be - iterated over asynchronously.""" + """ + A stream of Output for a request that can be iterated over asynchronously. + Attributes: 1.request_id: The id of the request. + 2._future: A future that will be set when the request is finished. + Methods: set_result and get_result, results will be set when finished, for once, and + the `self.future` will be set to done. + + """ def __init__(self, request_id: int) -> None: self.request_id = request_id @@ -51,6 +57,10 @@ class RequstStream: class Tracer: """ Recording new requests and finished requests. + Attributes: 1._request_streams: We create one stream for each request to trace the output. + 2._finished_requests: A queue to store the finished requests. + 3._new_requests: New requests will be stored in this queue first, before sending them to the engine. + 4.new_requests_event: An event to notify the engine that there are new requests. """ def __init__(self) -> None: @@ -133,7 +143,8 @@ class Tracer: class _AsyncInferenceEngine(InferenceEngine): """ - Async methods for Inference Engine. + Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for + Methods: 1. async_step: The async version of Engine.step() """ async def async_step(self) -> List[str]: