From 8dd071ad08949a0e57322f5de462ef93e2272e62 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 9 Aug 2023 14:51:15 -0700 Subject: [PATCH 1/8] import airbyte loaders (#9009) --- .../langchain/document_loaders/__init__.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index b52c0927db8..a988744d249 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -16,6 +16,16 @@ """ from langchain.document_loaders.acreom import AcreomLoader +from langchain.document_loaders.airbyte import ( + AirbyteCDKLoader, + AirbyteGongLoader, + AirbyteHubspotLoader, + AirbyteSalesforceLoader, + AirbyteShopifyLoader, + AirbyteStripeLoader, + AirbyteTypeformLoader, + AirbyteZendeskSupportLoader, +) from langchain.document_loaders.airbyte_json import AirbyteJSONLoader from langchain.document_loaders.airtable import AirtableLoader from langchain.document_loaders.apify_dataset import ApifyDatasetLoader @@ -188,7 +198,15 @@ TelegramChatLoader = TelegramChatFileLoader __all__ = [ "AZLyricsLoader", "AcreomLoader", + "AirbyteCDKLoader", + "AirbyteGongLoader", "AirbyteJSONLoader", + "AirbyteHubspotLoader", + "AirbyteSalesforceLoader", + "AirbyteShopifyLoader", + "AirbyteStripeLoader", + "AirbyteTypeformLoader", + "AirbyteZendeskSupportLoader", "AirtableLoader", "AmazonTextractPDFLoader", "ApifyDatasetLoader", From c72da53c109277ec61b27a8a1600b977893b1818 Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Thu, 10 Aug 2023 00:48:29 +0200 Subject: [PATCH 2/8] Add logprobs to SamplingParameters in vllm (#9010) This PR aims at amending #8806 , that I opened a few days ago, adding the extra `logprobs` parameter that I accidentally forgot --- libs/langchain/langchain/llms/vllm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/langchain/langchain/llms/vllm.py b/libs/langchain/langchain/llms/vllm.py index d02b6ff02f2..46da858c775 100644 --- a/libs/langchain/langchain/llms/vllm.py +++ b/libs/langchain/langchain/llms/vllm.py @@ -54,6 +54,9 @@ class VLLM(BaseLLM): max_new_tokens: int = 512 """Maximum number of tokens to generate per output sequence.""" + logprobs: Optional[int] = None + """Number of log probabilities to return per output token.""" + client: Any #: :meta private: @root_validator() @@ -91,6 +94,7 @@ class VLLM(BaseLLM): "stop": self.stop, "ignore_eos": self.ignore_eos, "use_beam_search": self.use_beam_search, + "logprobs": self.logprobs, } def _generate( From 5454591b0af8a409ad07e73681d071be08ef4eab Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Wed, 9 Aug 2023 15:49:06 -0700 Subject: [PATCH 3/8] docstrings cleanup (#8993) Added/Updated docstrings @baskaryan --- .../agents/agent_toolkits/openapi/spec.py | 10 ++ libs/langchain/langchain/agents/xml/base.py | 2 + .../langchain/chains/natbot/crawler.py | 2 + .../chains/query_constructor/parser.py | 1 + libs/langchain/langchain/chat_models/base.py | 2 + .../langchain/document_loaders/airbyte.py | 97 +++++++++++++++++++ .../document_loaders/parsers/audio.py | 15 ++- .../nuclia_text_transform.py | 2 +- libs/langchain/langchain/embeddings/awa.py | 8 ++ libs/langchain/langchain/embeddings/embaas.py | 2 +- .../langchain/langchain/evaluation/loading.py | 2 +- libs/langchain/langchain/llms/aviary.py | 15 +++ libs/langchain/langchain/llms/fake.py | 2 + libs/langchain/langchain/llms/vllm.py | 2 + libs/langchain/langchain/schema/messages.py | 12 +++ libs/langchain/langchain/schema/output.py | 9 ++ .../langchain/schema/output_parser.py | 2 + libs/langchain/langchain/schema/runnable.py | 12 +++ .../langchain/tools/multion/create_session.py | 8 ++ .../langchain/tools/multion/update_session.py | 8 ++ libs/langchain/langchain/tools/nuclia/tool.py | 9 ++ libs/langchain/langchain/utilities/portkey.py | 7 ++ .../langchain/utilities/spark_sql.py | 20 +++- libs/langchain/langchain/utils/utils.py | 8 +- .../vectorstores/alibabacloud_opensearch.py | 11 ++- .../langchain/vectorstores/hologres.py | 10 ++ .../langchain/vectorstores/pgembedding.py | 2 + .../langchain/langchain/vectorstores/scann.py | 1 + 28 files changed, 269 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py index 1c717db7b94..fa26b3c5d0e 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py @@ -55,6 +55,16 @@ def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]: @dataclass(frozen=True) class ReducedOpenAPISpec: + """A reduced OpenAPI spec. + + This is a quick and dirty representation for OpenAPI specs. + + Attributes: + servers: The servers in the spec. + description: The description of the spec. + endpoints: The endpoints in the spec. + """ + servers: List[dict] description: str endpoints: List[Tuple[str, str, dict]] diff --git a/libs/langchain/langchain/agents/xml/base.py b/libs/langchain/langchain/agents/xml/base.py index 8e93b54fe32..3462ebe66d5 100644 --- a/libs/langchain/langchain/agents/xml/base.py +++ b/libs/langchain/langchain/agents/xml/base.py @@ -10,6 +10,8 @@ from langchain.tools.base import BaseTool class XMLAgentOutputParser(AgentOutputParser): + """Output parser for XMLAgent.""" + def parse(self, text: str) -> Union[AgentAction, AgentFinish]: if "" in text: tool, tool_input = text.split("") diff --git a/libs/langchain/langchain/chains/natbot/crawler.py b/libs/langchain/langchain/chains/natbot/crawler.py index 69fd51122b4..442781e551c 100644 --- a/libs/langchain/langchain/chains/natbot/crawler.py +++ b/libs/langchain/langchain/chains/natbot/crawler.py @@ -49,6 +49,8 @@ class ElementInViewPort(TypedDict): class Crawler: + """A crawler for web pages.""" + def __init__(self) -> None: try: from playwright.sync_api import sync_playwright diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index 8e685786d45..6cbc42b5e73 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -9,6 +9,7 @@ try: except ImportError: def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore + """Dummy decorator for when lark is not installed.""" return lambda _: None Transformer = object # type: ignore diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 0a39dff54ac..8ef1479e1f8 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -51,6 +51,8 @@ def _get_verbosity() -> bool: class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): + """Base class for chat models.""" + cache: Optional[bool] = None """Whether to cache the response.""" verbose: bool = Field(default_factory=_get_verbosity) diff --git a/libs/langchain/langchain/document_loaders/airbyte.py b/libs/langchain/langchain/document_loaders/airbyte.py index aa670704655..51411f01200 100644 --- a/libs/langchain/langchain/document_loaders/airbyte.py +++ b/libs/langchain/langchain/document_loaders/airbyte.py @@ -19,6 +19,17 @@ class AirbyteCDKLoader(BaseLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + source_class: The source connector class. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage from airbyte_cdk.sources.embedded.base_integration import ( BaseEmbeddedIntegration, @@ -26,6 +37,8 @@ class AirbyteCDKLoader(BaseLoader): from airbyte_cdk.sources.embedded.runner import CDKRunner class CDKIntegration(BaseEmbeddedIntegration): + """A wrapper around the CDK integration.""" + def _handle_record( self, record: AirbyteRecordMessage, id: Optional[str] ) -> Document: @@ -50,6 +63,8 @@ class AirbyteCDKLoader(BaseLoader): class AirbyteHubspotLoader(AirbyteCDKLoader): + """Loads records from Hubspot using an Airbyte source connector.""" + def __init__( self, config: Mapping[str, Any], @@ -57,6 +72,16 @@ class AirbyteHubspotLoader(AirbyteCDKLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ source_class = guard_import( "source_hubspot", pip_name="airbyte-source-hubspot" ).SourceHubspot @@ -70,6 +95,8 @@ class AirbyteHubspotLoader(AirbyteCDKLoader): class AirbyteStripeLoader(AirbyteCDKLoader): + """Loads records from Stripe using an Airbyte source connector.""" + def __init__( self, config: Mapping[str, Any], @@ -77,6 +104,16 @@ class AirbyteStripeLoader(AirbyteCDKLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ source_class = guard_import( "source_stripe", pip_name="airbyte-source-stripe" ).SourceStripe @@ -90,6 +127,8 @@ class AirbyteStripeLoader(AirbyteCDKLoader): class AirbyteTypeformLoader(AirbyteCDKLoader): + """Loads records from Typeform using an Airbyte source connector.""" + def __init__( self, config: Mapping[str, Any], @@ -97,6 +136,16 @@ class AirbyteTypeformLoader(AirbyteCDKLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ source_class = guard_import( "source_typeform", pip_name="airbyte-source-typeform" ).SourceTypeform @@ -110,6 +159,8 @@ class AirbyteTypeformLoader(AirbyteCDKLoader): class AirbyteZendeskSupportLoader(AirbyteCDKLoader): + """Loads records from Zendesk Support using an Airbyte source connector.""" + def __init__( self, config: Mapping[str, Any], @@ -117,6 +168,16 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ source_class = guard_import( "source_zendesk_support", pip_name="airbyte-source-zendesk-support" ).SourceZendeskSupport @@ -130,6 +191,8 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader): class AirbyteShopifyLoader(AirbyteCDKLoader): + """Loads records from Shopify using an Airbyte source connector.""" + def __init__( self, config: Mapping[str, Any], @@ -137,6 +200,16 @@ class AirbyteShopifyLoader(AirbyteCDKLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ source_class = guard_import( "source_shopify", pip_name="airbyte-source-shopify" ).SourceShopify @@ -150,6 +223,8 @@ class AirbyteShopifyLoader(AirbyteCDKLoader): class AirbyteSalesforceLoader(AirbyteCDKLoader): + """Loads records from Salesforce using an Airbyte source connector.""" + def __init__( self, config: Mapping[str, Any], @@ -157,6 +232,16 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ source_class = guard_import( "source_salesforce", pip_name="airbyte-source-salesforce" ).SourceSalesforce @@ -170,6 +255,8 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader): class AirbyteGongLoader(AirbyteCDKLoader): + """Loads records from Gong using an Airbyte source connector.""" + def __init__( self, config: Mapping[str, Any], @@ -177,6 +264,16 @@ class AirbyteGongLoader(AirbyteCDKLoader): record_handler: Optional[RecordHandler] = None, state: Optional[Any] = None, ) -> None: + """Initializes the loader. + + Args: + config: The config to pass to the source connector. + stream_name: The name of the stream to load. + record_handler: A function that takes in a record and an optional id and + returns a Document. If None, the record will be used as the document. + Defaults to None. + state: The state to pass to the source connector. Defaults to None. + """ source_class = guard_import( "source_gong", pip_name="airbyte-source-gong" ).SourceGong diff --git a/libs/langchain/langchain/document_loaders/parsers/audio.py b/libs/langchain/langchain/document_loaders/parsers/audio.py index fe394570a0b..91c6870f7e2 100644 --- a/libs/langchain/langchain/document_loaders/parsers/audio.py +++ b/libs/langchain/langchain/document_loaders/parsers/audio.py @@ -79,8 +79,10 @@ class OpenAIWhisperParser(BaseBlobParser): class OpenAIWhisperParserLocal(BaseBlobParser): - """Transcribe and parse audio files. - Audio transcription with OpenAI Whisper model locally from transformers + """Transcribe and parse audio files with OpenAI Whisper model. + + Audio transcription with OpenAI Whisper model locally from transformers. + Parameters: device - device to use NOTE: By default uses the gpu if available, @@ -105,6 +107,15 @@ class OpenAIWhisperParserLocal(BaseBlobParser): lang_model: Optional[str] = None, forced_decoder_ids: Optional[Tuple[Dict]] = None, ): + """Initialize the parser. + + Args: + device: device to use. + lang_model: whisper model to use, for example "openai/whisper-medium". + Defaults to None. + forced_decoder_ids: id states for decoder in a multilanguage model. + Defaults to None. + """ try: from transformers import pipeline except ImportError: diff --git a/libs/langchain/langchain/document_transformers/nuclia_text_transform.py b/libs/langchain/langchain/document_transformers/nuclia_text_transform.py index 45454215746..387f33b81d5 100644 --- a/libs/langchain/langchain/document_transformers/nuclia_text_transform.py +++ b/libs/langchain/langchain/document_transformers/nuclia_text_transform.py @@ -11,7 +11,7 @@ class NucliaTextTransformer(BaseDocumentTransformer): """ The Nuclia Understanding API splits into paragraphs and sentences, identifies entities, provides a summary of the text and generates - embeddings for all the sentences. + embeddings for all sentences. """ def __init__(self, nua: NucliaUnderstandingAPI): diff --git a/libs/langchain/langchain/embeddings/awa.py b/libs/langchain/langchain/embeddings/awa.py index e2def631d87..d854eb2f63a 100644 --- a/libs/langchain/langchain/embeddings/awa.py +++ b/libs/langchain/langchain/embeddings/awa.py @@ -6,6 +6,14 @@ from langchain.embeddings.base import Embeddings class AwaEmbeddings(BaseModel, Embeddings): + """Embedding documents and queries with Awa DB. + + Attributes: + client: The AwaEmbedding client. + model: The name of the model used for embedding. + Default is "all-mpnet-base-v2". + """ + client: Any #: :meta private: model: str = "all-mpnet-base-v2" diff --git a/libs/langchain/langchain/embeddings/embaas.py b/libs/langchain/langchain/embeddings/embaas.py index 945861dc832..d6985fc68c8 100644 --- a/libs/langchain/langchain/embeddings/embaas.py +++ b/libs/langchain/langchain/embeddings/embaas.py @@ -13,7 +13,7 @@ EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/" class EmbaasEmbeddingsPayload(TypedDict): - """Payload for the embaas embeddings API.""" + """Payload for the Embaas embeddings API.""" model: str texts: List[str] diff --git a/libs/langchain/langchain/evaluation/loading.py b/libs/langchain/langchain/evaluation/loading.py index 6a43ed297f1..c608c259a34 100644 --- a/libs/langchain/langchain/evaluation/loading.py +++ b/libs/langchain/langchain/evaluation/loading.py @@ -24,7 +24,7 @@ from langchain.schema.language_model import BaseLanguageModel def load_dataset(uri: str) -> List[Dict]: - """Load a dataset from the `LangChainDatasets HuggingFace org `_. + """Load a dataset from the `LangChainDatasets on HuggingFace `_. Args: uri: The uri of the dataset to load. diff --git a/libs/langchain/langchain/llms/aviary.py b/libs/langchain/langchain/llms/aviary.py index 9f9c0937d07..5e2a38cf714 100644 --- a/libs/langchain/langchain/llms/aviary.py +++ b/libs/langchain/langchain/llms/aviary.py @@ -15,6 +15,13 @@ TIMEOUT = 60 @dataclasses.dataclass class AviaryBackend: + """Aviary backend. + + Attributes: + backend_url: The URL for the Aviary backend. + bearer: The bearer token for the Aviary backend. + """ + backend_url: str bearer: str @@ -89,6 +96,14 @@ class Aviary(LLM): AVIARY_URL and AVIARY_TOKEN environment variables must be set. + Attributes: + model: The name of the model to use. Defaults to "amazon/LightGPT". + aviary_url: The URL for the Aviary backend. Defaults to None. + aviary_token: The bearer token for the Aviary backend. Defaults to None. + use_prompt_format: If True, the prompt template for the model will be ignored. + Defaults to True. + version: API version to use for Aviary. Defaults to None. + Example: .. code-block:: python diff --git a/libs/langchain/langchain/llms/fake.py b/libs/langchain/langchain/llms/fake.py index 8aa0fbea751..d8a5a7fd037 100644 --- a/libs/langchain/langchain/llms/fake.py +++ b/libs/langchain/langchain/llms/fake.py @@ -56,6 +56,8 @@ class FakeListLLM(LLM): class FakeStreamingListLLM(FakeListLLM): + """Fake streaming list LLM for testing purposes.""" + def stream( self, input: LanguageModelInput, diff --git a/libs/langchain/langchain/llms/vllm.py b/libs/langchain/langchain/llms/vllm.py index 46da858c775..0a6c4dbec04 100644 --- a/libs/langchain/langchain/llms/vllm.py +++ b/libs/langchain/langchain/llms/vllm.py @@ -8,6 +8,8 @@ from langchain.schema.output import Generation, LLMResult class VLLM(BaseLLM): + """VLLM language model.""" + model: str = "" """The name or path of a HuggingFace Transformers model.""" diff --git a/libs/langchain/langchain/schema/messages.py b/libs/langchain/langchain/schema/messages.py index 1722602be36..3d6219f7d07 100644 --- a/libs/langchain/langchain/schema/messages.py +++ b/libs/langchain/langchain/schema/messages.py @@ -88,6 +88,8 @@ class BaseMessage(Serializable): class BaseMessageChunk(BaseMessage): + """A Message chunk, which can be concatenated with other Message chunks.""" + def _merge_kwargs_dict( self, left: Dict[str, Any], right: Dict[str, Any] ) -> Dict[str, Any]: @@ -145,6 +147,8 @@ class HumanMessage(BaseMessage): class HumanMessageChunk(HumanMessage, BaseMessageChunk): + """A Human Message chunk.""" + pass @@ -163,6 +167,8 @@ class AIMessage(BaseMessage): class AIMessageChunk(AIMessage, BaseMessageChunk): + """A Message chunk from an AI.""" + pass @@ -178,6 +184,8 @@ class SystemMessage(BaseMessage): class SystemMessageChunk(SystemMessage, BaseMessageChunk): + """A System Message chunk.""" + pass @@ -194,6 +202,8 @@ class FunctionMessage(BaseMessage): class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): + """A Function Message chunk.""" + pass @@ -210,6 +220,8 @@ class ChatMessage(BaseMessage): class ChatMessageChunk(ChatMessage, BaseMessageChunk): + """A Chat Message chunk.""" + pass diff --git a/libs/langchain/langchain/schema/output.py b/libs/langchain/langchain/schema/output.py index 06d222ce889..10b9923bad6 100644 --- a/libs/langchain/langchain/schema/output.py +++ b/libs/langchain/langchain/schema/output.py @@ -29,6 +29,8 @@ class Generation(Serializable): class GenerationChunk(Generation): + """A Generation chunk, which can be concatenated with other Generation chunks.""" + def __add__(self, other: GenerationChunk) -> GenerationChunk: if isinstance(other, GenerationChunk): generation_info = ( @@ -62,6 +64,13 @@ class ChatGeneration(Generation): class ChatGenerationChunk(ChatGeneration): + """A ChatGeneration chunk, which can be concatenated with other + ChatGeneration chunks. + + Attributes: + message: The message chunk output by the chat model. + """ + message: BaseMessageChunk def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index aeeda0880df..9290c4cedec 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -56,6 +56,8 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): class BaseGenerationOutputParser( BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] ): + """Base class to parse the output of an LLM call.""" + def invoke( self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None ) -> T: diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 84399a2c0b9..7ed4a739e1a 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -48,6 +48,8 @@ async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> li class RunnableConfig(TypedDict, total=False): + """Configuration for a Runnable.""" + tags: List[str] """ Tags for this call and any sub-calls (eg. a Chain calling an LLM). @@ -74,6 +76,9 @@ Other = TypeVar("Other") class Runnable(Generic[Input, Output], ABC): + """A Runnable is a unit of work that can be invoked, batched, streamed, or + transformed.""" + def __or__( self, other: Union[ @@ -1325,6 +1330,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): class RouterInput(TypedDict): + """A Router input. + + Attributes: + key: The key to route on. + input: The input to pass to the selected runnable. + """ + key: str input: Any diff --git a/libs/langchain/langchain/tools/multion/create_session.py b/libs/langchain/langchain/tools/multion/create_session.py index 9ae2332cf33..b0551e49cab 100644 --- a/libs/langchain/langchain/tools/multion/create_session.py +++ b/libs/langchain/langchain/tools/multion/create_session.py @@ -31,6 +31,14 @@ class CreateSessionSchema(BaseModel): class MultionCreateSession(BaseTool): + """Tool that creates a new Multion Browser Window with provided fields. + + Attributes: + name: The name of the tool. Default: "create_multion_session" + description: The description of the tool. + args_schema: The schema for the tool's arguments. + """ + name: str = "create_multion_session" description: str = """Use this tool to create a new Multion Browser Window \ with provided fields.Always the first step to run \ diff --git a/libs/langchain/langchain/tools/multion/update_session.py b/libs/langchain/langchain/tools/multion/update_session.py index 0e724726e93..1f20d70d65d 100644 --- a/libs/langchain/langchain/tools/multion/update_session.py +++ b/libs/langchain/langchain/tools/multion/update_session.py @@ -34,6 +34,14 @@ class UpdateSessionSchema(BaseModel): class MultionUpdateSession(BaseTool): + """Tool that updates an existing Multion Browser Window with provided fields. + + Attributes: + name: The name of the tool. Default: "update_multion_session" + description: The description of the tool. + args_schema: The schema for the tool's arguments. Default: UpdateSessionSchema + """ + name: str = "update_multion_session" description: str = """Use this tool to update \ a existing corresponding \ diff --git a/libs/langchain/langchain/tools/nuclia/tool.py b/libs/langchain/langchain/tools/nuclia/tool.py index 663b156a648..054a439dbb0 100644 --- a/libs/langchain/langchain/tools/nuclia/tool.py +++ b/libs/langchain/langchain/tools/nuclia/tool.py @@ -28,6 +28,15 @@ logger = logging.getLogger(__name__) class NUASchema(BaseModel): + """Input for Nuclia Understanding API. + + Attributes: + action: Action to perform. Either `push` or `pull`. + id: ID of the file to push or pull. + path: Path to the file to push (needed only for `push` action). + text: Text content to process (needed only for `push` action). + """ + action: str = Field( ..., description="Action to perform. Either `push` or `pull`.", diff --git a/libs/langchain/langchain/utilities/portkey.py b/libs/langchain/langchain/utilities/portkey.py index 4c07a59fb9b..bf9044c4f15 100644 --- a/libs/langchain/langchain/utilities/portkey.py +++ b/libs/langchain/langchain/utilities/portkey.py @@ -4,6 +4,13 @@ from typing import Dict, Optional class Portkey: + """Portkey configuration. + + Attributes: + base: The base URL for the Portkey API. + Default: "https://api.portkey.ai/v1/proxy" + """ + base = "https://api.portkey.ai/v1/proxy" @staticmethod diff --git a/libs/langchain/langchain/utilities/spark_sql.py b/libs/langchain/langchain/utilities/spark_sql.py index 12cfbd2d8aa..ffecbe511d4 100644 --- a/libs/langchain/langchain/utilities/spark_sql.py +++ b/libs/langchain/langchain/utilities/spark_sql.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: class SparkSQL: + """SparkSQL is a utility class for interacting with Spark SQL.""" + def __init__( self, spark_session: Optional[SparkSession] = None, @@ -16,10 +18,26 @@ class SparkSQL: include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, ): + """Initialize a SparkSQL object. + + Args: + spark_session: A SparkSession object. + If not provided, one will be created. + catalog: The catalog to use. + If not provided, the default catalog will be used. + schema: The schema to use. + If not provided, the default schema will be used. + ignore_tables: A list of tables to ignore. + If not provided, all tables will be used. + include_tables: A list of tables to include. + If not provided, all tables will be used. + sample_rows_in_table_info: The number of rows to include in the table info. + Defaults to 3. + """ try: from pyspark.sql import SparkSession except ImportError: - raise ValueError( + raise ImportError( "pyspark is not installed. Please install it with `pip install pyspark`" ) diff --git a/libs/langchain/langchain/utils/utils.py b/libs/langchain/langchain/utils/utils.py index 6257ca330e9..77ccbf68914 100644 --- a/libs/langchain/langchain/utils/utils.py +++ b/libs/langchain/langchain/utils/utils.py @@ -141,7 +141,13 @@ def build_extra_kwargs( values: Dict[str, Any], all_required_field_names: Set[str], ) -> Dict[str, Any]: - """""" + """Build extra kwargs from values and extra_kwargs. + + Args: + extra_kwargs: Extra kwargs passed in by user. + values: Values passed in by user. + all_required_field_names: All required field names for the pydantic class. + """ for field_name in list(values): if field_name in extra_kwargs: raise ValueError(f"Found {field_name} supplied twice.") diff --git a/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py b/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py index 10385e56805..f8cf664cc4d 100644 --- a/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py +++ b/libs/langchain/langchain/vectorstores/alibabacloud_opensearch.py @@ -12,19 +12,20 @@ logger = logging.getLogger() class AlibabaCloudOpenSearchSettings: - """Opensearch Client Configuration + """Alibaba Cloud Opensearch Client Configuration. + Attribute: endpoint (str) : The endpoint of opensearch instance, You can find it - from the console of Alibaba Cloud OpenSearch. + from the console of Alibaba Cloud OpenSearch. instance_id (str) : The identify of opensearch instance, You can find - it from the console of Alibaba Cloud OpenSearch. + it from the console of Alibaba Cloud OpenSearch. datasource_name (str): The name of the data source specified when creating it. username (str) : The username specified when purchasing the instance. password (str) : The password specified when purchasing the instance. embedding_index_name (str) : The name of the vector attribute specified - when configuring the instance attributes. + when configuring the instance attributes. field_name_mapping (Dict) : Using field name mapping between opensearch - vector store and opensearch instance configuration table field names: + vector store and opensearch instance configuration table field names: { 'id': 'The id field name map of index document.', 'document': 'The text field name map of index document.', diff --git a/libs/langchain/langchain/vectorstores/hologres.py b/libs/langchain/langchain/vectorstores/hologres.py index 23073708551..092dc24c364 100644 --- a/libs/langchain/langchain/vectorstores/hologres.py +++ b/libs/langchain/langchain/vectorstores/hologres.py @@ -16,7 +16,17 @@ _LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_pg_embedding" class HologresWrapper: + """Wrapper around Hologres service.""" + def __init__(self, connection_string: str, ndims: int, table_name: str) -> None: + """Initialize the wrapper. + + Args: + connection_string: Hologres connection string. + ndims: Number of dimensions of the embedding output. + table_name: Name of the table to store embeddings and data. + """ + import psycopg2 self.table_name = table_name diff --git a/libs/langchain/langchain/vectorstores/pgembedding.py b/libs/langchain/langchain/vectorstores/pgembedding.py index 02e1936a724..4c820636c62 100644 --- a/libs/langchain/langchain/vectorstores/pgembedding.py +++ b/libs/langchain/langchain/vectorstores/pgembedding.py @@ -87,6 +87,8 @@ class EmbeddingStore(BaseModel): class QueryResult: + """QueryResult is a result from a query.""" + EmbeddingStore: EmbeddingStore distance: float diff --git a/libs/langchain/langchain/vectorstores/scann.py b/libs/langchain/langchain/vectorstores/scann.py index b1eb9a9db15..a1ce4af4808 100644 --- a/libs/langchain/langchain/vectorstores/scann.py +++ b/libs/langchain/langchain/vectorstores/scann.py @@ -18,6 +18,7 @@ from langchain.vectorstores.utils import DistanceStrategy def normalize(x: np.ndarray) -> np.ndarray: + """Normalize vectors to unit length.""" x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None) return x From efa02ed76800dc640b3188544292c474a64591c9 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 9 Aug 2023 18:56:51 -0400 Subject: [PATCH 4/8] Suppress divide by zero wranings for cosine similarity (#9006) Suppress run time warnings for divide by zero as the downstream code handles the scenario (handling inf and nan) --- libs/langchain/langchain/utils/math.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/utils/math.py b/libs/langchain/langchain/utils/math.py index ae3c2919342..77784ba2a49 100644 --- a/libs/langchain/langchain/utils/math.py +++ b/libs/langchain/langchain/utils/math.py @@ -20,7 +20,9 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: X_norm = np.linalg.norm(X, axis=1) Y_norm = np.linalg.norm(Y, axis=1) - similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + # Ignore divide by zero errors run time warnings as those are handled below. + with np.errstate(divide="ignore", invalid="ignore"): + similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity From d248481f130b2f6dcf0db80348f58a01ac69b18f Mon Sep 17 00:00:00 2001 From: IanRogers-101Ways <140076427+IanRogers-101Ways@users.noreply.github.com> Date: Thu, 10 Aug 2023 00:05:02 +0100 Subject: [PATCH 5/8] skip over empty google spreadsheets (#8974) - Description: Allow GoogleDriveLoader to handle empty spreadsheets - Issue: Currently GoogleDriveLoader will crash if it tries to load a spreadsheet with an empty sheet - Dependencies: n/a - Tag maintainer: @rlancemartin, @eyurtsev --- libs/langchain/langchain/document_loaders/googledrive.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/langchain/langchain/document_loaders/googledrive.py b/libs/langchain/langchain/document_loaders/googledrive.py index 4538b469d02..9a0290a35ef 100644 --- a/libs/langchain/langchain/document_loaders/googledrive.py +++ b/libs/langchain/langchain/document_loaders/googledrive.py @@ -169,6 +169,8 @@ class GoogleDriveLoader(BaseLoader, BaseModel): .execute() ) values = result.get("values", []) + if not values: + continue # empty sheet header = values[0] for i, row in enumerate(values[1:], start=1): From bbbd2b076f5b41b62134c2f62ae63fa6bbe3b7a7 Mon Sep 17 00:00:00 2001 From: Kaizen <45839100+amovfx@users.noreply.github.com> Date: Wed, 9 Aug 2023 16:05:16 -0700 Subject: [PATCH 6/8] DirectoryLoader slicing (#8994) DirectoryLoader can now return a random sample of files in a directory. Parameters added are: sample_size randomize_sample sample_seed @rlancemartin, @eyurtsev --------- Co-authored-by: Andrew Oseen Co-authored-by: Bagatur --- .../langchain/document_loaders/directory.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/libs/langchain/langchain/document_loaders/directory.py b/libs/langchain/langchain/document_loaders/directory.py index 95bc963021c..729c953236e 100644 --- a/libs/langchain/langchain/document_loaders/directory.py +++ b/libs/langchain/langchain/document_loaders/directory.py @@ -1,6 +1,7 @@ """Load documents from a directory.""" import concurrent import logging +import random from pathlib import Path from typing import Any, List, Optional, Type, Union @@ -39,6 +40,10 @@ class DirectoryLoader(BaseLoader): show_progress: bool = False, use_multithreading: bool = False, max_concurrency: int = 4, + *, + sample_size: int = 0, + randomize_sample: bool = False, + sample_seed: Union[int, None] = None, ): """Initialize with a path to directory and how to glob over it. @@ -55,6 +60,10 @@ class DirectoryLoader(BaseLoader): show_progress: Whether to show a progress bar. Defaults to False. use_multithreading: Whether to use multithreading. Defaults to False. max_concurrency: The maximum number of threads to use. Defaults to 4. + sample_size: The maximum number of files you would like to load from the + directory. + randomize_sample: Suffle the files to get a random sample. + sample_seed: set the seed of the random shuffle for reporoducibility. """ if loader_kwargs is None: loader_kwargs = {} @@ -68,6 +77,9 @@ class DirectoryLoader(BaseLoader): self.show_progress = show_progress self.use_multithreading = use_multithreading self.max_concurrency = max_concurrency + self.sample_size = sample_size + self.randomize_sample = randomize_sample + self.sample_seed = sample_seed def load_file( self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any] @@ -107,6 +119,14 @@ class DirectoryLoader(BaseLoader): docs: List[Document] = [] items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob)) + if self.sample_size > 0: + if self.randomize_sample: + randomizer = ( + random.Random(self.sample_seed) if self.sample_seed else random + ) + randomizer.shuffle(items) # type: ignore + items = items[: min(len(items), self.sample_size)] + pbar = None if self.show_progress: try: From 3b51817706c32fe0fe54e6af19243105a9cd61c6 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 9 Aug 2023 17:08:48 -0700 Subject: [PATCH 7/8] Updating port and ssl use in sample notebook (#8995) ## Description This PR updates the sample notebook to use the default port (8182) and the ssl for the Neptune database connection. --- docs/extras/use_cases/graph/neptune_cypher_qa.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/extras/use_cases/graph/neptune_cypher_qa.ipynb b/docs/extras/use_cases/graph/neptune_cypher_qa.ipynb index 4b565e09333..68bb4e9e12a 100644 --- a/docs/extras/use_cases/graph/neptune_cypher_qa.ipynb +++ b/docs/extras/use_cases/graph/neptune_cypher_qa.ipynb @@ -18,8 +18,8 @@ "\n", "\n", "host = \"\"\n", - "port = 80\n", - "use_https = False\n", + "port = 8182\n", + "use_https = True\n", "\n", "graph = NeptuneGraph(host=host, port=port, use_https=use_https)" ] From f4a47ec7175925270e09e4161f9375b4305645ac Mon Sep 17 00:00:00 2001 From: colegottdank Date: Wed, 9 Aug 2023 17:34:00 -0700 Subject: [PATCH 8/8] Add optional model kwargs to ChatAnthropic to allow overrides (#9013) --------- Co-authored-by: Bagatur --- libs/langchain/langchain/llms/anthropic.py | 22 ++++++++++++--- libs/langchain/poetry.lock | 4 +-- libs/langchain/pyproject.toml | 1 + .../unit_tests/chat_models/test_anthropic.py | 27 +++++++++++++++++++ 4 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/chat_models/test_anthropic.py diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index f32f581d1f7..5e5695762f2 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -2,7 +2,7 @@ import re import warnings from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional -from pydantic import root_validator +from pydantic import Field, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -11,7 +11,12 @@ from langchain.callbacks.manager import ( from langchain.llms.base import LLM from langchain.schema.language_model import BaseLanguageModel from langchain.schema.output import GenerationChunk -from langchain.utils import check_package_version, get_from_dict_or_env +from langchain.utils import ( + check_package_version, + get_from_dict_or_env, + get_pydantic_field_names, +) +from langchain.utils.utils import build_extra_kwargs class _AnthropicCommon(BaseLanguageModel): @@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel): HUMAN_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None count_tokens: Optional[Callable[[str], int]] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator(pre=True) + def build_extra(cls, values: Dict) -> Dict: + extra = values.get("model_kwargs", {}) + all_required_field_names = get_pydantic_field_names(cls) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel): values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT values["AI_PROMPT"] = anthropic.AI_PROMPT values["count_tokens"] = values["client"].count_tokens + except ImportError: raise ImportError( "Could not import anthropic python package. " @@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel): d["top_k"] = self.top_k if self.top_p is not None: d["top_p"] = self.top_p - return d + return {**d, **self.model_kwargs} @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index dc34e82d2e6..712a4406bff 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -13608,7 +13608,7 @@ clarifai = ["clarifai"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"] +extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"] javascript = ["esprima"] llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"] openai = ["openai", "tiktoken"] @@ -13619,4 +13619,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0" +content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 2d5aef61224..3fb30ee0d6c 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -373,6 +373,7 @@ extended_testing = [ "feedparser", "xata", "xmltodict", + "anthropic", ] scheduled_testing = [ diff --git a/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py new file mode 100644 index 00000000000..7447ec03e41 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_anthropic.py @@ -0,0 +1,27 @@ +"""Test Anthropic Chat API wrapper.""" +import os + +import pytest + +from langchain.chat_models import ChatAnthropic + +os.environ["ANTHROPIC_API_KEY"] = "foo" + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_kwargs() -> None: + llm = ChatAnthropic(model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("anthropic") +def test_anthropic_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5}) + + +@pytest.mark.requires("anthropic") +def test_anthropic_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = ChatAnthropic(foo="bar") + assert llm.model_kwargs == {"foo": "bar"}