docstrings cleanup (#8993)

Added/Updated docstrings

 @baskaryan
This commit is contained in:
Leonid Ganeline 2023-08-09 15:49:06 -07:00 committed by GitHub
parent c72da53c10
commit 5454591b0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 269 additions and 12 deletions

View File

@ -55,6 +55,16 @@ def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
@dataclass(frozen=True) @dataclass(frozen=True)
class ReducedOpenAPISpec: class ReducedOpenAPISpec:
"""A reduced OpenAPI spec.
This is a quick and dirty representation for OpenAPI specs.
Attributes:
servers: The servers in the spec.
description: The description of the spec.
endpoints: The endpoints in the spec.
"""
servers: List[dict] servers: List[dict]
description: str description: str
endpoints: List[Tuple[str, str, dict]] endpoints: List[Tuple[str, str, dict]]

View File

@ -10,6 +10,8 @@ from langchain.tools.base import BaseTool
class XMLAgentOutputParser(AgentOutputParser): class XMLAgentOutputParser(AgentOutputParser):
"""Output parser for XMLAgent."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if "</tool>" in text: if "</tool>" in text:
tool, tool_input = text.split("</tool>") tool, tool_input = text.split("</tool>")

View File

@ -49,6 +49,8 @@ class ElementInViewPort(TypedDict):
class Crawler: class Crawler:
"""A crawler for web pages."""
def __init__(self) -> None: def __init__(self) -> None:
try: try:
from playwright.sync_api import sync_playwright from playwright.sync_api import sync_playwright

View File

@ -9,6 +9,7 @@ try:
except ImportError: except ImportError:
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore
"""Dummy decorator for when lark is not installed."""
return lambda _: None return lambda _: None
Transformer = object # type: ignore Transformer = object # type: ignore

View File

@ -51,6 +51,8 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"""Base class for chat models."""
cache: Optional[bool] = None cache: Optional[bool] = None
"""Whether to cache the response.""" """Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity) verbose: bool = Field(default_factory=_get_verbosity)

View File

@ -19,6 +19,17 @@ class AirbyteCDKLoader(BaseLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
source_class: The source connector class.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage
from airbyte_cdk.sources.embedded.base_integration import ( from airbyte_cdk.sources.embedded.base_integration import (
BaseEmbeddedIntegration, BaseEmbeddedIntegration,
@ -26,6 +37,8 @@ class AirbyteCDKLoader(BaseLoader):
from airbyte_cdk.sources.embedded.runner import CDKRunner from airbyte_cdk.sources.embedded.runner import CDKRunner
class CDKIntegration(BaseEmbeddedIntegration): class CDKIntegration(BaseEmbeddedIntegration):
"""A wrapper around the CDK integration."""
def _handle_record( def _handle_record(
self, record: AirbyteRecordMessage, id: Optional[str] self, record: AirbyteRecordMessage, id: Optional[str]
) -> Document: ) -> Document:
@ -50,6 +63,8 @@ class AirbyteCDKLoader(BaseLoader):
class AirbyteHubspotLoader(AirbyteCDKLoader): class AirbyteHubspotLoader(AirbyteCDKLoader):
"""Loads records from Hubspot using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -57,6 +72,16 @@ class AirbyteHubspotLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_hubspot", pip_name="airbyte-source-hubspot" "source_hubspot", pip_name="airbyte-source-hubspot"
).SourceHubspot ).SourceHubspot
@ -70,6 +95,8 @@ class AirbyteHubspotLoader(AirbyteCDKLoader):
class AirbyteStripeLoader(AirbyteCDKLoader): class AirbyteStripeLoader(AirbyteCDKLoader):
"""Loads records from Stripe using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -77,6 +104,16 @@ class AirbyteStripeLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_stripe", pip_name="airbyte-source-stripe" "source_stripe", pip_name="airbyte-source-stripe"
).SourceStripe ).SourceStripe
@ -90,6 +127,8 @@ class AirbyteStripeLoader(AirbyteCDKLoader):
class AirbyteTypeformLoader(AirbyteCDKLoader): class AirbyteTypeformLoader(AirbyteCDKLoader):
"""Loads records from Typeform using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -97,6 +136,16 @@ class AirbyteTypeformLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_typeform", pip_name="airbyte-source-typeform" "source_typeform", pip_name="airbyte-source-typeform"
).SourceTypeform ).SourceTypeform
@ -110,6 +159,8 @@ class AirbyteTypeformLoader(AirbyteCDKLoader):
class AirbyteZendeskSupportLoader(AirbyteCDKLoader): class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
"""Loads records from Zendesk Support using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -117,6 +168,16 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_zendesk_support", pip_name="airbyte-source-zendesk-support" "source_zendesk_support", pip_name="airbyte-source-zendesk-support"
).SourceZendeskSupport ).SourceZendeskSupport
@ -130,6 +191,8 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
class AirbyteShopifyLoader(AirbyteCDKLoader): class AirbyteShopifyLoader(AirbyteCDKLoader):
"""Loads records from Shopify using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -137,6 +200,16 @@ class AirbyteShopifyLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_shopify", pip_name="airbyte-source-shopify" "source_shopify", pip_name="airbyte-source-shopify"
).SourceShopify ).SourceShopify
@ -150,6 +223,8 @@ class AirbyteShopifyLoader(AirbyteCDKLoader):
class AirbyteSalesforceLoader(AirbyteCDKLoader): class AirbyteSalesforceLoader(AirbyteCDKLoader):
"""Loads records from Salesforce using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -157,6 +232,16 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_salesforce", pip_name="airbyte-source-salesforce" "source_salesforce", pip_name="airbyte-source-salesforce"
).SourceSalesforce ).SourceSalesforce
@ -170,6 +255,8 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader):
class AirbyteGongLoader(AirbyteCDKLoader): class AirbyteGongLoader(AirbyteCDKLoader):
"""Loads records from Gong using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -177,6 +264,16 @@ class AirbyteGongLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_gong", pip_name="airbyte-source-gong" "source_gong", pip_name="airbyte-source-gong"
).SourceGong ).SourceGong

View File

@ -79,8 +79,10 @@ class OpenAIWhisperParser(BaseBlobParser):
class OpenAIWhisperParserLocal(BaseBlobParser): class OpenAIWhisperParserLocal(BaseBlobParser):
"""Transcribe and parse audio files. """Transcribe and parse audio files with OpenAI Whisper model.
Audio transcription with OpenAI Whisper model locally from transformers
Audio transcription with OpenAI Whisper model locally from transformers.
Parameters: Parameters:
device - device to use device - device to use
NOTE: By default uses the gpu if available, NOTE: By default uses the gpu if available,
@ -105,6 +107,15 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
lang_model: Optional[str] = None, lang_model: Optional[str] = None,
forced_decoder_ids: Optional[Tuple[Dict]] = None, forced_decoder_ids: Optional[Tuple[Dict]] = None,
): ):
"""Initialize the parser.
Args:
device: device to use.
lang_model: whisper model to use, for example "openai/whisper-medium".
Defaults to None.
forced_decoder_ids: id states for decoder in a multilanguage model.
Defaults to None.
"""
try: try:
from transformers import pipeline from transformers import pipeline
except ImportError: except ImportError:

View File

@ -11,7 +11,7 @@ class NucliaTextTransformer(BaseDocumentTransformer):
""" """
The Nuclia Understanding API splits into paragraphs and sentences, The Nuclia Understanding API splits into paragraphs and sentences,
identifies entities, provides a summary of the text and generates identifies entities, provides a summary of the text and generates
embeddings for all the sentences. embeddings for all sentences.
""" """
def __init__(self, nua: NucliaUnderstandingAPI): def __init__(self, nua: NucliaUnderstandingAPI):

View File

@ -6,6 +6,14 @@ from langchain.embeddings.base import Embeddings
class AwaEmbeddings(BaseModel, Embeddings): class AwaEmbeddings(BaseModel, Embeddings):
"""Embedding documents and queries with Awa DB.
Attributes:
client: The AwaEmbedding client.
model: The name of the model used for embedding.
Default is "all-mpnet-base-v2".
"""
client: Any #: :meta private: client: Any #: :meta private:
model: str = "all-mpnet-base-v2" model: str = "all-mpnet-base-v2"

View File

@ -13,7 +13,7 @@ EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"
class EmbaasEmbeddingsPayload(TypedDict): class EmbaasEmbeddingsPayload(TypedDict):
"""Payload for the embaas embeddings API.""" """Payload for the Embaas embeddings API."""
model: str model: str
texts: List[str] texts: List[str]

View File

@ -24,7 +24,7 @@ from langchain.schema.language_model import BaseLanguageModel
def load_dataset(uri: str) -> List[Dict]: def load_dataset(uri: str) -> List[Dict]:
"""Load a dataset from the `LangChainDatasets HuggingFace org <https://huggingface.co/LangChainDatasets>`_. """Load a dataset from the `LangChainDatasets on HuggingFace <https://huggingface.co/LangChainDatasets>`_.
Args: Args:
uri: The uri of the dataset to load. uri: The uri of the dataset to load.

View File

@ -15,6 +15,13 @@ TIMEOUT = 60
@dataclasses.dataclass @dataclasses.dataclass
class AviaryBackend: class AviaryBackend:
"""Aviary backend.
Attributes:
backend_url: The URL for the Aviary backend.
bearer: The bearer token for the Aviary backend.
"""
backend_url: str backend_url: str
bearer: str bearer: str
@ -89,6 +96,14 @@ class Aviary(LLM):
AVIARY_URL and AVIARY_TOKEN environment variables must be set. AVIARY_URL and AVIARY_TOKEN environment variables must be set.
Attributes:
model: The name of the model to use. Defaults to "amazon/LightGPT".
aviary_url: The URL for the Aviary backend. Defaults to None.
aviary_token: The bearer token for the Aviary backend. Defaults to None.
use_prompt_format: If True, the prompt template for the model will be ignored.
Defaults to True.
version: API version to use for Aviary. Defaults to None.
Example: Example:
.. code-block:: python .. code-block:: python

View File

@ -56,6 +56,8 @@ class FakeListLLM(LLM):
class FakeStreamingListLLM(FakeListLLM): class FakeStreamingListLLM(FakeListLLM):
"""Fake streaming list LLM for testing purposes."""
def stream( def stream(
self, self,
input: LanguageModelInput, input: LanguageModelInput,

View File

@ -8,6 +8,8 @@ from langchain.schema.output import Generation, LLMResult
class VLLM(BaseLLM): class VLLM(BaseLLM):
"""VLLM language model."""
model: str = "" model: str = ""
"""The name or path of a HuggingFace Transformers model.""" """The name or path of a HuggingFace Transformers model."""

View File

@ -88,6 +88,8 @@ class BaseMessage(Serializable):
class BaseMessageChunk(BaseMessage): class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
def _merge_kwargs_dict( def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any] self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -145,6 +147,8 @@ class HumanMessage(BaseMessage):
class HumanMessageChunk(HumanMessage, BaseMessageChunk): class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
pass pass
@ -163,6 +167,8 @@ class AIMessage(BaseMessage):
class AIMessageChunk(AIMessage, BaseMessageChunk): class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
pass pass
@ -178,6 +184,8 @@ class SystemMessage(BaseMessage):
class SystemMessageChunk(SystemMessage, BaseMessageChunk): class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
pass pass
@ -194,6 +202,8 @@ class FunctionMessage(BaseMessage):
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
pass pass
@ -210,6 +220,8 @@ class ChatMessage(BaseMessage):
class ChatMessageChunk(ChatMessage, BaseMessageChunk): class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
pass pass

View File

@ -29,6 +29,8 @@ class Generation(Serializable):
class GenerationChunk(Generation): class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
def __add__(self, other: GenerationChunk) -> GenerationChunk: def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk): if isinstance(other, GenerationChunk):
generation_info = ( generation_info = (
@ -62,6 +64,13 @@ class ChatGeneration(Generation):
class ChatGenerationChunk(ChatGeneration): class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
ChatGeneration chunks.
Attributes:
message: The message chunk output by the chat model.
"""
message: BaseMessageChunk message: BaseMessageChunk
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:

View File

@ -56,6 +56,8 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseGenerationOutputParser( class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
): ):
"""Base class to parse the output of an LLM call."""
def invoke( def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T: ) -> T:

View File

@ -48,6 +48,8 @@ async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> li
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str] tags: List[str]
""" """
Tags for this call and any sub-calls (eg. a Chain calling an LLM). Tags for this call and any sub-calls (eg. a Chain calling an LLM).
@ -74,6 +76,9 @@ Other = TypeVar("Other")
class Runnable(Generic[Input, Output], ABC): class Runnable(Generic[Input, Output], ABC):
"""A Runnable is a unit of work that can be invoked, batched, streamed, or
transformed."""
def __or__( def __or__(
self, self,
other: Union[ other: Union[
@ -1325,6 +1330,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
class RouterInput(TypedDict): class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str key: str
input: Any input: Any

View File

@ -31,6 +31,14 @@ class CreateSessionSchema(BaseModel):
class MultionCreateSession(BaseTool): class MultionCreateSession(BaseTool):
"""Tool that creates a new Multion Browser Window with provided fields.
Attributes:
name: The name of the tool. Default: "create_multion_session"
description: The description of the tool.
args_schema: The schema for the tool's arguments.
"""
name: str = "create_multion_session" name: str = "create_multion_session"
description: str = """Use this tool to create a new Multion Browser Window \ description: str = """Use this tool to create a new Multion Browser Window \
with provided fields.Always the first step to run \ with provided fields.Always the first step to run \

View File

@ -34,6 +34,14 @@ class UpdateSessionSchema(BaseModel):
class MultionUpdateSession(BaseTool): class MultionUpdateSession(BaseTool):
"""Tool that updates an existing Multion Browser Window with provided fields.
Attributes:
name: The name of the tool. Default: "update_multion_session"
description: The description of the tool.
args_schema: The schema for the tool's arguments. Default: UpdateSessionSchema
"""
name: str = "update_multion_session" name: str = "update_multion_session"
description: str = """Use this tool to update \ description: str = """Use this tool to update \
a existing corresponding \ a existing corresponding \

View File

@ -28,6 +28,15 @@ logger = logging.getLogger(__name__)
class NUASchema(BaseModel): class NUASchema(BaseModel):
"""Input for Nuclia Understanding API.
Attributes:
action: Action to perform. Either `push` or `pull`.
id: ID of the file to push or pull.
path: Path to the file to push (needed only for `push` action).
text: Text content to process (needed only for `push` action).
"""
action: str = Field( action: str = Field(
..., ...,
description="Action to perform. Either `push` or `pull`.", description="Action to perform. Either `push` or `pull`.",

View File

@ -4,6 +4,13 @@ from typing import Dict, Optional
class Portkey: class Portkey:
"""Portkey configuration.
Attributes:
base: The base URL for the Portkey API.
Default: "https://api.portkey.ai/v1/proxy"
"""
base = "https://api.portkey.ai/v1/proxy" base = "https://api.portkey.ai/v1/proxy"
@staticmethod @staticmethod

View File

@ -7,6 +7,8 @@ if TYPE_CHECKING:
class SparkSQL: class SparkSQL:
"""SparkSQL is a utility class for interacting with Spark SQL."""
def __init__( def __init__(
self, self,
spark_session: Optional[SparkSession] = None, spark_session: Optional[SparkSession] = None,
@ -16,10 +18,26 @@ class SparkSQL:
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3, sample_rows_in_table_info: int = 3,
): ):
"""Initialize a SparkSQL object.
Args:
spark_session: A SparkSession object.
If not provided, one will be created.
catalog: The catalog to use.
If not provided, the default catalog will be used.
schema: The schema to use.
If not provided, the default schema will be used.
ignore_tables: A list of tables to ignore.
If not provided, all tables will be used.
include_tables: A list of tables to include.
If not provided, all tables will be used.
sample_rows_in_table_info: The number of rows to include in the table info.
Defaults to 3.
"""
try: try:
from pyspark.sql import SparkSession from pyspark.sql import SparkSession
except ImportError: except ImportError:
raise ValueError( raise ImportError(
"pyspark is not installed. Please install it with `pip install pyspark`" "pyspark is not installed. Please install it with `pip install pyspark`"
) )

View File

@ -141,7 +141,13 @@ def build_extra_kwargs(
values: Dict[str, Any], values: Dict[str, Any],
all_required_field_names: Set[str], all_required_field_names: Set[str],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""""" """Build extra kwargs from values and extra_kwargs.
Args:
extra_kwargs: Extra kwargs passed in by user.
values: Values passed in by user.
all_required_field_names: All required field names for the pydantic class.
"""
for field_name in list(values): for field_name in list(values):
if field_name in extra_kwargs: if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.") raise ValueError(f"Found {field_name} supplied twice.")

View File

@ -12,7 +12,8 @@ logger = logging.getLogger()
class AlibabaCloudOpenSearchSettings: class AlibabaCloudOpenSearchSettings:
"""Opensearch Client Configuration """Alibaba Cloud Opensearch Client Configuration.
Attribute: Attribute:
endpoint (str) : The endpoint of opensearch instance, You can find it endpoint (str) : The endpoint of opensearch instance, You can find it
from the console of Alibaba Cloud OpenSearch. from the console of Alibaba Cloud OpenSearch.

View File

@ -16,7 +16,17 @@ _LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_pg_embedding"
class HologresWrapper: class HologresWrapper:
"""Wrapper around Hologres service."""
def __init__(self, connection_string: str, ndims: int, table_name: str) -> None: def __init__(self, connection_string: str, ndims: int, table_name: str) -> None:
"""Initialize the wrapper.
Args:
connection_string: Hologres connection string.
ndims: Number of dimensions of the embedding output.
table_name: Name of the table to store embeddings and data.
"""
import psycopg2 import psycopg2
self.table_name = table_name self.table_name = table_name

View File

@ -87,6 +87,8 @@ class EmbeddingStore(BaseModel):
class QueryResult: class QueryResult:
"""QueryResult is a result from a query."""
EmbeddingStore: EmbeddingStore EmbeddingStore: EmbeddingStore
distance: float distance: float

View File

@ -18,6 +18,7 @@ from langchain.vectorstores.utils import DistanceStrategy
def normalize(x: np.ndarray) -> np.ndarray: def normalize(x: np.ndarray) -> np.ndarray:
"""Normalize vectors to unit length."""
x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None) x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None)
return x return x