mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 12:07:36 +00:00
REFACTOR: Refactor langchain_core (#13627)
Changes: - remove langchain_core/schema since no clear distinction b/n schema and non-schema modules - make every module that doesn't end in -y plural - where easy have 1-2 classes per file - no more than one level of nesting in directories - only import from top level core modules in langchain
This commit is contained in:
@@ -6,6 +6,8 @@ EXPECTED_ALL = [
|
||||
"suppress_langchain_deprecation_warning",
|
||||
"surface_langchain_deprecation_warnings",
|
||||
"warn_deprecated",
|
||||
"as_import_path",
|
||||
"get_relative_path",
|
||||
]
|
||||
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
"""Test functionality related to length based selector."""
|
||||
import pytest
|
||||
|
||||
from langchain_core.prompts.example_selector.length_based import (
|
||||
from langchain_core.example_selectors import (
|
||||
LengthBasedExampleSelector,
|
||||
)
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
EXAMPLES = [
|
||||
{"question": "Question: who are you?\nAnswer: foo"},
|
@@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
|
@@ -7,10 +7,9 @@ from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.chat_model import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.schema import ChatResult
|
||||
from langchain_core.schema.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.schema.output import ChatGeneration, ChatGenerationChunk
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
|
@@ -6,9 +6,8 @@ from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.llm import LLM
|
||||
from langchain_core.language_models import LLM, LanguageModelInput
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.schema.language_model import LanguageModelInput
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.schema import (
|
||||
from langchain_core.chat_history import (
|
||||
BaseChatMessageHistory,
|
||||
)
|
||||
from langchain_core.schema.messages import BaseMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||
|
@@ -3,6 +3,13 @@ from typing import Any, List, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
@@ -15,13 +22,6 @@ from langchain_core.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
_convert_to_message,
|
||||
)
|
||||
from langchain_core.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
def create_messages() -> List[BaseMessagePromptTemplate]:
|
||||
|
@@ -3,19 +3,19 @@ from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import (
|
||||
AIMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.chat import SystemMessagePromptTemplate
|
||||
from langchain_core.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain_core.prompts.few_shot import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
EXAMPLE_PROMPT = PromptTemplate(
|
||||
input_variables=["question", "answer"], template="{question}: {answer}"
|
||||
|
@@ -6,20 +6,22 @@ EXPECTED_ALL = [
|
||||
"BasePromptTemplate",
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"ChatPromptValueConcrete",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
"format_document",
|
||||
"ChatPromptValue",
|
||||
"PromptValue",
|
||||
"StringPromptValue",
|
||||
"HumanMessagePromptTemplate",
|
||||
"LengthBasedExampleSelector",
|
||||
"MaxMarginalRelevanceExampleSelector",
|
||||
"MessagesPlaceholder",
|
||||
"PipelinePromptTemplate",
|
||||
"Prompt",
|
||||
"PromptTemplate",
|
||||
"SemanticSimilarityExampleSelector",
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
]
|
||||
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
"""Test functionality related to prompt utils."""
|
||||
from langchain_core.prompts.example_selector.semantic_similarity import sorted_values
|
||||
from langchain_core.example_selectors import sorted_values
|
||||
|
||||
|
||||
def test_sorted_vals() -> None:
|
||||
|
File diff suppressed because one or more lines are too long
@@ -1,8 +1,8 @@
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain_core.runnables.config import RunnableConfig, merge_configs
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
|
||||
def test_merge_config_callbacks() -> None:
|
@@ -1,9 +1,10 @@
|
||||
from typing import Any, Callable, Sequence, Union
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.schema import AIMessage, BaseMessage, HumanMessage
|
||||
from tests.unit_tests.fake.memory import ChatMessageHistory
|
||||
|
||||
|
@@ -26,53 +26,55 @@ from langchain_core.callbacks.manager import (
|
||||
collect_runs,
|
||||
trace_as_chain_group,
|
||||
)
|
||||
from langchain_core.callbacks.tracers.base import BaseTracer
|
||||
from langchain_core.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain_core.callbacks.tracers.schemas import Run
|
||||
from langchain_core.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain_core.load.dump import dumpd, dumps
|
||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.prompts.base import StringPromptValue
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import (
|
||||
RouterRunnable,
|
||||
Runnable,
|
||||
RunnableBranch,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain_core.runnables.base import (
|
||||
ConfigurableField,
|
||||
RunnableBinding,
|
||||
RunnableGenerator,
|
||||
)
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
add,
|
||||
)
|
||||
from langchain_core.schema.document import Document
|
||||
from langchain_core.schema.messages import (
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.schema.output_parser import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.schema.retriever import BaseRetriever
|
||||
from langchain_core.tool import BaseTool, tool
|
||||
from langchain_core.output_parsers import (
|
||||
BaseOutputParser,
|
||||
CommaSeparatedListOutputParser,
|
||||
StrOutputParser,
|
||||
)
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
ChatPromptValue,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
PromptTemplate,
|
||||
StringPromptValue,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
ConfigurableField,
|
||||
ConfigurableFieldMultiOption,
|
||||
ConfigurableFieldSingleOption,
|
||||
RouterRunnable,
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableBranch,
|
||||
RunnableConfig,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
add,
|
||||
)
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langchain_core.tracers import (
|
||||
BaseTracer,
|
||||
ConsoleCallbackHandler,
|
||||
Run,
|
||||
RunLog,
|
||||
RunLogPatch,
|
||||
)
|
||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||
|
||||
@@ -1539,7 +1541,7 @@ def test_with_listeners(mocker: MockerFixture) -> None:
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain = prompt | chat
|
||||
chain: Runnable = prompt | chat
|
||||
|
||||
mock_start = mocker.Mock()
|
||||
mock_end = mocker.Mock()
|
||||
@@ -1572,7 +1574,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain = prompt | chat
|
||||
chain: Runnable = prompt | chat
|
||||
|
||||
mock_start = mocker.Mock()
|
||||
mock_end = mocker.Mock()
|
||||
@@ -1608,7 +1610,7 @@ def test_prompt_with_chat_model(
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain = prompt | chat
|
||||
chain: Runnable = prompt | chat
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
@@ -1712,7 +1714,7 @@ async def test_prompt_with_chat_model_async(
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain = prompt | chat
|
||||
chain: Runnable = prompt | chat
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
@@ -1819,7 +1821,7 @@ async def test_prompt_with_llm(
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
chain = prompt | llm
|
||||
chain: Runnable = prompt | llm
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
@@ -2325,13 +2327,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
async def test_router_runnable(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
chain1 = ChatPromptTemplate.from_template(
|
||||
chain1: Runnable = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
chain2 = ChatPromptTemplate.from_template(
|
||||
chain2: Runnable = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
router = RouterRunnable({"math": chain1, "english": chain2})
|
||||
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
|
||||
chain: Runnable = {
|
||||
"key": lambda x: x["key"],
|
||||
"input": {"question": lambda x: x["question"]},
|
||||
@@ -2377,10 +2379,10 @@ async def test_router_runnable(
|
||||
async def test_higher_order_lambda_runnable(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
math_chain = ChatPromptTemplate.from_template(
|
||||
math_chain: Runnable = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
english_chain = ChatPromptTemplate.from_template(
|
||||
english_chain: Runnable = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
input_map: Runnable = RunnableParallel(
|
||||
@@ -3096,7 +3098,7 @@ async def test_deep_astream_assign() -> None:
|
||||
def test_runnable_sequence_transform() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = llm | StrOutputParser()
|
||||
chain: Runnable = llm | StrOutputParser()
|
||||
|
||||
stream = chain.transform(llm.stream("Hi there!"))
|
||||
|
||||
@@ -3111,7 +3113,7 @@ def test_runnable_sequence_transform() -> None:
|
||||
async def test_runnable_sequence_atransform() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain = llm | StrOutputParser()
|
||||
chain: Runnable = llm | StrOutputParser()
|
||||
|
||||
stream = chain.atransform(llm.astream("Hi there!"))
|
||||
|
@@ -1,43 +0,0 @@
|
||||
from langchain_core.schema import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseCache",
|
||||
"BaseMemory",
|
||||
"BaseStore",
|
||||
"AgentFinish",
|
||||
"AgentAction",
|
||||
"Document",
|
||||
"BaseChatMessageHistory",
|
||||
"BaseDocumentTransformer",
|
||||
"BaseMessage",
|
||||
"ChatMessage",
|
||||
"FunctionMessage",
|
||||
"HumanMessage",
|
||||
"AIMessage",
|
||||
"SystemMessage",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"_message_to_dict",
|
||||
"_message_from_dict",
|
||||
"get_buffer_string",
|
||||
"RunInfo",
|
||||
"LLMResult",
|
||||
"ChatResult",
|
||||
"ChatGeneration",
|
||||
"Generation",
|
||||
"PromptValue",
|
||||
"LangChainException",
|
||||
"BaseRetriever",
|
||||
"RUN_KEY",
|
||||
"Memory",
|
||||
"OutputParserException",
|
||||
"StrOutputParser",
|
||||
"BaseOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
"BasePromptTemplate",
|
||||
"format_document",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.schema.messages import (
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
@@ -1,5 +1,5 @@
|
||||
from langchain_core.schema.messages import HumanMessageChunk
|
||||
from langchain_core.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
from langchain_core.messages import HumanMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
|
||||
def test_generation_chunk() -> None:
|
@@ -7,12 +7,12 @@ from typing import Any, List, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.tool import (
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
SchemaAnnotationError,
|
||||
StructuredTool,
|
@@ -14,6 +14,7 @@ EXPECTED_ALL = [
|
||||
"print_text",
|
||||
"raise_for_status_with_text",
|
||||
"xor_args",
|
||||
"try_load_from_hub",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user