Merge branch 'master' into bagatur/locals_in_config

This commit is contained in:
Bagatur 2023-08-09 17:56:33 -07:00
commit f8ed93e7bd
37 changed files with 367 additions and 20 deletions

View File

@ -18,8 +18,8 @@
"\n", "\n",
"\n", "\n",
"host = \"<neptune-host>\"\n", "host = \"<neptune-host>\"\n",
"port = 80\n", "port = 8182\n",
"use_https = False\n", "use_https = True\n",
"\n", "\n",
"graph = NeptuneGraph(host=host, port=port, use_https=use_https)" "graph = NeptuneGraph(host=host, port=port, use_https=use_https)"
] ]

View File

@ -55,6 +55,16 @@ def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
@dataclass(frozen=True) @dataclass(frozen=True)
class ReducedOpenAPISpec: class ReducedOpenAPISpec:
"""A reduced OpenAPI spec.
This is a quick and dirty representation for OpenAPI specs.
Attributes:
servers: The servers in the spec.
description: The description of the spec.
endpoints: The endpoints in the spec.
"""
servers: List[dict] servers: List[dict]
description: str description: str
endpoints: List[Tuple[str, str, dict]] endpoints: List[Tuple[str, str, dict]]

View File

@ -10,6 +10,8 @@ from langchain.tools.base import BaseTool
class XMLAgentOutputParser(AgentOutputParser): class XMLAgentOutputParser(AgentOutputParser):
"""Output parser for XMLAgent."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if "</tool>" in text: if "</tool>" in text:
tool, tool_input = text.split("</tool>") tool, tool_input = text.split("</tool>")

View File

@ -49,6 +49,8 @@ class ElementInViewPort(TypedDict):
class Crawler: class Crawler:
"""A crawler for web pages."""
def __init__(self) -> None: def __init__(self) -> None:
try: try:
from playwright.sync_api import sync_playwright from playwright.sync_api import sync_playwright

View File

@ -9,6 +9,7 @@ try:
except ImportError: except ImportError:
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore
"""Dummy decorator for when lark is not installed."""
return lambda _: None return lambda _: None
Transformer = object # type: ignore Transformer = object # type: ignore

View File

@ -51,6 +51,8 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"""Base class for chat models."""
cache: Optional[bool] = None cache: Optional[bool] = None
"""Whether to cache the response.""" """Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity) verbose: bool = Field(default_factory=_get_verbosity)

View File

@ -16,6 +16,16 @@
""" """
from langchain.document_loaders.acreom import AcreomLoader from langchain.document_loaders.acreom import AcreomLoader
from langchain.document_loaders.airbyte import (
AirbyteCDKLoader,
AirbyteGongLoader,
AirbyteHubspotLoader,
AirbyteSalesforceLoader,
AirbyteShopifyLoader,
AirbyteStripeLoader,
AirbyteTypeformLoader,
AirbyteZendeskSupportLoader,
)
from langchain.document_loaders.airbyte_json import AirbyteJSONLoader from langchain.document_loaders.airbyte_json import AirbyteJSONLoader
from langchain.document_loaders.airtable import AirtableLoader from langchain.document_loaders.airtable import AirtableLoader
from langchain.document_loaders.apify_dataset import ApifyDatasetLoader from langchain.document_loaders.apify_dataset import ApifyDatasetLoader
@ -188,7 +198,15 @@ TelegramChatLoader = TelegramChatFileLoader
__all__ = [ __all__ = [
"AZLyricsLoader", "AZLyricsLoader",
"AcreomLoader", "AcreomLoader",
"AirbyteCDKLoader",
"AirbyteGongLoader",
"AirbyteJSONLoader", "AirbyteJSONLoader",
"AirbyteHubspotLoader",
"AirbyteSalesforceLoader",
"AirbyteShopifyLoader",
"AirbyteStripeLoader",
"AirbyteTypeformLoader",
"AirbyteZendeskSupportLoader",
"AirtableLoader", "AirtableLoader",
"AmazonTextractPDFLoader", "AmazonTextractPDFLoader",
"ApifyDatasetLoader", "ApifyDatasetLoader",

View File

@ -19,6 +19,17 @@ class AirbyteCDKLoader(BaseLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
source_class: The source connector class.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage
from airbyte_cdk.sources.embedded.base_integration import ( from airbyte_cdk.sources.embedded.base_integration import (
BaseEmbeddedIntegration, BaseEmbeddedIntegration,
@ -26,6 +37,8 @@ class AirbyteCDKLoader(BaseLoader):
from airbyte_cdk.sources.embedded.runner import CDKRunner from airbyte_cdk.sources.embedded.runner import CDKRunner
class CDKIntegration(BaseEmbeddedIntegration): class CDKIntegration(BaseEmbeddedIntegration):
"""A wrapper around the CDK integration."""
def _handle_record( def _handle_record(
self, record: AirbyteRecordMessage, id: Optional[str] self, record: AirbyteRecordMessage, id: Optional[str]
) -> Document: ) -> Document:
@ -50,6 +63,8 @@ class AirbyteCDKLoader(BaseLoader):
class AirbyteHubspotLoader(AirbyteCDKLoader): class AirbyteHubspotLoader(AirbyteCDKLoader):
"""Loads records from Hubspot using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -57,6 +72,16 @@ class AirbyteHubspotLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_hubspot", pip_name="airbyte-source-hubspot" "source_hubspot", pip_name="airbyte-source-hubspot"
).SourceHubspot ).SourceHubspot
@ -70,6 +95,8 @@ class AirbyteHubspotLoader(AirbyteCDKLoader):
class AirbyteStripeLoader(AirbyteCDKLoader): class AirbyteStripeLoader(AirbyteCDKLoader):
"""Loads records from Stripe using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -77,6 +104,16 @@ class AirbyteStripeLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_stripe", pip_name="airbyte-source-stripe" "source_stripe", pip_name="airbyte-source-stripe"
).SourceStripe ).SourceStripe
@ -90,6 +127,8 @@ class AirbyteStripeLoader(AirbyteCDKLoader):
class AirbyteTypeformLoader(AirbyteCDKLoader): class AirbyteTypeformLoader(AirbyteCDKLoader):
"""Loads records from Typeform using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -97,6 +136,16 @@ class AirbyteTypeformLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_typeform", pip_name="airbyte-source-typeform" "source_typeform", pip_name="airbyte-source-typeform"
).SourceTypeform ).SourceTypeform
@ -110,6 +159,8 @@ class AirbyteTypeformLoader(AirbyteCDKLoader):
class AirbyteZendeskSupportLoader(AirbyteCDKLoader): class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
"""Loads records from Zendesk Support using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -117,6 +168,16 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_zendesk_support", pip_name="airbyte-source-zendesk-support" "source_zendesk_support", pip_name="airbyte-source-zendesk-support"
).SourceZendeskSupport ).SourceZendeskSupport
@ -130,6 +191,8 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
class AirbyteShopifyLoader(AirbyteCDKLoader): class AirbyteShopifyLoader(AirbyteCDKLoader):
"""Loads records from Shopify using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -137,6 +200,16 @@ class AirbyteShopifyLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_shopify", pip_name="airbyte-source-shopify" "source_shopify", pip_name="airbyte-source-shopify"
).SourceShopify ).SourceShopify
@ -150,6 +223,8 @@ class AirbyteShopifyLoader(AirbyteCDKLoader):
class AirbyteSalesforceLoader(AirbyteCDKLoader): class AirbyteSalesforceLoader(AirbyteCDKLoader):
"""Loads records from Salesforce using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -157,6 +232,16 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_salesforce", pip_name="airbyte-source-salesforce" "source_salesforce", pip_name="airbyte-source-salesforce"
).SourceSalesforce ).SourceSalesforce
@ -170,6 +255,8 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader):
class AirbyteGongLoader(AirbyteCDKLoader): class AirbyteGongLoader(AirbyteCDKLoader):
"""Loads records from Gong using an Airbyte source connector."""
def __init__( def __init__(
self, self,
config: Mapping[str, Any], config: Mapping[str, Any],
@ -177,6 +264,16 @@ class AirbyteGongLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None, record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> None: ) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import( source_class = guard_import(
"source_gong", pip_name="airbyte-source-gong" "source_gong", pip_name="airbyte-source-gong"
).SourceGong ).SourceGong

View File

@ -1,6 +1,7 @@
"""Load documents from a directory.""" """Load documents from a directory."""
import concurrent import concurrent
import logging import logging
import random
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
@ -39,6 +40,10 @@ class DirectoryLoader(BaseLoader):
show_progress: bool = False, show_progress: bool = False,
use_multithreading: bool = False, use_multithreading: bool = False,
max_concurrency: int = 4, max_concurrency: int = 4,
*,
sample_size: int = 0,
randomize_sample: bool = False,
sample_seed: Union[int, None] = None,
): ):
"""Initialize with a path to directory and how to glob over it. """Initialize with a path to directory and how to glob over it.
@ -55,6 +60,10 @@ class DirectoryLoader(BaseLoader):
show_progress: Whether to show a progress bar. Defaults to False. show_progress: Whether to show a progress bar. Defaults to False.
use_multithreading: Whether to use multithreading. Defaults to False. use_multithreading: Whether to use multithreading. Defaults to False.
max_concurrency: The maximum number of threads to use. Defaults to 4. max_concurrency: The maximum number of threads to use. Defaults to 4.
sample_size: The maximum number of files you would like to load from the
directory.
randomize_sample: Suffle the files to get a random sample.
sample_seed: set the seed of the random shuffle for reporoducibility.
""" """
if loader_kwargs is None: if loader_kwargs is None:
loader_kwargs = {} loader_kwargs = {}
@ -68,6 +77,9 @@ class DirectoryLoader(BaseLoader):
self.show_progress = show_progress self.show_progress = show_progress
self.use_multithreading = use_multithreading self.use_multithreading = use_multithreading
self.max_concurrency = max_concurrency self.max_concurrency = max_concurrency
self.sample_size = sample_size
self.randomize_sample = randomize_sample
self.sample_seed = sample_seed
def load_file( def load_file(
self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any] self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any]
@ -107,6 +119,14 @@ class DirectoryLoader(BaseLoader):
docs: List[Document] = [] docs: List[Document] = []
items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob)) items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob))
if self.sample_size > 0:
if self.randomize_sample:
randomizer = (
random.Random(self.sample_seed) if self.sample_seed else random
)
randomizer.shuffle(items) # type: ignore
items = items[: min(len(items), self.sample_size)]
pbar = None pbar = None
if self.show_progress: if self.show_progress:
try: try:

View File

@ -169,6 +169,8 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
.execute() .execute()
) )
values = result.get("values", []) values = result.get("values", [])
if not values:
continue # empty sheet
header = values[0] header = values[0]
for i, row in enumerate(values[1:], start=1): for i, row in enumerate(values[1:], start=1):

View File

@ -79,8 +79,10 @@ class OpenAIWhisperParser(BaseBlobParser):
class OpenAIWhisperParserLocal(BaseBlobParser): class OpenAIWhisperParserLocal(BaseBlobParser):
"""Transcribe and parse audio files. """Transcribe and parse audio files with OpenAI Whisper model.
Audio transcription with OpenAI Whisper model locally from transformers
Audio transcription with OpenAI Whisper model locally from transformers.
Parameters: Parameters:
device - device to use device - device to use
NOTE: By default uses the gpu if available, NOTE: By default uses the gpu if available,
@ -105,6 +107,15 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
lang_model: Optional[str] = None, lang_model: Optional[str] = None,
forced_decoder_ids: Optional[Tuple[Dict]] = None, forced_decoder_ids: Optional[Tuple[Dict]] = None,
): ):
"""Initialize the parser.
Args:
device: device to use.
lang_model: whisper model to use, for example "openai/whisper-medium".
Defaults to None.
forced_decoder_ids: id states for decoder in a multilanguage model.
Defaults to None.
"""
try: try:
from transformers import pipeline from transformers import pipeline
except ImportError: except ImportError:

View File

@ -11,7 +11,7 @@ class NucliaTextTransformer(BaseDocumentTransformer):
""" """
The Nuclia Understanding API splits into paragraphs and sentences, The Nuclia Understanding API splits into paragraphs and sentences,
identifies entities, provides a summary of the text and generates identifies entities, provides a summary of the text and generates
embeddings for all the sentences. embeddings for all sentences.
""" """
def __init__(self, nua: NucliaUnderstandingAPI): def __init__(self, nua: NucliaUnderstandingAPI):

View File

@ -6,6 +6,14 @@ from langchain.embeddings.base import Embeddings
class AwaEmbeddings(BaseModel, Embeddings): class AwaEmbeddings(BaseModel, Embeddings):
"""Embedding documents and queries with Awa DB.
Attributes:
client: The AwaEmbedding client.
model: The name of the model used for embedding.
Default is "all-mpnet-base-v2".
"""
client: Any #: :meta private: client: Any #: :meta private:
model: str = "all-mpnet-base-v2" model: str = "all-mpnet-base-v2"

View File

@ -13,7 +13,7 @@ EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"
class EmbaasEmbeddingsPayload(TypedDict): class EmbaasEmbeddingsPayload(TypedDict):
"""Payload for the embaas embeddings API.""" """Payload for the Embaas embeddings API."""
model: str model: str
texts: List[str] texts: List[str]

View File

@ -24,7 +24,7 @@ from langchain.schema.language_model import BaseLanguageModel
def load_dataset(uri: str) -> List[Dict]: def load_dataset(uri: str) -> List[Dict]:
"""Load a dataset from the `LangChainDatasets HuggingFace org <https://huggingface.co/LangChainDatasets>`_. """Load a dataset from the `LangChainDatasets on HuggingFace <https://huggingface.co/LangChainDatasets>`_.
Args: Args:
uri: The uri of the dataset to load. uri: The uri of the dataset to load.

View File

@ -2,7 +2,7 @@ import re
import warnings import warnings
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
from pydantic import root_validator from pydantic import Field, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -11,7 +11,12 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.utils import check_package_version, get_from_dict_or_env from langchain.utils import (
check_package_version,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain.utils.utils import build_extra_kwargs
class _AnthropicCommon(BaseLanguageModel): class _AnthropicCommon(BaseLanguageModel):
@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel):
HUMAN_PROMPT: Optional[str] = None HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel):
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens values["count_tokens"] = values["client"].count_tokens
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import anthropic python package. " "Could not import anthropic python package. "
@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel):
d["top_k"] = self.top_k d["top_k"] = self.top_k
if self.top_p is not None: if self.top_p is not None:
d["top_p"] = self.top_p d["top_p"] = self.top_p
return d return {**d, **self.model_kwargs}
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:

View File

@ -15,6 +15,13 @@ TIMEOUT = 60
@dataclasses.dataclass @dataclasses.dataclass
class AviaryBackend: class AviaryBackend:
"""Aviary backend.
Attributes:
backend_url: The URL for the Aviary backend.
bearer: The bearer token for the Aviary backend.
"""
backend_url: str backend_url: str
bearer: str bearer: str
@ -89,6 +96,14 @@ class Aviary(LLM):
AVIARY_URL and AVIARY_TOKEN environment variables must be set. AVIARY_URL and AVIARY_TOKEN environment variables must be set.
Attributes:
model: The name of the model to use. Defaults to "amazon/LightGPT".
aviary_url: The URL for the Aviary backend. Defaults to None.
aviary_token: The bearer token for the Aviary backend. Defaults to None.
use_prompt_format: If True, the prompt template for the model will be ignored.
Defaults to True.
version: API version to use for Aviary. Defaults to None.
Example: Example:
.. code-block:: python .. code-block:: python

View File

@ -56,6 +56,8 @@ class FakeListLLM(LLM):
class FakeStreamingListLLM(FakeListLLM): class FakeStreamingListLLM(FakeListLLM):
"""Fake streaming list LLM for testing purposes."""
def stream( def stream(
self, self,
input: LanguageModelInput, input: LanguageModelInput,

View File

@ -8,6 +8,8 @@ from langchain.schema.output import Generation, LLMResult
class VLLM(BaseLLM): class VLLM(BaseLLM):
"""VLLM language model."""
model: str = "" model: str = ""
"""The name or path of a HuggingFace Transformers model.""" """The name or path of a HuggingFace Transformers model."""
@ -54,6 +56,9 @@ class VLLM(BaseLLM):
max_new_tokens: int = 512 max_new_tokens: int = 512
"""Maximum number of tokens to generate per output sequence.""" """Maximum number of tokens to generate per output sequence."""
logprobs: Optional[int] = None
"""Number of log probabilities to return per output token."""
client: Any #: :meta private: client: Any #: :meta private:
@root_validator() @root_validator()
@ -91,6 +96,7 @@ class VLLM(BaseLLM):
"stop": self.stop, "stop": self.stop,
"ignore_eos": self.ignore_eos, "ignore_eos": self.ignore_eos,
"use_beam_search": self.use_beam_search, "use_beam_search": self.use_beam_search,
"logprobs": self.logprobs,
} }
def _generate( def _generate(

View File

@ -88,6 +88,8 @@ class BaseMessage(Serializable):
class BaseMessageChunk(BaseMessage): class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
def _merge_kwargs_dict( def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any] self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -145,6 +147,8 @@ class HumanMessage(BaseMessage):
class HumanMessageChunk(HumanMessage, BaseMessageChunk): class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
pass pass
@ -163,6 +167,8 @@ class AIMessage(BaseMessage):
class AIMessageChunk(AIMessage, BaseMessageChunk): class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
pass pass
@ -178,6 +184,8 @@ class SystemMessage(BaseMessage):
class SystemMessageChunk(SystemMessage, BaseMessageChunk): class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
pass pass
@ -194,6 +202,8 @@ class FunctionMessage(BaseMessage):
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
pass pass
@ -210,6 +220,8 @@ class ChatMessage(BaseMessage):
class ChatMessageChunk(ChatMessage, BaseMessageChunk): class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
pass pass

View File

@ -29,6 +29,8 @@ class Generation(Serializable):
class GenerationChunk(Generation): class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
def __add__(self, other: GenerationChunk) -> GenerationChunk: def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk): if isinstance(other, GenerationChunk):
generation_info = ( generation_info = (
@ -62,6 +64,13 @@ class ChatGeneration(Generation):
class ChatGenerationChunk(ChatGeneration): class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
ChatGeneration chunks.
Attributes:
message: The message chunk output by the chat model.
"""
message: BaseMessageChunk message: BaseMessageChunk
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:

View File

@ -56,6 +56,8 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseGenerationOutputParser( class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
): ):
"""Base class to parse the output of an LLM call."""
def invoke( def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T: ) -> T:

View File

@ -49,6 +49,8 @@ async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> li
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str] tags: List[str]
""" """
Tags for this call and any sub-calls (eg. a Chain calling an LLM). Tags for this call and any sub-calls (eg. a Chain calling an LLM).
@ -104,6 +106,9 @@ Other = TypeVar("Other")
class Runnable(Generic[Input, Output], ABC): class Runnable(Generic[Input, Output], ABC):
"""A Runnable is a unit of work that can be invoked, batched, streamed, or
transformed."""
def __or__( def __or__(
self, self,
other: Union[ other: Union[
@ -1300,6 +1305,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
class RouterInput(TypedDict): class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str key: str
input: Any input: Any

View File

@ -31,6 +31,14 @@ class CreateSessionSchema(BaseModel):
class MultionCreateSession(BaseTool): class MultionCreateSession(BaseTool):
"""Tool that creates a new Multion Browser Window with provided fields.
Attributes:
name: The name of the tool. Default: "create_multion_session"
description: The description of the tool.
args_schema: The schema for the tool's arguments.
"""
name: str = "create_multion_session" name: str = "create_multion_session"
description: str = """Use this tool to create a new Multion Browser Window \ description: str = """Use this tool to create a new Multion Browser Window \
with provided fields.Always the first step to run \ with provided fields.Always the first step to run \

View File

@ -34,6 +34,14 @@ class UpdateSessionSchema(BaseModel):
class MultionUpdateSession(BaseTool): class MultionUpdateSession(BaseTool):
"""Tool that updates an existing Multion Browser Window with provided fields.
Attributes:
name: The name of the tool. Default: "update_multion_session"
description: The description of the tool.
args_schema: The schema for the tool's arguments. Default: UpdateSessionSchema
"""
name: str = "update_multion_session" name: str = "update_multion_session"
description: str = """Use this tool to update \ description: str = """Use this tool to update \
a existing corresponding \ a existing corresponding \

View File

@ -28,6 +28,15 @@ logger = logging.getLogger(__name__)
class NUASchema(BaseModel): class NUASchema(BaseModel):
"""Input for Nuclia Understanding API.
Attributes:
action: Action to perform. Either `push` or `pull`.
id: ID of the file to push or pull.
path: Path to the file to push (needed only for `push` action).
text: Text content to process (needed only for `push` action).
"""
action: str = Field( action: str = Field(
..., ...,
description="Action to perform. Either `push` or `pull`.", description="Action to perform. Either `push` or `pull`.",

View File

@ -4,6 +4,13 @@ from typing import Dict, Optional
class Portkey: class Portkey:
"""Portkey configuration.
Attributes:
base: The base URL for the Portkey API.
Default: "https://api.portkey.ai/v1/proxy"
"""
base = "https://api.portkey.ai/v1/proxy" base = "https://api.portkey.ai/v1/proxy"
@staticmethod @staticmethod

View File

@ -7,6 +7,8 @@ if TYPE_CHECKING:
class SparkSQL: class SparkSQL:
"""SparkSQL is a utility class for interacting with Spark SQL."""
def __init__( def __init__(
self, self,
spark_session: Optional[SparkSession] = None, spark_session: Optional[SparkSession] = None,
@ -16,10 +18,26 @@ class SparkSQL:
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3, sample_rows_in_table_info: int = 3,
): ):
"""Initialize a SparkSQL object.
Args:
spark_session: A SparkSession object.
If not provided, one will be created.
catalog: The catalog to use.
If not provided, the default catalog will be used.
schema: The schema to use.
If not provided, the default schema will be used.
ignore_tables: A list of tables to ignore.
If not provided, all tables will be used.
include_tables: A list of tables to include.
If not provided, all tables will be used.
sample_rows_in_table_info: The number of rows to include in the table info.
Defaults to 3.
"""
try: try:
from pyspark.sql import SparkSession from pyspark.sql import SparkSession
except ImportError: except ImportError:
raise ValueError( raise ImportError(
"pyspark is not installed. Please install it with `pip install pyspark`" "pyspark is not installed. Please install it with `pip install pyspark`"
) )

View File

@ -20,6 +20,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
X_norm = np.linalg.norm(X, axis=1) X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1) Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity return similarity

View File

@ -141,7 +141,13 @@ def build_extra_kwargs(
values: Dict[str, Any], values: Dict[str, Any],
all_required_field_names: Set[str], all_required_field_names: Set[str],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""""" """Build extra kwargs from values and extra_kwargs.
Args:
extra_kwargs: Extra kwargs passed in by user.
values: Values passed in by user.
all_required_field_names: All required field names for the pydantic class.
"""
for field_name in list(values): for field_name in list(values):
if field_name in extra_kwargs: if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.") raise ValueError(f"Found {field_name} supplied twice.")

View File

@ -12,7 +12,8 @@ logger = logging.getLogger()
class AlibabaCloudOpenSearchSettings: class AlibabaCloudOpenSearchSettings:
"""Opensearch Client Configuration """Alibaba Cloud Opensearch Client Configuration.
Attribute: Attribute:
endpoint (str) : The endpoint of opensearch instance, You can find it endpoint (str) : The endpoint of opensearch instance, You can find it
from the console of Alibaba Cloud OpenSearch. from the console of Alibaba Cloud OpenSearch.

View File

@ -16,7 +16,17 @@ _LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_pg_embedding"
class HologresWrapper: class HologresWrapper:
"""Wrapper around Hologres service."""
def __init__(self, connection_string: str, ndims: int, table_name: str) -> None: def __init__(self, connection_string: str, ndims: int, table_name: str) -> None:
"""Initialize the wrapper.
Args:
connection_string: Hologres connection string.
ndims: Number of dimensions of the embedding output.
table_name: Name of the table to store embeddings and data.
"""
import psycopg2 import psycopg2
self.table_name = table_name self.table_name = table_name

View File

@ -87,6 +87,8 @@ class EmbeddingStore(BaseModel):
class QueryResult: class QueryResult:
"""QueryResult is a result from a query."""
EmbeddingStore: EmbeddingStore EmbeddingStore: EmbeddingStore
distance: float distance: float

View File

@ -18,6 +18,7 @@ from langchain.vectorstores.utils import DistanceStrategy
def normalize(x: np.ndarray) -> np.ndarray: def normalize(x: np.ndarray) -> np.ndarray:
"""Normalize vectors to unit length."""
x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None) x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None)
return x return x

View File

@ -13608,7 +13608,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"] cohere = ["cohere"]
docarray = ["docarray"] docarray = ["docarray"]
embeddings = ["sentence-transformers"] embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"] extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
javascript = ["esprima"] javascript = ["esprima"]
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"] llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
openai = ["openai", "tiktoken"] openai = ["openai", "tiktoken"]
@ -13619,4 +13619,4 @@ text-helpers = ["chardet"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0" content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e"

View File

@ -373,6 +373,7 @@ extended_testing = [
"feedparser", "feedparser",
"xata", "xata",
"xmltodict", "xmltodict",
"anthropic",
] ]
scheduled_testing = [ scheduled_testing = [

View File

@ -0,0 +1,27 @@
"""Test Anthropic Chat API wrapper."""
import os
import pytest
from langchain.chat_models import ChatAnthropic
os.environ["ANTHROPIC_API_KEY"] = "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_kwargs() -> None:
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
@pytest.mark.requires("anthropic")
def test_anthropic_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = ChatAnthropic(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}