Compare commits

...

1 Commits

Author SHA1 Message Date
vowelparrot
b07cfbaa82 accept any callable 2023-06-28 15:27:06 -07:00
3 changed files with 154 additions and 48 deletions

View File

@@ -4,10 +4,12 @@ from __future__ import annotations
import asyncio
import functools
import inspect
import logging
from datetime import datetime
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
@@ -15,6 +17,7 @@ from typing import (
List,
Optional,
Sequence,
Tuple,
Union,
)
@@ -41,13 +44,43 @@ from langchain.schema import (
logger = logging.getLogger(__name__)
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
MODEL_OR_CHAIN_FACTORY = Union[
Callable[[], Chain],
BaseLanguageModel,
Callable[[Dict], Dict],
Callable[[Dict], Awaitable[Dict]],
]
def is_chain_factory(model: MODEL_OR_CHAIN_FACTORY) -> bool:
"""Check if a callback is a chain factory."""
return (
not isinstance(model, BaseLanguageModel)
and inspect.isfunction(model)
and len(inspect.signature(model).parameters) == 0
and isinstance(model(), Chain)
)
def get_model_split(
model: MODEL_OR_CHAIN_FACTORY,
) -> Tuple[Optional[BaseLanguageModel], Optional[Callable[[], Chain]], Optional[Any]]:
if isinstance(model, BaseLanguageModel):
return model, None, None
elif is_chain_factory(model):
return None, model, None # type: ignore
else:
return None, None, model
class InputFormatError(Exception):
"""Raised when the input format is invalid."""
class IllegalModelError(Exception):
"""Raised when the model is not allowed to run."""
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
"""
Get prompts from inputs.
@@ -186,9 +219,11 @@ async def _arun_llm(
async def _arun_llm_or_chain(
example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
*,
llm: Optional[BaseLanguageModel] = None,
chain_factory: Optional[Callable[[], Chain]] = None,
model: Optional[Callable[[Dict], Awaitable[Dict]]] = None,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
@@ -197,8 +232,10 @@ async def _arun_llm_or_chain(
Args:
example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
llm: The language model to run.
chain_factory: The chain factory to run.
model: The model to run.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
@@ -217,19 +254,28 @@ async def _arun_llm_or_chain(
outputs = []
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
if llm is not None:
output: Any = await _arun_llm(
llm_or_chain_factory,
llm,
example.inputs,
tags=tags,
callbacks=callbacks,
)
else:
chain = llm_or_chain_factory()
elif model is not None:
output = await model(example.inputs)
elif chain_factory is not None:
chain = chain_factory()
output = await chain.acall(
example.inputs, callbacks=callbacks, tags=tags
)
else:
raise IllegalModelError(
"Must specify either llm, chain_factory, or "
"a coroutine that accepts an inputs dictionary"
)
outputs.append(output)
except IllegalModelError:
raise
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
@@ -365,6 +411,7 @@ async def arun_on_examples(
evaluation_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], client=client_
)
llm, chain_factory, model = get_model_split(llm_or_chain_factory)
async def process_example(
example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
@@ -372,8 +419,10 @@ async def arun_on_examples(
"""Process a single example."""
result = await _arun_llm_or_chain(
example,
llm_or_chain_factory,
num_repetitions,
llm=llm,
chain_factory=chain_factory,
model=model,
tags=tags,
callbacks=callbacks,
)
@@ -449,9 +498,11 @@ def run_llm(
def run_llm_or_chain(
example: Example,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
*,
llm: Optional[BaseLanguageModel] = None,
chain_factory: Optional[Callable[[], Chain]] = None,
model: Optional[Callable[[dict], dict]] = None,
tags: Optional[List[str]] = None,
callbacks: Optional[List[BaseCallbackHandler]] = None,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
@@ -460,8 +511,10 @@ def run_llm_or_chain(
Args:
example: The example to run.
llm_or_chain_factory: The Chain or language model constructor to run.
n_repetitions: The number of times to run the model on each example.
llm: The language model to run.
chain_factory: The Chain constructor to run.
model: The model to run.
tags: Optional tags to add to the run.
callbacks: Optional callbacks to use during the run.
@@ -480,14 +533,20 @@ def run_llm_or_chain(
outputs = []
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = run_llm(
llm_or_chain_factory, example.inputs, callbacks, tags=tags
)
else:
chain = llm_or_chain_factory()
if llm is not None:
output: Any = run_llm(llm, example.inputs, callbacks, tags=tags)
elif model is not None:
output = model(example.inputs)
elif chain_factory is not None:
chain = chain_factory()
output = chain(example.inputs, callbacks=callbacks, tags=tags)
else:
raise IllegalModelError(
"Either llm, chain_factory or model must be provided."
)
outputs.append(output)
except IllegalModelError as e:
raise e
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
@@ -540,12 +599,15 @@ def run_on_examples(
evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], client=client_
)
llm, chain_factory, model = get_model_split(llm_or_chain_factory)
callbacks: List[BaseCallbackHandler] = [tracer, evalution_handler]
for i, example in enumerate(examples):
result = run_llm_or_chain(
example,
llm_or_chain_factory,
num_repetitions,
llm=llm,
chain_factory=chain_factory,
model=model,
tags=tags,
callbacks=callbacks,
)
@@ -578,8 +640,10 @@ def _get_project_name(
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if isinstance(llm_or_chain_factory, BaseLanguageModel):
model_name = llm_or_chain_factory.__class__.__name__
elif is_chain_factory(llm_or_chain_factory):
model_name = llm_or_chain_factory().__class__.__name__ # type: ignore[call-arg]
else:
model_name = llm_or_chain_factory().__class__.__name__
model_name = ""
dataset_prefix = f"{dataset_name}-" if dataset_name else ""
return f"{dataset_prefix}{model_name}-{current_time}"

View File

@@ -1,7 +1,7 @@
"""Test the LangChain+ client."""
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from unittest import mock
import pytest
@@ -17,6 +17,7 @@ from langchain.client.runner_utils import (
arun_on_dataset,
run_llm,
)
from tests.unit_tests.chains.test_base import FakeChain
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -104,9 +105,9 @@ def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.asyncio
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = Dataset(
@pytest.fixture
def dataset() -> Dataset:
return Dataset(
id=uuid.uuid4(),
name="test",
description="Test dataset",
@@ -114,13 +115,21 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
created_at=_CREATED_AT,
tenant_id=_TENANT_ID,
)
uuids = [
@pytest.fixture
def uuids() -> List[str]:
return [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
@pytest.fixture
def examples(uuids: List[str]) -> List[Example]:
examples = [
Example(
id=uuids[0],
@@ -158,24 +167,35 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset_id=str(uuid.uuid4()),
),
]
return examples
_RUN_OBJECTS = [
FakeLLM(
queres={str(i): f"Result for input {i}" for i in range(10)},
sequential_responses=True,
),
FakeChatModel(),
lambda: FakeChain(the_input_keys=["input"], the_output_keys=["output"]),
lambda input_: {"result": f"Result for input {input_}"},
]
@pytest.mark.asyncio
@pytest.mark.parametrize("model", _RUN_OBJECTS)
async def test_arun_on_dataset(
monkeypatch: pytest.MonkeyPatch,
examples: List[Example],
dataset: Dataset,
uuids: List[str],
model: Union[BaseLanguageModel, Chain, Callable[[dict], dict]],
) -> None:
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]:
return examples
async def mock_arun_chain(
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_create_project(*args: Any, **kwargs: Any) -> None:
pass
@@ -183,28 +203,28 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
LangChainPlusClient, "read_dataset", new=mock_read_dataset
), mock.patch.object(
LangChainPlusClient, "list_examples", new=mock_list_examples
), mock.patch(
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
), mock.patch.object(
LangChainPlusClient, "create_project", new=mock_create_project
), mock.patch(
"langchain.client.runner_utils.LangChainTracer", mock.MagicMock
):
client = LangChainPlusClient(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock()
num_repetitions = 3
results = await arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
llm_or_chain_factory=model,
concurrency_level=2,
project_name="test_project",
num_repetitions=num_repetitions,
client=client,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
]
for uuid_ in uuids
}
assert results["results"] == expected
assert "results" in results
assert results["results"]
# expected = {
# uuid_: [
# {"result": f"Result for example {uuid.UUID(uuid_)}"}
# for _ in range(num_repetitions)
# ]
# for uuid_ in uuids
# }
# assert results["results"] == expected

View File

@@ -1,9 +1,13 @@
"""Fake LLM wrapper for testing purposes."""
import asyncio
from typing import Any, List, Mapping, Optional, cast
from pydantic import validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
@@ -50,6 +54,24 @@ class FakeLLM(LLM):
else:
return "bar"
async def _acall(
self,
prompt: str,
stop: List[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> str:
await asyncio.sleep(0)
if self.sequential_responses:
return self._get_next_response_in_sequence
if self.queries is not None:
return self.queries[prompt]
if stop is None:
return "foo"
else:
return "bar"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}