REFACTOR: Refactor langchain_core (#13627)

Changes:
- remove langchain_core/schema since no clear distinction b/n schema and
non-schema modules
- make every module that doesn't end in -y plural
- where easy have 1-2 classes per file
- no more than one level of nesting in directories
- only import from top level core modules in langchain
This commit is contained in:
Bagatur
2023-11-21 08:35:29 -08:00
committed by GitHub
parent 17c6551c18
commit d32e511826
783 changed files with 2992 additions and 2899 deletions

View File

@@ -6,6 +6,8 @@ EXPECTED_ALL = [
"suppress_langchain_deprecation_warning",
"surface_langchain_deprecation_warnings",
"warn_deprecated",
"as_import_path",
"get_relative_path",
]

View File

@@ -1,10 +1,10 @@
"""Test functionality related to length based selector."""
import pytest
from langchain_core.prompts.example_selector.length_based import (
from langchain_core.example_selectors import (
LengthBasedExampleSelector,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts import PromptTemplate
EXAMPLES = [
{"question": "Question: who are you?\nAnswer: foo"},

View File

@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.schema.messages import BaseMessage
class BaseFakeCallbackHandler(BaseModel):

View File

@@ -7,10 +7,9 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.chat_model import BaseChatModel, SimpleChatModel
from langchain_core.schema import ChatResult
from langchain_core.schema.messages import AIMessageChunk, BaseMessage
from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
class FakeMessagesListChatModel(BaseChatModel):

View File

@@ -6,9 +6,8 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.llm import LLM
from langchain_core.language_models import LLM, LanguageModelInput
from langchain_core.runnables import RunnableConfig
from langchain_core.schema.language_model import LanguageModelInput
class FakeListLLM(LLM):

View File

@@ -1,10 +1,10 @@
from typing import List
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.schema import (
from langchain_core.chat_history import (
BaseChatMessageHistory,
)
from langchain_core.schema.messages import BaseMessage
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel, Field
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):

View File

@@ -3,6 +3,13 @@ from typing import Any, List, Union
import pytest
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
@@ -15,13 +22,6 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
_convert_to_message,
)
from langchain_core.schema.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
def create_messages() -> List[BaseMessagePromptTemplate]:

View File

@@ -3,19 +3,19 @@ from typing import Any, Dict, List, Sequence, Tuple
import pytest
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain_core.prompts.chat import SystemMessagePromptTemplate
from langchain_core.prompts.example_selector.base import BaseExampleSelector
from langchain_core.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.schema import AIMessage, HumanMessage, SystemMessage
EXAMPLE_PROMPT = PromptTemplate(
input_variables=["question", "answer"], template="{question}: {answer}"

View File

@@ -6,20 +6,22 @@ EXPECTED_ALL = [
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"ChatPromptValueConcrete",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
"format_document",
"ChatPromptValue",
"PromptValue",
"StringPromptValue",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
]

View File

@@ -1,5 +1,5 @@
"""Test functionality related to prompt utils."""
from langchain_core.prompts.example_selector.semantic_similarity import sorted_values
from langchain_core.example_selectors import sorted_values
def test_sorted_vals() -> None:

View File

@@ -1,8 +1,8 @@
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain_core.runnables.config import RunnableConfig, merge_configs
from langchain_core.tracers.stdout import ConsoleCallbackHandler
def test_merge_config_callbacks() -> None:

View File

@@ -1,9 +1,10 @@
from typing import Any, Callable, Sequence, Union
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.schema import AIMessage, BaseMessage, HumanMessage
from tests.unit_tests.fake.memory import ChatMessageHistory

View File

@@ -26,53 +26,55 @@ from langchain_core.callbacks.manager import (
collect_runs,
trace_as_chain_group,
)
from langchain_core.callbacks.tracers.base import BaseTracer
from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain_core.callbacks.tracers.schemas import Run
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain_core.load.dump import dumpd, dumps
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.base import StringPromptValue
from langchain_core.prompts.chat import (
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import (
RouterRunnable,
Runnable,
RunnableBranch,
RunnableConfig,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
RunnableSequence,
RunnableWithFallbacks,
)
from langchain_core.runnables.base import (
ConfigurableField,
RunnableBinding,
RunnableGenerator,
)
from langchain_core.runnables.utils import (
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
add,
)
from langchain_core.schema.document import Document
from langchain_core.schema.messages import (
from langchain_core.documents import Document
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.schema.output_parser import BaseOutputParser, StrOutputParser
from langchain_core.schema.retriever import BaseRetriever
from langchain_core.tool import BaseTool, tool
from langchain_core.output_parsers import (
BaseOutputParser,
CommaSeparatedListOutputParser,
StrOutputParser,
)
from langchain_core.prompts import (
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
StringPromptValue,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
ConfigurableField,
ConfigurableFieldMultiOption,
ConfigurableFieldSingleOption,
RouterRunnable,
Runnable,
RunnableBinding,
RunnableBranch,
RunnableConfig,
RunnableGenerator,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
RunnableSequence,
RunnableWithFallbacks,
add,
)
from langchain_core.tools import BaseTool, tool
from langchain_core.tracers import (
BaseTracer,
ConsoleCallbackHandler,
Run,
RunLog,
RunLogPatch,
)
from tests.unit_tests.fake.chat_model import FakeListChatModel
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
@@ -1539,7 +1541,7 @@ def test_with_listeners(mocker: MockerFixture) -> None:
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
@@ -1572,7 +1574,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
@@ -1608,7 +1610,7 @@ def test_prompt_with_chat_model(
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
@@ -1712,7 +1714,7 @@ async def test_prompt_with_chat_model_async(
)
chat = FakeListChatModel(responses=["foo"])
chain = prompt | chat
chain: Runnable = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
@@ -1819,7 +1821,7 @@ async def test_prompt_with_llm(
)
llm = FakeListLLM(responses=["foo", "bar"])
chain = prompt | llm
chain: Runnable = prompt | llm
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
@@ -2325,13 +2327,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
async def test_router_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
chain1 = ChatPromptTemplate.from_template(
chain1: Runnable = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
chain2 = ChatPromptTemplate.from_template(
chain2: Runnable = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
router = RouterRunnable({"math": chain1, "english": chain2})
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
chain: Runnable = {
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
@@ -2377,10 +2379,10 @@ async def test_router_runnable(
async def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
math_chain = ChatPromptTemplate.from_template(
math_chain: Runnable = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain = ChatPromptTemplate.from_template(
english_chain: Runnable = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableParallel(
@@ -3096,7 +3098,7 @@ async def test_deep_astream_assign() -> None:
def test_runnable_sequence_transform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = llm | StrOutputParser()
chain: Runnable = llm | StrOutputParser()
stream = chain.transform(llm.stream("Hi there!"))
@@ -3111,7 +3113,7 @@ def test_runnable_sequence_transform() -> None:
async def test_runnable_sequence_atransform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain = llm | StrOutputParser()
chain: Runnable = llm | StrOutputParser()
stream = chain.atransform(llm.astream("Hi there!"))

View File

@@ -1,43 +0,0 @@
from langchain_core.schema import __all__
EXPECTED_ALL = [
"BaseCache",
"BaseMemory",
"BaseStore",
"AgentFinish",
"AgentAction",
"Document",
"BaseChatMessageHistory",
"BaseDocumentTransformer",
"BaseMessage",
"ChatMessage",
"FunctionMessage",
"HumanMessage",
"AIMessage",
"SystemMessage",
"messages_from_dict",
"messages_to_dict",
"_message_to_dict",
"_message_from_dict",
"get_buffer_string",
"RunInfo",
"LLMResult",
"ChatResult",
"ChatGeneration",
"Generation",
"PromptValue",
"LangChainException",
"BaseRetriever",
"RUN_KEY",
"Memory",
"OutputParserException",
"StrOutputParser",
"BaseOutputParser",
"BaseLLMOutputParser",
"BasePromptTemplate",
"format_document",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@@ -1,6 +1,6 @@
import pytest
from langchain_core.schema.messages import (
from langchain_core.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,

View File

@@ -1,5 +1,5 @@
from langchain_core.schema.messages import HumanMessageChunk
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
from langchain_core.messages import HumanMessageChunk
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
def test_generation_chunk() -> None:

View File

@@ -7,12 +7,12 @@ from typing import Any, List, Optional, Type, Union
import pytest
from langchain_core.callbacks.manager import (
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tool import (
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
StructuredTool,

View File

@@ -14,6 +14,7 @@ EXPECTED_ALL = [
"print_text",
"raise_for_status_with_text",
"xor_args",
"try_load_from_hub",
]