mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
add more import checks (#11033)
This commit is contained in:
parent
efb7c459a2
commit
e355606b11
@ -15,7 +15,7 @@ from typing import (
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.chat_loaders.base import ChatSession
|
||||
from langchain.schema.chat import ChatSession
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
|
@ -1,15 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterator, List, Sequence, TypedDict
|
||||
from typing import Iterator, List
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
class ChatSession(TypedDict):
|
||||
"""Chat Session represents a single
|
||||
conversation, channel, or other group of messages."""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""The LangChain chat messages loaded from the source."""
|
||||
from langchain.schema.chat import ChatSession
|
||||
|
||||
|
||||
class BaseChatLoader(ABC):
|
||||
|
@ -3,7 +3,8 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Union
|
||||
|
||||
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
|
||||
from langchain.chat_loaders.base import BaseChatLoader
|
||||
from langchain.schema.chat import ChatSession
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
@ -2,7 +2,8 @@ import base64
|
||||
import re
|
||||
from typing import Any, Iterator
|
||||
|
||||
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
|
||||
from langchain.chat_loaders.base import BaseChatLoader
|
||||
from langchain.schema.chat import ChatSession
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
|
||||
|
@ -3,8 +3,9 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
|
||||
from langchain.chat_loaders.base import BaseChatLoader
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema.chat import ChatSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sqlite3
|
||||
|
@ -5,8 +5,9 @@ import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Union
|
||||
|
||||
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
|
||||
from langchain.chat_loaders.base import BaseChatLoader
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from langchain.schema.chat import ChatSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -6,8 +6,9 @@ import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
|
||||
from langchain.chat_loaders.base import BaseChatLoader
|
||||
from langchain.schema import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.schema.chat import ChatSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
from copy import deepcopy
|
||||
from typing import Iterable, Iterator, List
|
||||
|
||||
from langchain.chat_loaders.base import ChatSession
|
||||
from langchain.schema.chat import ChatSession
|
||||
from langchain.schema.messages import AIMessage, BaseMessage
|
||||
|
||||
|
||||
|
@ -4,8 +4,9 @@ import re
|
||||
import zipfile
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
|
||||
from langchain.chat_loaders.base import BaseChatLoader
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from langchain.schema.chat import ChatSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.requests import Requests
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utilities.requests import Requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain.requests import Requests
|
||||
from langchain.utilities.requests import Requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -5,10 +5,10 @@ import warnings
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
|
||||
from langchain.formatting import formatter
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
from langchain.utils.formatting import formatter
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
@ -6,11 +6,10 @@ from typing import Callable, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
from langchain.utils.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -77,6 +76,8 @@ def _load_output_parser(config: dict) -> dict:
|
||||
_config = config.pop("output_parser")
|
||||
output_parser_type = _config.pop("_type")
|
||||
if output_parser_type == "regex_parser":
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
|
||||
output_parser: BaseLLMOutputParser = RegexParser(**_config)
|
||||
elif output_parser_type == "default":
|
||||
output_parser = StrOutputParser(**_config)
|
||||
|
11
libs/langchain/langchain/schema/chat.py
Normal file
11
libs/langchain/langchain/schema/chat.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Sequence, TypedDict
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
|
||||
class ChatSession(TypedDict):
|
||||
"""Chat Session represents a single
|
||||
conversation, channel, or other group of messages."""
|
||||
|
||||
messages: Sequence[BaseMessage]
|
||||
"""The LangChain chat messages loaded from the source."""
|
@ -32,10 +32,9 @@ if TYPE_CHECKING:
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
||||
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field
|
||||
@ -216,6 +215,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
The jsonpatch ops can be applied in order to construct state.
|
||||
"""
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.tracers.log_stream import (
|
||||
LogStreamCallbackHandler,
|
||||
RunLogPatch,
|
||||
)
|
||||
|
||||
# Create a stream handler that will emit Log objects
|
||||
stream = LogStreamCallbackHandler(
|
||||
auto_close=False,
|
||||
|
@ -1,4 +1,15 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
@ -10,14 +21,16 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
|
||||
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
@ -54,7 +67,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
def _patch_config(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
run_manager: T,
|
||||
run_manager: "T",
|
||||
retry_state: RetryCallState,
|
||||
) -> RunnableConfig:
|
||||
attempt = retry_state.attempt_number
|
||||
@ -64,7 +77,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
def _patch_config_list(
|
||||
self,
|
||||
config: List[RunnableConfig],
|
||||
run_manager: List[T],
|
||||
run_manager: List["T"],
|
||||
retry_state: RetryCallState,
|
||||
) -> List[RunnableConfig]:
|
||||
return [
|
||||
@ -74,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
run_manager: "CallbackManagerForChainRun",
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
for attempt in self._sync_retrying(reraise=True):
|
||||
@ -95,7 +108,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
run_manager: "AsyncCallbackManagerForChainRun",
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
async for attempt in self._async_retrying(reraise=True):
|
||||
@ -116,7 +129,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
def _batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List[CallbackManagerForChainRun],
|
||||
run_manager: List["CallbackManagerForChainRun"],
|
||||
config: List[RunnableConfig],
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
@ -180,7 +193,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List[AsyncCallbackManagerForChainRun],
|
||||
run_manager: List["AsyncCallbackManagerForChainRun"],
|
||||
config: List[RunnableConfig],
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
|
@ -1,10 +1,12 @@
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
from langchain.document_loaders import ApifyDatasetLoader
|
||||
from langchain.document_loaders.base import Document
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.document_loaders import ApifyDatasetLoader
|
||||
|
||||
|
||||
class ApifyWrapper(BaseModel):
|
||||
"""Wrapper around Apify.
|
||||
@ -48,7 +50,7 @@ class ApifyWrapper(BaseModel):
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> ApifyDatasetLoader:
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run an Actor on the Apify platform and wait for results to be ready.
|
||||
Args:
|
||||
actor_id (str): The ID or name of the Actor on the Apify platform.
|
||||
@ -65,6 +67,8 @@ class ApifyWrapper(BaseModel):
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
Actor run's default dataset.
|
||||
"""
|
||||
from langchain.document_loaders import ApifyDatasetLoader
|
||||
|
||||
actor_call = self.apify_client.actor(actor_id).call(
|
||||
run_input=run_input,
|
||||
build=build,
|
||||
@ -86,7 +90,7 @@ class ApifyWrapper(BaseModel):
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> ApifyDatasetLoader:
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run an Actor on the Apify platform and wait for results to be ready.
|
||||
Args:
|
||||
actor_id (str): The ID or name of the Actor on the Apify platform.
|
||||
@ -103,6 +107,8 @@ class ApifyWrapper(BaseModel):
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
Actor run's default dataset.
|
||||
"""
|
||||
from langchain.document_loaders import ApifyDatasetLoader
|
||||
|
||||
actor_call = await self.apify_client_async.actor(actor_id).call(
|
||||
run_input=run_input,
|
||||
build=build,
|
||||
@ -124,7 +130,7 @@ class ApifyWrapper(BaseModel):
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> ApifyDatasetLoader:
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run a saved Actor task on Apify and wait for results to be ready.
|
||||
Args:
|
||||
task_id (str): The ID or name of the task on the Apify platform.
|
||||
@ -142,6 +148,8 @@ class ApifyWrapper(BaseModel):
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
task run's default dataset.
|
||||
"""
|
||||
from langchain.document_loaders import ApifyDatasetLoader
|
||||
|
||||
task_call = self.apify_client.task(task_id).call(
|
||||
task_input=task_input,
|
||||
build=build,
|
||||
@ -163,7 +171,7 @@ class ApifyWrapper(BaseModel):
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> ApifyDatasetLoader:
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run a saved Actor task on Apify and wait for results to be ready.
|
||||
Args:
|
||||
task_id (str): The ID or name of the task on the Apify platform.
|
||||
@ -181,6 +189,8 @@ class ApifyWrapper(BaseModel):
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
task run's default dataset.
|
||||
"""
|
||||
from langchain.document_loaders import ApifyDatasetLoader
|
||||
|
||||
task_call = await self.apify_client_async.task(task_id).call(
|
||||
task_input=task_input,
|
||||
build=build,
|
||||
|
@ -1,54 +1,4 @@
|
||||
"""Utilities for loading configurations from langchain-hub."""
|
||||
from langchain.utils.loading import try_load_from_hub
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Callable, Optional, Set, TypeVar, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
|
||||
URL_BASE = os.environ.get(
|
||||
"LANGCHAIN_HUB_URL_BASE",
|
||||
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
|
||||
)
|
||||
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def try_load_from_hub(
|
||||
path: Union[str, Path],
|
||||
loader: Callable[[str], T],
|
||||
valid_prefix: str,
|
||||
valid_suffixes: Set[str],
|
||||
**kwargs: Any,
|
||||
) -> Optional[T]:
|
||||
"""Load configuration from hub. Returns None if path is not a hub path."""
|
||||
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
|
||||
return None
|
||||
ref, remote_path_str = match.groups()
|
||||
ref = ref[1:] if ref else DEFAULT_REF
|
||||
remote_path = Path(remote_path_str)
|
||||
if remote_path.parts[0] != valid_prefix:
|
||||
return None
|
||||
if remote_path.suffix[1:] not in valid_suffixes:
|
||||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
|
||||
|
||||
# Using Path with URLs is not recommended, because on Windows
|
||||
# the backslash is used as the path separator, which can cause issues
|
||||
# when working with URLs that use forward slashes as the path separator.
|
||||
# Instead, use PurePosixPath to ensure that forward slashes are used as the
|
||||
# path separator, regardless of the operating system.
|
||||
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
|
||||
|
||||
r = requests.get(full_url, timeout=5)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = Path(tmpdirname) / remote_path.name
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return loader(str(file), **kwargs)
|
||||
# For backwards compatibility
|
||||
__all__ = ["try_load_from_hub"]
|
||||
|
54
libs/langchain/langchain/utils/loading.py
Normal file
54
libs/langchain/langchain/utils/loading.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Utilities for loading configurations from langchain-hub."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Callable, Optional, Set, TypeVar, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
|
||||
URL_BASE = os.environ.get(
|
||||
"LANGCHAIN_HUB_URL_BASE",
|
||||
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
|
||||
)
|
||||
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def try_load_from_hub(
|
||||
path: Union[str, Path],
|
||||
loader: Callable[[str], T],
|
||||
valid_prefix: str,
|
||||
valid_suffixes: Set[str],
|
||||
**kwargs: Any,
|
||||
) -> Optional[T]:
|
||||
"""Load configuration from hub. Returns None if path is not a hub path."""
|
||||
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
|
||||
return None
|
||||
ref, remote_path_str = match.groups()
|
||||
ref = ref[1:] if ref else DEFAULT_REF
|
||||
remote_path = Path(remote_path_str)
|
||||
if remote_path.parts[0] != valid_prefix:
|
||||
return None
|
||||
if remote_path.suffix[1:] not in valid_suffixes:
|
||||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
|
||||
|
||||
# Using Path with URLs is not recommended, because on Windows
|
||||
# the backslash is used as the path separator, which can cause issues
|
||||
# when working with URLs that use forward slashes as the path separator.
|
||||
# Instead, use PurePosixPath to ensure that forward slashes are used as the
|
||||
# path separator, regardless of the operating system.
|
||||
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
|
||||
|
||||
r = requests.get(full_url, timeout=5)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = Path(tmpdirname) / remote_path.name
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return loader(str(file), **kwargs)
|
@ -6,8 +6,7 @@ import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings import TensorflowHubEmbeddings
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
|
||||
@ -16,6 +15,8 @@ if TYPE_CHECKING:
|
||||
from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint
|
||||
from google.oauth2.service_account import Credentials
|
||||
|
||||
from langchain.embeddings import TensorflowHubEmbeddings
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@ -443,10 +444,13 @@ class MatchingEngine(VectorStore):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_default_embeddings(cls) -> TensorflowHubEmbeddings:
|
||||
def _get_default_embeddings(cls) -> "TensorflowHubEmbeddings":
|
||||
"""This function returns the default embedding.
|
||||
|
||||
Returns:
|
||||
Default TensorflowHubEmbeddings to use.
|
||||
"""
|
||||
|
||||
from langchain.embeddings import TensorflowHubEmbeddings
|
||||
|
||||
return TensorflowHubEmbeddings()
|
||||
|
@ -18,8 +18,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
|
@ -2,13 +2,31 @@
|
||||
|
||||
set -eu
|
||||
|
||||
git grep 'from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && exit 1 || exit 0
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# Pydantic bridge should not import from any other module
|
||||
git grep 'from langchain ' langchain/pydantic_v1 && exit 1 || exit 0
|
||||
# Check the conditions
|
||||
git grep '^from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && errors=$((errors+1))
|
||||
git grep '^from langchain ' langchain/pydantic_v1 | grep -vE 'from langchain.(pydantic_v1)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1|load)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/adapters | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/callbacks | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env)' && errors=$((errors+1))
|
||||
# TODO: it's probably not amazing so that so many other modules depend on `langchain.utilities`, because there can be a lot of imports there
|
||||
git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
|
||||
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))
|
||||
|
||||
# load should not import from anything except itself and pydantic_v1
|
||||
git grep 'from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1)' && exit 1 || exit 0
|
||||
|
||||
# utils should not import from anything except itself and pydantic_v1
|
||||
git grep 'from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && exit 1 || exit 0
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
||||
|
@ -10,7 +10,7 @@ from urllib.parse import urljoin
|
||||
import pytest
|
||||
import responses
|
||||
|
||||
from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
|
||||
from langchain.utils.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
Loading…
Reference in New Issue
Block a user