mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
Refactored input
(#8202)
Refactored `input.py`. The same as https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099 input.py is in the root code folder. This creates the `langchain.input: Input` group on the API Reference navigation ToC, on the same level as Chains and Agents which is incorrect. Refactoring: - copied input.py file into utils/input.py - I added the backwards compatibility ref in the original input.py. - changed several imports to a new ref @hwchase17, @baskaryan
This commit is contained in:
parent
72eb4fa4e8
commit
7cbe28ba9b
@ -25,7 +25,6 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
@ -39,6 +38,7 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.asyncio import asyncio_timeout
|
||||
from langchain.utils.input import get_color_mapping
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -25,11 +25,11 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.utilities.asyncio import asyncio_timeout
|
||||
from langchain.utils.input import get_color_mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
@ -2,8 +2,8 @@
|
||||
from typing import Any, Dict, Optional, TextIO, cast
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from langchain.utils.input import print_text
|
||||
|
||||
|
||||
class FileCallbackHandler(BaseCallbackHandler):
|
||||
|
@ -2,8 +2,8 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.utils.input import print_text
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, Callable, List
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.input import get_bolded_text, get_colored_text
|
||||
from langchain.utils.input import get_bolded_text, get_colored_text
|
||||
|
||||
|
||||
def try_json_stringify(obj: Any, fallback: str) -> str:
|
||||
|
@ -14,7 +14,6 @@ from langchain.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
@ -25,6 +24,7 @@ from langchain.schema import (
|
||||
PromptValue,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.utils.input import get_colored_text
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
|
@ -12,13 +12,13 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import APIOperation
|
||||
from langchain.utilities.openapi import OpenAPISpec
|
||||
from langchain.utils.input import get_colored_text
|
||||
|
||||
|
||||
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:
|
||||
|
@ -8,7 +8,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.utils.input import get_color_mapping
|
||||
|
||||
|
||||
class SequentialChain(Chain):
|
||||
|
@ -1,42 +1,14 @@
|
||||
"""Handle chained inputs."""
|
||||
from typing import Dict, List, Optional, TextIO
|
||||
"""DEPRECATED: Kept for backwards compatibility."""
|
||||
from langchain.utils.input import (
|
||||
get_bolded_text,
|
||||
get_color_mapping,
|
||||
get_colored_text,
|
||||
print_text,
|
||||
)
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
||||
|
||||
def get_bolded_text(text: str) -> str:
|
||||
"""Get bolded text."""
|
||||
return f"\033[1m{text}\033[0m"
|
||||
|
||||
|
||||
def print_text(
|
||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end, file=file)
|
||||
if file:
|
||||
file.flush() # ensure all printed content are written to file
|
||||
__all__ = [
|
||||
"get_bolded_text",
|
||||
"get_color_mapping",
|
||||
"get_colored_text",
|
||||
"print_text",
|
||||
]
|
||||
|
@ -5,9 +5,9 @@ from typing import List, Optional, Sequence
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.utils.input import get_color_mapping, print_text
|
||||
|
||||
|
||||
class ModelLaboratory:
|
||||
|
@ -6,6 +6,12 @@ These functions do not depend on any other langchain modules.
|
||||
|
||||
from langchain.utils.env import get_from_dict_or_env, get_from_env
|
||||
from langchain.utils.formatting import StrictFormatter, formatter
|
||||
from langchain.utils.input import (
|
||||
get_bolded_text,
|
||||
get_color_mapping,
|
||||
get_colored_text,
|
||||
print_text,
|
||||
)
|
||||
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
||||
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
|
||||
from langchain.utils.utils import (
|
||||
@ -24,11 +30,15 @@ __all__ = [
|
||||
"cosine_similarity",
|
||||
"cosine_similarity_top_k",
|
||||
"formatter",
|
||||
"get_bolded_text",
|
||||
"get_color_mapping",
|
||||
"get_colored_text",
|
||||
"get_from_dict_or_env",
|
||||
"get_from_env",
|
||||
"get_pydantic_field_names",
|
||||
"guard_import",
|
||||
"mock_now",
|
||||
"print_text",
|
||||
"raise_for_status_with_text",
|
||||
"stringify_dict",
|
||||
"stringify_value",
|
||||
|
42
libs/langchain/langchain/utils/input.py
Normal file
42
libs/langchain/langchain/utils/input.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Handle chained inputs."""
|
||||
from typing import Dict, List, Optional, TextIO
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
||||
|
||||
def get_bolded_text(text: str) -> str:
|
||||
"""Get bolded text."""
|
||||
return f"\033[1m{text}\033[0m"
|
||||
|
||||
|
||||
def print_text(
|
||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end, file=file)
|
||||
if file:
|
||||
file.flush() # ensure all printed content are written to file
|
Loading…
Reference in New Issue
Block a user