Compare commits

...

7 Commits

Author SHA1 Message Date
William Fu-Hinthorn
c3ef56ad5f fix 2024-02-29 09:13:56 -08:00
William Fu-Hinthorn
8e45bb3b50 Merge remote-tracking branch 'origin/master' into wfh/add_warnings 2024-02-29 09:12:36 -08:00
William Fu-Hinthorn
4ee6386721 Merge branch 'master' into wfh/add_warnings 2024-02-26 15:39:36 -08:00
William Fu-Hinthorn
0e1f42c5a8 Update docs 2024-02-26 11:56:44 -08:00
William Fu-Hinthorn
7a7a5eb03c fixup 2024-02-26 11:48:39 -08:00
William Fu-Hinthorn
cbc5cbee63 Merge branch 'master' into wfh/add_warnings 2024-02-26 11:48:21 -08:00
William Fu-Hinthorn
966d03f61a Warn against implicit generator coercion 2024-02-26 10:58:22 -08:00
5 changed files with 52 additions and 29 deletions

View File

@@ -123,7 +123,9 @@
"metadata": {},
"outputs": [],
"source": [
"list_chain = str_chain | split_into_list"
"from langchain_core.runnables import RunnableGenerator\n",
"\n",
"list_chain = str_chain | RunnableGenerator(split_into_list)"
]
},
{
@@ -199,7 +201,7 @@
" yield [buffer.strip()]\n",
"\n",
"\n",
"list_chain = str_chain | asplit_into_list"
"list_chain = str_chain | RunnableGenerator(asplit_into_list)"
]
},
{

View File

@@ -360,6 +360,7 @@
],
"source": [
"from langchain_core.output_parsers import JsonOutputParser\n",
"from langchain_core.runnables import RunnableGenerator\n",
"\n",
"\n",
"async def _extract_country_names_streaming(input_stream):\n",
@@ -387,7 +388,7 @@
" country_names_so_far.add(name)\n",
"\n",
"\n",
"chain = model | JsonOutputParser() | _extract_country_names_streaming\n",
"chain = model | JsonOutputParser() | RunnableGenerator(_extract_country_names_streaming)\n",
"\n",
"async for text in chain.astream(\n",
" 'output a list of the countries france, spain and japan and their populations in JSON format. Use a dict with an outer key of \"countries\" which contains a list of countries. Each country should have the key `name` and `population`'\n",
@@ -1383,7 +1384,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".docs-venv",
"language": "python",
"name": "python3"
},

View File

@@ -155,31 +155,10 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "808a5df5-b11e-42a0-bd7a-6b95ca0c3eba",
"metadata": {},
"outputs": [
{
"ename": "ParseError",
"evalue": "syntax error: line 1, column 1 (<string>)",
"output_type": "error",
"traceback": [
"Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
"\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.1/envs/langchain/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3508\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
"\u001b[0m Cell \u001b[1;32mIn[7], line 1\u001b[0m\n for s in chain.stream({\"query\": actor_query}):\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/workplace/langchain/libs/core/langchain_core/runnables/base.py:1984\u001b[0m in \u001b[1;35mstream\u001b[0m\n yield from self.transform(iter([input]), config, **kwargs)\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/workplace/langchain/libs/core/langchain_core/runnables/base.py:1974\u001b[0m in \u001b[1;35mtransform\u001b[0m\n yield from self._transform_stream_with_config(\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/workplace/langchain/libs/core/langchain_core/runnables/base.py:1141\u001b[0m in \u001b[1;35m_transform_stream_with_config\u001b[0m\n for chunk in iterator:\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/workplace/langchain/libs/core/langchain_core/runnables/base.py:1938\u001b[0m in \u001b[1;35m_transform\u001b[0m\n for output in final_pipeline:\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/workplace/langchain/libs/core/langchain_core/output_parsers/transform.py:50\u001b[0m in \u001b[1;35mtransform\u001b[0m\n yield from self._transform_stream_with_config(\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/workplace/langchain/libs/core/langchain_core/runnables/base.py:1141\u001b[0m in \u001b[1;35m_transform_stream_with_config\u001b[0m\n for chunk in iterator:\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/workplace/langchain/libs/core/langchain_core/output_parsers/xml.py:71\u001b[0m in \u001b[1;35m_transform\u001b[0m\n for event, elem in parser.read_events():\u001b[0m\n",
"\u001b[0m File \u001b[1;32m~/.pyenv/versions/3.10.1/lib/python3.10/xml/etree/ElementTree.py:1329\u001b[0m in \u001b[1;35mread_events\u001b[0m\n raise event\u001b[0m\n",
"\u001b[0;36m File \u001b[0;32m~/.pyenv/versions/3.10.1/lib/python3.10/xml/etree/ElementTree.py:1301\u001b[0;36m in \u001b[0;35mfeed\u001b[0;36m\n\u001b[0;31m self._parser.feed(data)\u001b[0;36m\n",
"\u001b[0;36m File \u001b[0;32m<string>\u001b[0;36m\u001b[0m\n\u001b[0;31mParseError\u001b[0m\u001b[0;31m:\u001b[0m syntax error: line 1, column 1\n"
]
}
],
"outputs": [],
"source": [
"for s in chain.stream({\"query\": actor_query}):\n",
" print(s)"

View File

@@ -36,6 +36,7 @@ from typing import (
from typing_extensions import Literal, get_args
from langchain_core._api import beta_decorator
from langchain_core._api.deprecation import warn_deprecated
from langchain_core.load.dump import dumpd
from langchain_core.load.serializable import (
Serializable,
@@ -4366,7 +4367,21 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
result = RunnableGenerator(thing)
message = (
"Implicit coercion of custom functions to a "
"RunnableGenerator will be deprecated, and all implicit function coercion"
" will convert to a RunnableLambda."
" Please explicitly convert to a RunnableGenerator using:\n"
"from langchain_core.runnable import RunnableGenerator\n\n"
"runnable = (\n ...\n"
f" | RunnableGenerator({thing.__name__})\n"
")\n"
)
warn_deprecated("0.1.26", message=message, removal="0.2.0")
return result
elif callable(thing):
return RunnableLambda(cast(Callable[[Input], Output], thing))
elif isinstance(thing, dict):

View File

@@ -22,6 +22,7 @@ from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from typing_extensions import TypedDict
from langchain_core._api.deprecation import LangChainDeprecationWarning
from langchain_core.callbacks.manager import (
Callbacks,
atrace_as_chain_group,
@@ -69,7 +70,7 @@ from langchain_core.runnables import (
add,
chain,
)
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.base import RunnableSerializable, coerce_to_runnable
from langchain_core.tools import BaseTool, tool
from langchain_core.tracers import (
BaseTracer,
@@ -5183,3 +5184,28 @@ async def test_astream_log_deep_copies() -> None:
"name": "add_one",
"type": "chain",
}
def test_coerce_to_runnable() -> None:
"""Test that a callable can be coerced to a runnable."""
def add_one(x: int) -> int:
"""Add one."""
return x + 1
runnable = coerce_to_runnable(add_one)
assert runnable.invoke(1) == 2
assert isinstance(runnable, RunnableLambda)
def some_generator(x: Any) -> Iterator[int]:
"""Return 1."""
yield 1
# Assert that a warning is raised when trying to coerce a generator
with pytest.warns(
LangChainDeprecationWarning,
):
runnable = coerce_to_runnable(some_generator)
assert runnable.invoke(1) == 1
# To change in 0.2.0
assert isinstance(runnable, RunnableGenerator)