mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 20:39:48 +00:00
fix comments
This commit is contained in:
parent
2396d58343
commit
402e9918df
@ -1,9 +1,8 @@
|
||||
"""
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -215,7 +214,7 @@ class InferenceConfig:
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
attrs = [attr.name for attr in fields(cls)]
|
||||
inference_config_args = {}
|
||||
for attr in attrs:
|
||||
if attr in config_dict:
|
||||
|
@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
|
||||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
|
||||
msg = "Task finished unexpectedly. This should never happen! "
|
||||
try:
|
||||
try:
|
||||
@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
|
||||
|
||||
|
||||
class RequstStream:
|
||||
"""A stream of Output for a request that can be
|
||||
iterated over asynchronously."""
|
||||
"""
|
||||
A stream of Output for a request that can be iterated over asynchronously.
|
||||
Attributes: 1.request_id: The id of the request.
|
||||
2._future: A future that will be set when the request is finished.
|
||||
Methods: set_result and get_result, results will be set when finished, for once, and
|
||||
the `self.future` will be set to done.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: int) -> None:
|
||||
self.request_id = request_id
|
||||
@ -51,6 +57,10 @@ class RequstStream:
|
||||
class Tracer:
|
||||
"""
|
||||
Recording new requests and finished requests.
|
||||
Attributes: 1._request_streams: We create one stream for each request to trace the output.
|
||||
2._finished_requests: A queue to store the finished requests.
|
||||
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
|
||||
4.new_requests_event: An event to notify the engine that there are new requests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -133,7 +143,8 @@ class Tracer:
|
||||
|
||||
class _AsyncInferenceEngine(InferenceEngine):
|
||||
"""
|
||||
Async methods for Inference Engine.
|
||||
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
|
||||
Methods: 1. async_step: The async version of Engine.step()
|
||||
"""
|
||||
|
||||
async def async_step(self) -> List[str]:
|
||||
|
Loading…
Reference in New Issue
Block a user