mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 12:29:09 +00:00
fix comments
This commit is contained in:
parent
2396d58343
commit
402e9918df
@ -1,9 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||||
"""
|
"""
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, fields
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -215,7 +214,7 @@ class InferenceConfig:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||||
# Get the list of attributes of this dataclass.
|
# Get the list of attributes of this dataclass.
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in fields(cls)]
|
||||||
inference_config_args = {}
|
inference_config_args = {}
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
if attr in config_dict:
|
if attr in config_dict:
|
||||||
|
@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
|
|||||||
logger = logging.getLogger("colossalai-inference")
|
logger = logging.getLogger("colossalai-inference")
|
||||||
|
|
||||||
|
|
||||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
|
||||||
msg = "Task finished unexpectedly. This should never happen! "
|
msg = "Task finished unexpectedly. This should never happen! "
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
|
|||||||
|
|
||||||
|
|
||||||
class RequstStream:
|
class RequstStream:
|
||||||
"""A stream of Output for a request that can be
|
"""
|
||||||
iterated over asynchronously."""
|
A stream of Output for a request that can be iterated over asynchronously.
|
||||||
|
Attributes: 1.request_id: The id of the request.
|
||||||
|
2._future: A future that will be set when the request is finished.
|
||||||
|
Methods: set_result and get_result, results will be set when finished, for once, and
|
||||||
|
the `self.future` will be set to done.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, request_id: int) -> None:
|
def __init__(self, request_id: int) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
@ -51,6 +57,10 @@ class RequstStream:
|
|||||||
class Tracer:
|
class Tracer:
|
||||||
"""
|
"""
|
||||||
Recording new requests and finished requests.
|
Recording new requests and finished requests.
|
||||||
|
Attributes: 1._request_streams: We create one stream for each request to trace the output.
|
||||||
|
2._finished_requests: A queue to store the finished requests.
|
||||||
|
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
|
||||||
|
4.new_requests_event: An event to notify the engine that there are new requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -133,7 +143,8 @@ class Tracer:
|
|||||||
|
|
||||||
class _AsyncInferenceEngine(InferenceEngine):
|
class _AsyncInferenceEngine(InferenceEngine):
|
||||||
"""
|
"""
|
||||||
Async methods for Inference Engine.
|
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
|
||||||
|
Methods: 1. async_step: The async version of Engine.step()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def async_step(self) -> List[str]:
|
async def async_step(self) -> List[str]:
|
||||||
|
Loading…
Reference in New Issue
Block a user