fix comments

This commit is contained in:
CjhHa1 2024-04-17 11:27:51 +08:00
parent 2396d58343
commit 402e9918df
2 changed files with 17 additions and 7 deletions

View File

@ -1,9 +1,8 @@
"""
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import dataclasses
import logging
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Union
import torch
@ -215,7 +214,7 @@ class InferenceConfig:
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
attrs = [attr.name for attr in fields(cls)]
inference_config_args = {}
for attr in attrs:
if attr in config_dict:

View File

@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
logger = logging.getLogger("colossalai-inference")
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
msg = "Task finished unexpectedly. This should never happen! "
try:
try:
@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
class RequstStream:
"""A stream of Output for a request that can be
iterated over asynchronously."""
"""
A stream of Output for a request that can be iterated over asynchronously.
Attributes: 1.request_id: The id of the request.
2._future: A future that will be set when the request is finished.
Methods: set_result and get_result, results will be set when finished, for once, and
the `self.future` will be set to done.
"""
def __init__(self, request_id: int) -> None:
self.request_id = request_id
@ -51,6 +57,10 @@ class RequstStream:
class Tracer:
"""
Recording new requests and finished requests.
Attributes: 1._request_streams: We create one stream for each request to trace the output.
2._finished_requests: A queue to store the finished requests.
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
4.new_requests_event: An event to notify the engine that there are new requests.
"""
def __init__(self) -> None:
@ -133,7 +143,8 @@ class Tracer:
class _AsyncInferenceEngine(InferenceEngine):
"""
Async methods for Inference Engine.
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
Methods: 1. async_step: The async version of Engine.step()
"""
async def async_step(self) -> List[str]: