mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
Refactored math_utils
(#7961)
`math_utils.py` is in the root code folder. This creates the `langchain.math_utils: Math Utils` group on the API Reference navigation ToC, on the same level with `Chains` and `Agents` which is not correct. Refactoring: - created the `utils/` folder - moved `math_utils.py` to `utils/math.py` - moved `utils.py` to `utils/utils.py` - split `utils.py` into `utils.py, env.py, strings.py` - added module description @baskaryan
This commit is contained in:
parent
5137f40dd6
commit
995220b797
@ -5,8 +5,8 @@ import numpy as np
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.math_utils import cosine_similarity
|
|
||||||
from langchain.schema import BaseDocumentTransformer, Document
|
from langchain.schema import BaseDocumentTransformer, Document
|
||||||
|
from langchain.utils.math import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
class _DocumentWithState(Document):
|
class _DocumentWithState(Document):
|
||||||
|
@ -14,8 +14,8 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||||
from langchain.math_utils import cosine_similarity
|
|
||||||
from langchain.schema import RUN_KEY
|
from langchain.schema import RUN_KEY
|
||||||
|
from langchain.utils.math import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingDistance(str, Enum):
|
class EmbeddingDistance(str, Enum):
|
||||||
|
@ -9,11 +9,11 @@ from langchain.document_transformers.embeddings_redundant_filter import (
|
|||||||
get_stateful_documents,
|
get_stateful_documents,
|
||||||
)
|
)
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.math_utils import cosine_similarity
|
|
||||||
from langchain.retrievers.document_compressors.base import (
|
from langchain.retrievers.document_compressors.base import (
|
||||||
BaseDocumentCompressor,
|
BaseDocumentCompressor,
|
||||||
)
|
)
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
from langchain.utils.math import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsFilter(BaseDocumentCompressor):
|
class EmbeddingsFilter(BaseDocumentCompressor):
|
||||||
|
@ -10,7 +10,7 @@ from sqlalchemy.engine import Engine
|
|||||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||||
from sqlalchemy.schema import CreateTable
|
from sqlalchemy.schema import CreateTable
|
||||||
|
|
||||||
from langchain import utils
|
from langchain.utils import get_from_env
|
||||||
|
|
||||||
|
|
||||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||||
@ -192,13 +192,11 @@ class SQLDatabase:
|
|||||||
|
|
||||||
default_host = context.browserHostName if context else None
|
default_host = context.browserHostName if context else None
|
||||||
if host is None:
|
if host is None:
|
||||||
host = utils.get_from_env("host", "DATABRICKS_HOST", default_host)
|
host = get_from_env("host", "DATABRICKS_HOST", default_host)
|
||||||
|
|
||||||
default_api_token = context.apiToken if context else None
|
default_api_token = context.apiToken if context else None
|
||||||
if api_token is None:
|
if api_token is None:
|
||||||
api_token = utils.get_from_env(
|
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
|
||||||
"api_token", "DATABRICKS_TOKEN", default_api_token
|
|
||||||
)
|
|
||||||
|
|
||||||
if warehouse_id is None and cluster_id is None:
|
if warehouse_id is None and cluster_id is None:
|
||||||
if context:
|
if context:
|
||||||
|
33
langchain/utils/__init__.py
Normal file
33
langchain/utils/__init__.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for langchain.
|
||||||
|
|
||||||
|
These functions do not depend on any other langchain modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from langchain.utils.env import get_from_dict_or_env, get_from_env
|
||||||
|
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
||||||
|
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
|
||||||
|
from langchain.utils.utils import (
|
||||||
|
check_package_version,
|
||||||
|
get_pydantic_field_names,
|
||||||
|
guard_import,
|
||||||
|
mock_now,
|
||||||
|
raise_for_status_with_text,
|
||||||
|
xor_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"check_package_version",
|
||||||
|
"comma_list",
|
||||||
|
"cosine_similarity",
|
||||||
|
"cosine_similarity_top_k",
|
||||||
|
"get_from_dict_or_env",
|
||||||
|
"get_from_env",
|
||||||
|
"get_pydantic_field_names",
|
||||||
|
"guard_import",
|
||||||
|
"mock_now",
|
||||||
|
"raise_for_status_with_text",
|
||||||
|
"stringify_dict",
|
||||||
|
"stringify_value",
|
||||||
|
"xor_args",
|
||||||
|
]
|
26
langchain/utils/env.py
Normal file
26
langchain/utils/env.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_dict_or_env(
|
||||||
|
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""Get a value from a dictionary or an environment variable."""
|
||||||
|
if key in data and data[key]:
|
||||||
|
return data[key]
|
||||||
|
else:
|
||||||
|
return get_from_env(key, env_key, default=default)
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
||||||
|
"""Get a value from a dictionary or an environment variable."""
|
||||||
|
if env_key in os.environ and os.environ[env_key]:
|
||||||
|
return os.environ[env_key]
|
||||||
|
elif default is not None:
|
||||||
|
return default
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Did not find {key}, please add an environment variable"
|
||||||
|
f" `{env_key}` which contains it, or pass"
|
||||||
|
f" `{key}` as a named parameter."
|
||||||
|
)
|
39
langchain/utils/strings.py
Normal file
39
langchain/utils/strings.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
|
||||||
|
def stringify_value(val: Any) -> str:
|
||||||
|
"""Stringify a value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val: The value to stringify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The stringified value.
|
||||||
|
"""
|
||||||
|
if isinstance(val, str):
|
||||||
|
return val
|
||||||
|
elif isinstance(val, dict):
|
||||||
|
return "\n" + stringify_dict(val)
|
||||||
|
elif isinstance(val, list):
|
||||||
|
return "\n".join(stringify_value(v) for v in val)
|
||||||
|
else:
|
||||||
|
return str(val)
|
||||||
|
|
||||||
|
|
||||||
|
def stringify_dict(data: dict) -> str:
|
||||||
|
"""Stringify a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The dictionary to stringify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The stringified dictionary.
|
||||||
|
"""
|
||||||
|
text = ""
|
||||||
|
for key, value in data.items():
|
||||||
|
text += key + ": " + stringify_value(value) + "\n"
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def comma_list(items: List[Any]) -> str:
|
||||||
|
return ", ".join(str(item) for item in items)
|
@ -2,38 +2,13 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
from typing import Any, Callable, Optional, Set, Tuple
|
||||||
|
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
from requests import HTTPError, Response
|
from requests import HTTPError, Response
|
||||||
|
|
||||||
|
|
||||||
def get_from_dict_or_env(
|
|
||||||
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
|
||||||
) -> str:
|
|
||||||
"""Get a value from a dictionary or an environment variable."""
|
|
||||||
if key in data and data[key]:
|
|
||||||
return data[key]
|
|
||||||
else:
|
|
||||||
return get_from_env(key, env_key, default=default)
|
|
||||||
|
|
||||||
|
|
||||||
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
|
||||||
"""Get a value from a dictionary or an environment variable."""
|
|
||||||
if env_key in os.environ and os.environ[env_key]:
|
|
||||||
return os.environ[env_key]
|
|
||||||
elif default is not None:
|
|
||||||
return default
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Did not find {key}, please add an environment variable"
|
|
||||||
f" `{env_key}` which contains it, or pass"
|
|
||||||
f" `{key}` as a named parameter."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||||
"""Validate specified keyword args are mutually exclusive."""
|
"""Validate specified keyword args are mutually exclusive."""
|
||||||
|
|
||||||
@ -67,44 +42,6 @@ def raise_for_status_with_text(response: Response) -> None:
|
|||||||
raise ValueError(response.text) from e
|
raise ValueError(response.text) from e
|
||||||
|
|
||||||
|
|
||||||
def stringify_value(val: Any) -> str:
|
|
||||||
"""Stringify a value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
val: The value to stringify.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The stringified value.
|
|
||||||
"""
|
|
||||||
if isinstance(val, str):
|
|
||||||
return val
|
|
||||||
elif isinstance(val, dict):
|
|
||||||
return "\n" + stringify_dict(val)
|
|
||||||
elif isinstance(val, list):
|
|
||||||
return "\n".join(stringify_value(v) for v in val)
|
|
||||||
else:
|
|
||||||
return str(val)
|
|
||||||
|
|
||||||
|
|
||||||
def stringify_dict(data: dict) -> str:
|
|
||||||
"""Stringify a dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: The dictionary to stringify.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The stringified dictionary.
|
|
||||||
"""
|
|
||||||
text = ""
|
|
||||||
for key, value in data.items():
|
|
||||||
text += key + ": " + stringify_value(value) + "\n"
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def comma_list(items: List[Any]) -> str:
|
|
||||||
return ", ".join(str(item) for item in items)
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def mock_now(dt_value): # type: ignore
|
def mock_now(dt_value): # type: ignore
|
||||||
"""Context manager for mocking out datetime.now() in unit tests.
|
"""Context manager for mocking out datetime.now() in unit tests.
|
@ -5,7 +5,7 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from langchain.math_utils import cosine_similarity
|
from langchain.utils.math import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
class DistanceStrategy(str, Enum):
|
class DistanceStrategy(str, Enum):
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from langchain.document_transformers.embeddings_redundant_filter import (
|
from langchain.document_transformers.embeddings_redundant_filter import (
|
||||||
_filter_similar_embeddings,
|
_filter_similar_embeddings,
|
||||||
)
|
)
|
||||||
from langchain.math_utils import cosine_similarity
|
from langchain.utils.math import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
def test__filter_similar_embeddings() -> None:
|
def test__filter_similar_embeddings() -> None:
|
||||||
|
@ -4,7 +4,7 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.math_utils import cosine_similarity, cosine_similarity_top_k
|
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
Loading…
Reference in New Issue
Block a user