community[patch]: docstrings (#16810)

- added missed docstrings
- formated docstrings to the consistent form
This commit is contained in:
Leonid Ganeline 2024-02-09 12:48:57 -08:00 committed by GitHub
parent ae66bcbc10
commit 932c52c333
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 66 additions and 18 deletions

View File

@ -40,25 +40,33 @@ async def aenumerate(
class IndexableBaseModel(BaseModel): class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing""" """Allows a BaseModel to return its fields by string variable indexing."""
def __getitem__(self, item: str) -> Any: def __getitem__(self, item: str) -> Any:
return getattr(self, item) return getattr(self, item)
class Choice(IndexableBaseModel): class Choice(IndexableBaseModel):
"""Choice."""
message: dict message: dict
class ChatCompletions(IndexableBaseModel): class ChatCompletions(IndexableBaseModel):
"""Chat completions."""
choices: List[Choice] choices: List[Choice]
class ChoiceChunk(IndexableBaseModel): class ChoiceChunk(IndexableBaseModel):
"""Choice chunk."""
delta: dict delta: dict
class ChatCompletionChunk(IndexableBaseModel): class ChatCompletionChunk(IndexableBaseModel):
"""Chat completion chunk."""
choices: List[ChoiceChunk] choices: List[ChoiceChunk]
@ -301,7 +309,7 @@ def convert_messages_for_finetuning(
class Completions: class Completions:
"""Completion.""" """Completions."""
@overload @overload
@staticmethod @staticmethod
@ -399,6 +407,8 @@ class Completions:
class Chat: class Chat:
"""Chat."""
def __init__(self) -> None: def __init__(self) -> None:
self.completions = Completions() self.completions = Completions()

View File

@ -191,7 +191,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool): class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
"""A tool that sends a DELETE request and parses the response.""" """Tool that sends a DELETE request and parses the response."""
name: str = "requests_delete" name: str = "requests_delete"
"""The name of the tool.""" """The name of the tool."""

View File

@ -39,6 +39,7 @@ def import_mlflow() -> Any:
def mlflow_callback_metrics() -> List[str]: def mlflow_callback_metrics() -> List[str]:
"""Get the metrics to log to MLFlow."""
return [ return [
"step", "step",
"starts", "starts",
@ -59,6 +60,7 @@ def mlflow_callback_metrics() -> List[str]:
def get_text_complexity_metrics() -> List[str]: def get_text_complexity_metrics() -> List[str]:
"""Get the text complexity metrics from textstat."""
return [ return [
"flesch_reading_ease", "flesch_reading_ease",
"flesch_kincaid_grade", "flesch_kincaid_grade",

View File

@ -225,7 +225,7 @@ class LLMThought:
class StreamlitCallbackHandler(BaseCallbackHandler): class StreamlitCallbackHandler(BaseCallbackHandler):
"""A callback handler that writes to a Streamlit app.""" """Callback handler that writes to a Streamlit app."""
def __init__( def __init__(
self, self,

View File

@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class BaseMessageConverter(ABC): class BaseMessageConverter(ABC):
"""The class responsible for converting BaseMessage to your SQLAlchemy model.""" """Class that converts BaseMessage to the SQLAlchemy model."""
@abstractmethod @abstractmethod
def from_sql_model(self, sql_message: Any) -> BaseMessage: def from_sql_model(self, sql_message: Any) -> BaseMessage:

View File

@ -20,6 +20,8 @@ from langchain_community.llms.azureml_endpoint import (
class LlamaContentFormatter(ContentFormatterBase): class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`."""
def __init__(self) -> None: def __init__(self) -> None:
raise TypeError( raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use " "`LlamaContentFormatter` is deprecated for chat models. Use "
@ -34,7 +36,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
@staticmethod @staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict: def _convert_message_to_dict(message: BaseMessage) -> Dict:
"""Converts message to a dict according to role""" """Converts a message to a dict according to a role"""
content = cast(str, message.content) content = cast(str, message.content)
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
return { return {

View File

@ -58,6 +58,8 @@ logger = logging.getLogger(__name__)
class ChatDeepInfraException(Exception): class ChatDeepInfraException(Exception):
"""Exception raised when the DeepInfra API returns an error."""
pass pass
@ -67,7 +69,7 @@ def _create_retry_decorator(
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None, ] = None,
) -> Callable[[Any], Any]: ) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions."""
return create_base_retry_decorator( return create_base_retry_decorator(
error_types=[requests.exceptions.ConnectTimeout, ChatDeepInfraException], error_types=[requests.exceptions.ConnectTimeout, ChatDeepInfraException],
max_retries=llm.max_retries, max_retries=llm.max_retries,

View File

@ -53,6 +53,8 @@ class GPTRouterException(Exception):
class GPTRouterModel(BaseModel): class GPTRouterModel(BaseModel):
"""GPTRouter model."""
name: str name: str
provider_name: str provider_name: str

View File

@ -39,8 +39,8 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase): class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
"""Volc Engine Maas hosts a plethora of models.
"""volc engine maas hosts a plethora of models.
You can utilize these models through this class. You can utilize these models through this class.
To use, you should have the ``volcengine`` python package installed. To use, you should have the ``volcengine`` python package installed.

View File

@ -20,11 +20,15 @@ logger = logging.getLogger(__name__)
class ref(BaseModel): class ref(BaseModel):
"""Reference used in CharacterGLM."""
enable: bool = Field(True) enable: bool = Field(True)
search_query: str = Field("") search_query: str = Field("")
class meta(BaseModel): class meta(BaseModel):
"""Metadata used in CharacterGLM."""
user_info: str = Field("") user_info: str = Field("")
bot_info: str = Field("") bot_info: str = Field("")
bot_name: str = Field("") bot_name: str = Field("")

View File

@ -9,7 +9,7 @@ if TYPE_CHECKING:
class UnstructuredCHMLoader(UnstructuredFileLoader): class UnstructuredCHMLoader(UnstructuredFileLoader):
"""Load `CHM` files using `Unstructured`. """Load `CHM` files using `Unstructured`.
CHM mean Microsoft Compiled HTML Help. CHM means Microsoft Compiled HTML Help.
Examples Examples
-------- --------
@ -35,6 +35,8 @@ class UnstructuredCHMLoader(UnstructuredFileLoader):
class CHMParser(object): class CHMParser(object):
"""Microsoft Compiled HTML Help (CHM) Parser."""
path: str path: str
file: "chm.CHMFile" file: "chm.CHMFile"

View File

@ -11,6 +11,8 @@ from langchain_community.document_loaders.blob_loaders import Blob
class VsdxParser(BaseBlobParser, ABC): class VsdxParser(BaseBlobParser, ABC):
"""Parser for vsdx files."""
def parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[override] def parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[override]
"""Parse a vsdx file.""" """Parse a vsdx file."""
return self.lazy_parse(blob) return self.lazy_parse(blob)

View File

@ -141,6 +141,8 @@ def _parse_video_id(url: str) -> Optional[str]:
class TranscriptFormat(Enum): class TranscriptFormat(Enum):
"""Transcript format."""
TEXT = "text" TEXT = "text"
LINES = "lines" LINES = "lines"

View File

@ -13,7 +13,7 @@ def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
class MlflowEmbeddings(Embeddings, BaseModel): class MlflowEmbeddings(Embeddings, BaseModel):
"""Wrapper around embeddings LLMs in MLflow. """Embedding LLMs in MLflow.
To use, you should have the `mlflow[genai]` python package installed. To use, you should have the `mlflow[genai]` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html. For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
@ -85,5 +85,7 @@ class MlflowEmbeddings(Embeddings, BaseModel):
class MlflowCohereEmbeddings(MlflowEmbeddings): class MlflowCohereEmbeddings(MlflowEmbeddings):
"""Cohere embedding LLMs in MLflow."""
query_params: Dict[str, str] = {"input_type": "search_query"} query_params: Dict[str, str] = {"input_type": "search_query"}
documents_params: Dict[str, str] = {"input_type": "search_document"} documents_params: Dict[str, str] = {"input_type": "search_document"}

View File

@ -8,6 +8,8 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
class OCIAuthType(Enum): class OCIAuthType(Enum):
"""OCI authentication types as enumerator."""
API_KEY = 1 API_KEY = 1
SECURITY_TOKEN = 2 SECURITY_TOKEN = 2
INSTANCE_PRINCIPAL = 3 INSTANCE_PRINCIPAL = 3

View File

@ -32,7 +32,8 @@ RETURN {start: label, type: property, end: toString(other_node)} AS output
def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]: def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
""" """Sanitize the input dictionary.
Sanitizes the input dictionary by removing embedding-like values, Sanitizes the input dictionary by removing embedding-like values,
lists with more than 128 elements, that are mostly irrelevant for lists with more than 128 elements, that are mostly irrelevant for
generating answers in a LLM context. These properties, if left in generating answers in a LLM context. These properties, if left in
@ -63,7 +64,8 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
class Neo4jGraph(GraphStore): class Neo4jGraph(GraphStore):
"""Provides a connection to a Neo4j database for various graph operations. """Neo4j database wrapper for various graph operations.
Parameters: Parameters:
url (Optional[str]): The URL of the Neo4j database server. url (Optional[str]): The URL of the Neo4j database server.
username (Optional[str]): The username for database authentication. username (Optional[str]): The username for database authentication.

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
class NeptuneQueryException(Exception): class NeptuneQueryException(Exception):
"""A class to handle queries that fail to execute""" """Exception for the Neptune queries."""
def __init__(self, exception: Union[str, Dict]): def __init__(self, exception: Union[str, Dict]):
if isinstance(exception, dict): if isinstance(exception, dict):

View File

@ -15,6 +15,8 @@ VALID_PROVIDERS = ("cohere", "meta")
class OCIAuthType(Enum): class OCIAuthType(Enum):
"""OCI authentication types as enumerator."""
API_KEY = 1 API_KEY = 1
SECURITY_TOKEN = 2 SECURITY_TOKEN = 2
INSTANCE_PRINCIPAL = 3 INSTANCE_PRINCIPAL = 3

View File

@ -91,7 +91,9 @@ def stream_generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any:
async def astream_generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any: async def astream_generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any:
"""Because the dashscope SDK doesn't provide an async API, """Async version of `stream_generate_with_retry`.
Because the dashscope SDK doesn't provide an async API,
we wrap `stream_generate_with_retry` with an async generator.""" we wrap `stream_generate_with_retry` with an async generator."""
class _AioTongyiGenerator: class _AioTongyiGenerator:

View File

@ -10,7 +10,7 @@ from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
class ArceeRetriever(BaseRetriever): class ArceeRetriever(BaseRetriever):
"""Document retriever for Arcee's Domain Adapted Language Models (DALMs). """Retriever for Arcee's Domain Adapted Language Models (DALMs).
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key, To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter. or pass ``arcee_api_key`` as a named parameter.

View File

@ -25,6 +25,8 @@ V = TypeVar("V")
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC): class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""Base class for the DataStax AstraDB data store."""
def __init__( def __init__(
self, self,
collection_name: str, collection_name: str,
@ -79,6 +81,7 @@ class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
class AstraDBStore(AstraDBBaseStore[Any]): class AstraDBStore(AstraDBBaseStore[Any]):
"""BaseStore implementation using DataStax AstraDB as the underlying store. """BaseStore implementation using DataStax AstraDB as the underlying store.
The value type can be any type serializable by json.dumps. The value type can be any type serializable by json.dumps.
Can be used to store embeddings with the CacheBackedEmbeddings. Can be used to store embeddings with the CacheBackedEmbeddings.
Documents in the AstraDB collection will have the format Documents in the AstraDB collection will have the format
@ -97,6 +100,7 @@ class AstraDBStore(AstraDBBaseStore[Any]):
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore): class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
"""ByteStore implementation using DataStax AstraDB as the underlying store. """ByteStore implementation using DataStax AstraDB as the underlying store.
The bytes values are converted to base64 encoded strings The bytes values are converted to base64 encoded strings
Documents in the AstraDB collection will have the format Documents in the AstraDB collection will have the format
{ {

View File

@ -8,6 +8,8 @@ from langchain_community.utilities.polygon import PolygonAPIWrapper
class Inputs(BaseModel): class Inputs(BaseModel):
"""Inputs for Polygon's Last Quote API"""
query: str query: str

View File

@ -57,7 +57,7 @@ def init_vertexai(
location: Optional[str] = None, location: Optional[str] = None,
credentials: Optional["Credentials"] = None, credentials: Optional["Credentials"] = None,
) -> None: ) -> None:
"""Init vertexai. """Init Vertex AI.
Args: Args:
project: The default GCP project to use when making Vertex API calls. project: The default GCP project to use when making Vertex API calls.

View File

@ -16,6 +16,8 @@ logger = logging.getLogger(__name__)
class KDBAI(VectorStore): class KDBAI(VectorStore):
"""`KDB.AI` vector store. """`KDB.AI` vector store.
See [https://kdb.ai](https://kdb.ai)
To use, you should have the `kdbai_client` python package installed. To use, you should have the `kdbai_client` python package installed.
Args: Args:
@ -25,7 +27,7 @@ class KDBAI(VectorStore):
distance_strategy: One option from DistanceStrategy.EUCLIDEAN_DISTANCE, distance_strategy: One option from DistanceStrategy.EUCLIDEAN_DISTANCE,
DistanceStrategy.DOT_PRODUCT or DistanceStrategy.COSINE. DistanceStrategy.DOT_PRODUCT or DistanceStrategy.COSINE.
See the example https://github.com/KxSystems/langchain/blob/KDB.AI/docs/docs/integrations/vectorstores/kdbai.ipynb. See the example [notebook](https://github.com/KxSystems/langchain/blob/KDB.AI/docs/docs/integrations/vectorstores/kdbai.ipynb).
""" """
def __init__( def __init__(

View File

@ -47,12 +47,14 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]:
class BaseEmbeddingStore: class BaseEmbeddingStore:
"""Embedding store.""" """Base class for the Lantern embedding store."""
def get_embedding_store( def get_embedding_store(
distance_strategy: DistanceStrategy, collection_name: str distance_strategy: DistanceStrategy, collection_name: str
) -> Any: ) -> Any:
"""Get the embedding store class."""
embedding_type = None embedding_type = None
if distance_strategy == DistanceStrategy.HAMMING: if distance_strategy == DistanceStrategy.HAMMING: