Compare commits

...

9 Commits

Author SHA1 Message Date
Erick Friis
63f380d349 ignore ones 2023-11-22 14:47:04 -05:00
Erick Friis
9c55d65606 document transformers gone 2023-11-22 14:32:33 -05:00
Bagatur
fabebb2042 more 2023-11-22 11:31:51 -08:00
Bagatur
b2b8122249 Merge remote-tracking branch 'origin/erick/core-namespace-same' into bagatur/fix_core_namespace 2023-11-22 11:20:07 -08:00
Bagatur
5c6973a1b5 more 2023-11-22 11:19:41 -08:00
Bagatur
4449481cd3 BUGFIX: backwards compatible core namespacing 2023-11-22 11:10:27 -08:00
Erick Friis
70f6be32e0 number 2023-11-22 13:44:10 -05:00
Erick Friis
1548f32f3d lint 2023-11-22 13:42:59 -05:00
Erick Friis
6c59db482b unit test for core 2023-11-22 13:38:00 -05:00
37 changed files with 643 additions and 172 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Literal, Sequence, Union
from typing import Any, List, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.messages import BaseMessage
@@ -34,6 +34,11 @@ class AgentAction(Serializable):
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "agent"]
class AgentActionMessageLog(AgentAction):
message_log: Sequence[BaseMessage]
@@ -49,6 +54,11 @@ class AgentActionMessageLog(AgentAction):
# The type literal is used for serialization purposes.
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "agent"]
class AgentFinish(Serializable):
"""The final return value of an ActionAgent."""
@@ -72,3 +82,8 @@ class AgentFinish(Serializable):
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "agent"]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Literal
from typing import List, Literal
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field
@@ -21,3 +21,8 @@ class Document(Serializable):
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "document"]

View File

@@ -37,6 +37,11 @@ class BaseMessage(Serializable):
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "messages"]
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],

View File

@@ -18,6 +18,11 @@ class ListOutputParser(BaseOutputParser[List[str]]):
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "output_parsers", "list"]
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""

View File

@@ -1,3 +1,5 @@
from typing import List
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@@ -17,3 +19,8 @@ class StrOutputParser(BaseTransformOutputParser[str]):
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "output_parser"]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, Literal, Optional
from typing import Any, Dict, List, Literal, Optional
from langchain_core.load import Serializable
@@ -24,6 +24,11 @@ class Generation(Serializable):
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "schema", "output"]
class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""

View File

@@ -48,6 +48,11 @@ class StringPromptValue(PromptValue):
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "prompts", "base"]
class ChatPromptValue(PromptValue):
"""Chat prompt value.
@@ -66,6 +71,11 @@ class ChatPromptValue(PromptValue):
"""Return prompt as a list of messages."""
return list(self.messages)
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "prompts", "chat"]
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.

View File

@@ -191,6 +191,11 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
else:
raise ValueError(f"{save_path} must be json or yaml")
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain"] + cls.__module__.split(".")[1:]
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template.

View File

@@ -75,6 +75,11 @@ class BaseMessagePromptTemplate(Serializable, ABC):
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain"] + cls.__module__.split(".")[1:]
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages."""

View File

@@ -1316,7 +1316,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
class Config:
arbitrary_types_allowed = True
@@ -1842,7 +1843,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
class Config:
arbitrary_types_allowed = True
@@ -2600,7 +2602,8 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
def _invoke(
self,
@@ -2772,7 +2775,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = merge_configs(self.config, *configs)

View File

@@ -128,8 +128,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
def get_input_schema(
self, config: Optional[RunnableConfig] = None

View File

@@ -52,7 +52,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
@property
def InputType(self) -> Type[Input]:

View File

@@ -125,7 +125,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:

View File

@@ -150,7 +150,8 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
@property
def InputType(self) -> Any:

View File

@@ -77,7 +77,8 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
# For backwards compatibility, replace langchain_core with langchain, schema.
return ["langchain", "schema", "runnable"]
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None

View File

@@ -448,6 +448,11 @@ class ChildTool(BaseTool):
"""Make tool callable."""
return self.run(tool_input, callbacks=callbacks)
@classmethod
def get_lc_namespace(cls) -> List[str]:
# For backwards compatibility.
return ["langchain", "tools", "base"]
class Tool(BaseTool):
"""Tool that takes in function or coroutine directly."""

View File

@@ -0,0 +1,9 @@
from langchain_core.documents import Document
def test_lc_namespace() -> None:
assert Document.get_lc_namespace() == [
"langchain",
"schema",
"document",
]

View File

@@ -0,0 +1,89 @@
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
def test_lc_namespace() -> None:
assert BaseMessage.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert AIMessage.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert HumanMessage.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert SystemMessage.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert FunctionMessage.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert ToolMessage.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert ChatMessage.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert BaseMessageChunk.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert AIMessageChunk.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert HumanMessageChunk.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert SystemMessageChunk.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert FunctionMessageChunk.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert ToolMessageChunk.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]
assert ChatMessageChunk.get_lc_namespace() == [
"langchain",
"schema",
"messages",
]

View File

@@ -0,0 +1,23 @@
from langchain_core.output_parsers import (
CommaSeparatedListOutputParser,
MarkdownListOutputParser,
NumberedListOutputParser,
)
def test_lc_namespace() -> None:
assert CommaSeparatedListOutputParser.get_lc_namespace() == [
"langchain",
"output_parsers",
"list",
]
assert NumberedListOutputParser.get_lc_namespace() == [
"langchain",
"output_parsers",
"list",
]
assert MarkdownListOutputParser.get_lc_namespace() == [
"langchain",
"output_parsers",
"list",
]

View File

@@ -0,0 +1,9 @@
from langchain_core.output_parsers import StrOutputParser
def test_lc_namespace() -> None:
assert StrOutputParser.get_lc_namespace() == [
"langchain",
"schema",
"output_parser",
]

View File

@@ -19,6 +19,7 @@ from langchain_core.prompts.chat import (
ChatMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
_convert_to_message,
)
@@ -360,3 +361,15 @@ def test_chat_message_partial() -> None:
]
assert res == expected
assert template2.format(input="hello") == get_buffer_string(expected)
def test_lc_namespace() -> None:
for cls_ in (
MessagesPlaceholder,
ChatMessagePromptTemplate,
HumanMessagePromptTemplate,
AIMessagePromptTemplate,
SystemMessagePromptTemplate,
ChatPromptTemplate,
):
assert cls_.get_lc_namespace() == ["langchain", "prompts", "chat"]

View File

@@ -416,3 +416,8 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None:
AIMessage(content="5", additional_kwargs={}, example=False),
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
]
def test_lc_namespace() -> None:
for cls_ in (FewShotPromptTemplate, FewShotChatMessagePromptTemplate):
assert cls_.get_lc_namespace() == ["langchain", "prompts", "few_shot"]

View File

@@ -74,3 +74,12 @@ def test_prompttemplate_validation() -> None:
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
).input_variables == ["content", "new_content"]
def test_lc_namespace() -> None:
for cls_ in (FewShotPromptWithTemplates,):
assert cls_.get_lc_namespace() == [
"langchain",
"prompts",
"few_shot_with_templates",
]

View File

@@ -43,3 +43,8 @@ def test_partial_with_chat_prompts() -> None:
assert pipeline_prompt.input_variables == ["bar"]
output = pipeline_prompt.format_prompt(bar="okay")
assert output.to_messages()[0].content == "jim okay"
def test_lc_namespace() -> None:
for cls_ in (PipelinePromptTemplate,):
assert cls_.get_lc_namespace() == ["langchain", "prompts", "pipeline"]

View File

@@ -297,3 +297,8 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]
def test_lc_namespace() -> None:
for cls_ in (PromptTemplate,):
assert cls_.get_lc_namespace() == ["langchain", "prompts", "prompt"]

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,25 @@
from langchain_core.runnables import RunnableSequence
from langchain_core.runnables.base import (
RunnableBindingBase,
RunnableEachBase,
RunnableParallel,
)
def test_lc_namespace() -> None:
assert RunnableSequence.get_lc_namespace() == ["langchain", "schema", "runnable"]
assert RunnableEachBase.get_lc_namespace() == [
"langchain",
"schema",
"runnable",
]
assert RunnableParallel.get_lc_namespace() == [
"langchain",
"schema",
"runnable",
]
assert RunnableBindingBase.get_lc_namespace() == [
"langchain",
"schema",
"runnable",
]

View File

@@ -0,0 +1,5 @@
from langchain_core.runnables import RunnableBranch
def test_lc_namespace() -> None:
assert RunnableBranch.get_lc_namespace() == ["langchain", "schema", "runnable"]

View File

@@ -0,0 +1,5 @@
from langchain_core.runnables.configurable import DynamicRunnable
def test_lc_namespace() -> None:
assert DynamicRunnable.get_lc_namespace() == ["langchain", "schema", "runnable"]

View File

@@ -0,0 +1,9 @@
from langchain_core.runnables import RunnableWithFallbacks
def test_lc_namespace() -> None:
assert RunnableWithFallbacks.get_lc_namespace() == [
"langchain",
"schema",
"runnable",
]

View File

@@ -0,0 +1,5 @@
from langchain_core.runnables import RunnablePassthrough
def test_lc_namespace() -> None:
assert RunnablePassthrough.get_lc_namespace() == ["langchain", "schema", "runnable"]

View File

@@ -0,0 +1,5 @@
from langchain_core.runnables import RouterRunnable
def test_lc_namespace() -> None:
assert RouterRunnable.get_lc_namespace() == ["langchain", "schema", "runnable"]

View File

@@ -0,0 +1,19 @@
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
def test_lc_namespace() -> None:
assert AgentAction.get_lc_namespace() == [
"langchain",
"schema",
"agent",
]
assert AgentActionMessageLog.get_lc_namespace() == [
"langchain",
"schema",
"agent",
]
assert AgentFinish.get_lc_namespace() == [
"langchain",
"schema",
"agent",
]

View File

@@ -0,0 +1,15 @@
from langchain_core.prompt_values import (
ChatPromptValue,
ChatPromptValueConcrete,
StringPromptValue,
)
def test_lc_namespace() -> None:
assert StringPromptValue.get_lc_namespace() == ["langchain", "prompts", "base"]
assert ChatPromptValue.get_lc_namespace() == ["langchain", "prompts", "chat"]
assert ChatPromptValueConcrete.get_lc_namespace() == [
"langchain",
"prompts",
"chat",
]

View File

@@ -23,6 +23,16 @@ from langchain_core.tools import (
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
def test_lc_namespace() -> None:
assert BaseTool.get_lc_namespace() == ["langchain", "tools", "base"]
assert Tool.get_lc_namespace() == ["langchain", "tools", "base"]
assert StructuredTool.get_lc_namespace() == [
"langchain",
"tools",
"base",
]
def test_unnamed_decorator() -> None:
"""Test functionality with unnamed decorator."""

View File

@@ -0,0 +1,95 @@
import importlib
import inspect
from langchain_core.load import Serializable
core_modules = [
"agents",
"caches",
"callbacks",
"chat_history",
"chat_sessions",
"documents",
"embeddings",
"env",
"example_selectors",
"exceptions",
"globals",
"language_models",
"load",
"memory",
"messages",
"output_parsers",
"outputs",
"prompt_values",
"prompts",
"pydantic_v1",
"retrievers",
"runnables",
"stores",
"tools",
"tracers",
"utils",
"vectorstores",
]
ignore_objects = set([
"langchain_core.vectorstores.VectorStoreRetriever" # this is a base class
])
def test_core_exported_from_langchain() -> None:
# iterate through core modules and get exported names that inherit from serializable
# and are not private
wrong_module = []
does_not_exist = []
for module_name in core_modules:
module = importlib.import_module(f"langchain_core.{module_name}")
for name in dir(module):
if name.startswith("_"):
continue
obj_name = f"langchain_core.{module_name}.{name}"
obj = getattr(module, name)
if not isinstance(obj, type):
continue
if not issubclass(obj, Serializable):
continue
if inspect.isabstract(obj):
continue
if obj is Serializable:
continue
if obj_name in ignore_objects:
continue
lc_id = obj.lc_id() # type: ignore
if not lc_id[0] == "langchain":
wrong_module.append(f"{obj_name} -> {lc_id}")
continue
# see if importable
[*id_namespace, id_name] = lc_id
import_name = ".".join(id_namespace)
try:
import_module = importlib.import_module(import_name)
import_obj = getattr(import_module, id_name)
except (ImportError, AttributeError):
does_not_exist.append(f"{obj_name} -> {lc_id}")
continue
# assert same id
assert import_obj.lc_id() == lc_id, f"{obj_name} -> {lc_id}"
# assert serializable
assert issubclass(
import_obj, Serializable
), f"Referenced object not serializable: {obj_name} -> {lc_id}"
if len(wrong_module) == 0 and len(does_not_exist) == 0:
return
wrong_module_message = "\n".join(f"- {m}" for m in wrong_module) or "None! Passed"
does_not_exist_message = (
"\n".join(f"- {m}" for m in does_not_exist) or "None! Passed"
)
assert False, f"""LC ID must be from langchain.x ({len(wrong_module)}):
{wrong_module_message}
The following LC IDs do not exist ({len(does_not_exist)}):
{does_not_exist_message}"""