mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
Convert Chain to a Chain Factory (#4605)
## Change Chain argument in client to accept a chain factory The `run_over_dataset` functionality seeks to treat each iteration of an example as an independent trial. Chains have memory, so it's easier to permit this type of behavior if we accept a factory method rather than the chain object directly. There's still corner cases / UX pains people will likely run into, like: - Caching may cause issues - if memory is persisted to a shared object (e.g., same redis queue) , this could impact what is retrieved - If we're running the async methods with concurrency using local models, if someone naively instantiates the chain and loads each time, it could lead to tons of disk I/O or OOM
This commit is contained in:
parent
ed0d557ede
commit
0c6ed657ef
@ -40,6 +40,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||
|
||||
|
||||
def _get_link_stem(url: str) -> str:
|
||||
scheme = urlsplit(url).scheme
|
||||
@ -99,6 +101,21 @@ class LangChainPlusClient(BaseSettings):
|
||||
raise ValueError("No seeded tenant found")
|
||||
return results[0]["id"]
|
||||
|
||||
@staticmethod
|
||||
def _get_session_name(
|
||||
session_name: Optional[str],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
dataset_name: str,
|
||||
) -> str:
|
||||
if session_name is not None:
|
||||
return session_name
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
model_name = llm_or_chain_factory.__class__.__name__
|
||||
else:
|
||||
model_name = llm_or_chain_factory().__class__.__name__
|
||||
return f"{dataset_name}-{model_name}-{current_time}"
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the instance with a link to the URL."""
|
||||
link = _get_link_stem(self.api_url)
|
||||
@ -312,7 +329,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
async def _arun_llm_or_chain(
|
||||
example: Example,
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
n_repetitions: int,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain asynchronously."""
|
||||
@ -321,12 +338,13 @@ class LangChainPlusClient(BaseSettings):
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain, BaseLanguageModel):
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = await LangChainPlusClient._arun_llm(
|
||||
llm_or_chain, example.inputs, langchain_tracer
|
||||
llm_or_chain_factory, example.inputs, langchain_tracer
|
||||
)
|
||||
else:
|
||||
output = await llm_or_chain.arun(
|
||||
chain = llm_or_chain_factory()
|
||||
output = await chain.arun(
|
||||
example.inputs, callbacks=[langchain_tracer]
|
||||
)
|
||||
outputs.append(output)
|
||||
@ -388,7 +406,8 @@ class LangChainPlusClient(BaseSettings):
|
||||
async def arun_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
concurrency_level: int = 5,
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
@ -399,7 +418,9 @@ class LangChainPlusClient(BaseSettings):
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain: Chain or language model to run over the dataset.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
@ -411,11 +432,9 @@ class LangChainPlusClient(BaseSettings):
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
if session_name is None:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
session_name = (
|
||||
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
|
||||
)
|
||||
session_name = LangChainPlusClient._get_session_name(
|
||||
session_name, llm_or_chain_factory, dataset_name
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||
results: Dict[str, List[Any]] = {}
|
||||
@ -427,7 +446,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
result = await LangChainPlusClient._arun_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
)
|
||||
results[str(example.id)] = result
|
||||
@ -474,7 +493,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
def run_llm_or_chain(
|
||||
example: Example,
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
n_repetitions: int,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain synchronously."""
|
||||
@ -483,14 +502,13 @@ class LangChainPlusClient(BaseSettings):
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain, BaseLanguageModel):
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = LangChainPlusClient.run_llm(
|
||||
llm_or_chain, example.inputs, langchain_tracer
|
||||
llm_or_chain_factory, example.inputs, langchain_tracer
|
||||
)
|
||||
else:
|
||||
output = llm_or_chain.run(
|
||||
example.inputs, callbacks=[langchain_tracer]
|
||||
)
|
||||
chain = llm_or_chain_factory()
|
||||
output = chain.run(example.inputs, callbacks=[langchain_tracer])
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
@ -502,7 +520,8 @@ class LangChainPlusClient(BaseSettings):
|
||||
def run_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
@ -511,7 +530,9 @@ class LangChainPlusClient(BaseSettings):
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain: Chain or language model to run over the dataset.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
concurrency_level: Number of async workers to run in parallel.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
@ -523,11 +544,9 @@ class LangChainPlusClient(BaseSettings):
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
if session_name is None:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
session_name = (
|
||||
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
|
||||
)
|
||||
session_name = LangChainPlusClient._get_session_name(
|
||||
session_name, llm_or_chain_factory, dataset_name
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = list(self.list_examples(dataset_id=str(dataset.id)))
|
||||
results: Dict[str, Any] = {}
|
||||
@ -539,7 +558,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
result = self.run_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
)
|
||||
if verbose:
|
||||
|
@ -133,21 +133,19 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The current population of Canada as of 2023 is 39,566,248.\n",
|
||||
"Anwar Hadid's age raised to the 0.43 power is approximately 3.87.\n",
|
||||
"Anwar Hadid is Dua Lipa's boyfriend and his age raised to the 0.43 power is approximately 3.87.\n",
|
||||
"LLMMathChain._evaluate(\"\n",
|
||||
"(age)**0.43\n",
|
||||
"\") raised error: 'age'. Please try again with a valid numerical expression\n",
|
||||
"The distance between Paris and Boston is 3448 miles.\n",
|
||||
"unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n",
|
||||
"The distance between Paris and Boston is approximately 3448 miles.\n",
|
||||
"unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n",
|
||||
"LLMMathChain._evaluate(\"\n",
|
||||
"(total number of points scored in the 2023 super bowl)**0.23\n",
|
||||
"\") raised error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n",
|
||||
"3 points were scored more in the 2023 Super Bowl than in the 2022 Super Bowl.\n",
|
||||
"Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n",
|
||||
"1.9347796717823205\n",
|
||||
"81\n",
|
||||
"LLMMathChain._evaluate(\"\n",
|
||||
"round(0.2791714614499425, 2)\n",
|
||||
"\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n"
|
||||
"77\n",
|
||||
"0.2791714614499425\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -254,12 +252,109 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 7,
|
||||
"id": "60d14593-c61f-449f-a38f-772ca43707c2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c34edde8de5340888b3278d1ac427417",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>input</th>\n",
|
||||
" <th>output</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>How many people live in canada as of 2023?</td>\n",
|
||||
" <td>approximately 38,625,801</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>who is dua lipa's boyfriend? what is his age r...</td>\n",
|
||||
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>what is dua lipa's boyfriend age raised to the...</td>\n",
|
||||
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>how far is it from paris to boston in miles</td>\n",
|
||||
" <td>approximately 3,435 mi</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>what was the total number of points scored in ...</td>\n",
|
||||
" <td>approximately 2.682651500990882</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" input \\\n",
|
||||
"0 How many people live in canada as of 2023? \n",
|
||||
"1 who is dua lipa's boyfriend? what is his age r... \n",
|
||||
"2 what is dua lipa's boyfriend age raised to the... \n",
|
||||
"3 how far is it from paris to boston in miles \n",
|
||||
"4 what was the total number of points scored in ... \n",
|
||||
"\n",
|
||||
" output \n",
|
||||
"0 approximately 38,625,801 \n",
|
||||
"1 her boyfriend is Romain Gravas. his age raised... \n",
|
||||
"2 her boyfriend is Romain Gravas. his age raised... \n",
|
||||
"3 approximately 3,435 mi \n",
|
||||
"4 approximately 2.682651500990882 "
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# import pandas as pd\n",
|
||||
"# from langchain.evaluation.loading import load_dataset\n",
|
||||
@ -272,7 +367,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"id": "52a7ea76-79ca-4765-abf7-231e884040d6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -308,7 +403,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 9,
|
||||
"id": "c2b59104-b90e-466a-b7ea-c5bd0194263b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -336,7 +431,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 10,
|
||||
"id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -348,7 +443,8 @@
|
||||
"\u001b[0;31mSignature:\u001b[0m\n",
|
||||
"\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Union[Chain, BaseLanguageModel]'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mconcurrency_level\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
@ -359,7 +455,9 @@
|
||||
"\n",
|
||||
"Args:\n",
|
||||
" dataset_name: Name of the dataset to run the chain on.\n",
|
||||
" llm_or_chain: Chain or language model to run over the dataset.\n",
|
||||
" llm_or_chain_factory: Language model or Chain constructor to run\n",
|
||||
" over the dataset. The Chain constructor is used to permit\n",
|
||||
" independent calls on each example without carrying over state.\n",
|
||||
" concurrency_level: The number of async tasks to run concurrently.\n",
|
||||
" num_repetitions: Number of times to run the model on each example.\n",
|
||||
" This is useful when testing success rates or generating confidence\n",
|
||||
@ -384,7 +482,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 11,
|
||||
"id": "6e10f823",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Since chains can be stateful (e.g. they can have memory), we need provide\n",
|
||||
"# a way to initialize a new chain for each row in the dataset. This is done\n",
|
||||
"# by passing in a factory function that returns a new chain for each row.\n",
|
||||
"chain_factory = lambda: initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)\n",
|
||||
"\n",
|
||||
"# If your chain is NOT stateful, your lambda can return the object directly\n",
|
||||
"# to improve runtime performance. For example:\n",
|
||||
"# chain_factory = lambda: agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -396,7 +513,9 @@
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
|
||||
" warnings.warn(\n",
|
||||
"Chain failed for example 92c75ce4-f807-4d44-8f7e-027610f7fcbd. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n"
|
||||
"Chain failed for example 5523e460-6bb4-4a64-be37-bec0a98699a4. Error: LLMMathChain._evaluate(\"\n",
|
||||
"(total number of points scored in the 2023 super bowl)**0.23\n",
|
||||
"\") raised error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -410,25 +529,23 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chain failed for example 9f5d1426-3e21-4628-b5f9-d2ad354bfa8d. Error: LLMMathChain._evaluate(\"\n",
|
||||
"(age ** 0.43)\n",
|
||||
"\") raised error: 'age'. Please try again with a valid numerical expression\n"
|
||||
"Chain failed for example f193a3f6-1147-4ce6-a83e-fab1157dc88d. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 4\r"
|
||||
"Processed examples: 6\r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chain failed for example e480f086-6d3f-4659-8669-26316db7e772. Error: LLMMathChain._evaluate(\"\n",
|
||||
"(total number of points scored in the 2023 super bowl)**0.23\n",
|
||||
"\") raised error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n"
|
||||
"Chain failed for example 6d7bbb45-1dc0-4adc-be21-4f76a208a8d2. Error: LLMMathChain._evaluate(\"\n",
|
||||
"(age ** 0.43)\n",
|
||||
"\") raised error: 'age'. Please try again with a valid numerical expression\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -442,7 +559,7 @@
|
||||
"source": [
|
||||
"chain_results = await client.arun_on_dataset(\n",
|
||||
" dataset_name=dataset_name,\n",
|
||||
" llm_or_chain=agent,\n",
|
||||
" llm_or_chain_factory=chain_factory,\n",
|
||||
" verbose=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@ -463,7 +580,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 13,
|
||||
"id": "136db492-d6ca-4215-96f9-439c23538241",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -478,7 +595,7 @@
|
||||
"LangChainPlusClient (API URL: http://localhost:8000)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -508,7 +625,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 14,
|
||||
"id": "64490d7c-9a18-49ed-a3ac-36049c522cb4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -524,7 +641,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "31a576ae98634602b349046ec0821c0d",
|
||||
"model_id": "047a8094367f43938f74e863b3e01711",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@ -606,7 +723,7 @@
|
||||
"4 [{'data': {'content': 'Here is the topic for a... "
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -622,7 +739,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 15,
|
||||
"id": "348acd86-a927-4d60-8d52-02e64585e4fc",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -652,7 +769,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 16,
|
||||
"id": "a69dd183-ad5e-473d-b631-db90706e837f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -691,7 +808,7 @@
|
||||
"source": [
|
||||
"chat_model_results = await client.arun_on_dataset(\n",
|
||||
" dataset_name=chat_dataset_name,\n",
|
||||
" llm_or_chain=chat_model,\n",
|
||||
" llm_or_chain_factory=chat_model,\n",
|
||||
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
|
||||
" num_repetitions=3,\n",
|
||||
" verbose=True\n",
|
||||
@ -936,7 +1053,7 @@
|
||||
"# We also offer a synchronous method for running examples if a chain or llm's async methods aren't yet implemented\n",
|
||||
"completions_model_results = client.run_on_dataset(\n",
|
||||
" dataset_name=completions_dataset_name,\n",
|
||||
" llm_or_chain=llm,\n",
|
||||
" llm_or_chain_factory=llm,\n",
|
||||
" num_repetitions=1,\n",
|
||||
" verbose=True\n",
|
||||
")"
|
||||
|
@ -218,7 +218,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
num_repetitions = 3
|
||||
results = await client.arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain=chain,
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
concurrency_level=2,
|
||||
session_name="test_session",
|
||||
num_repetitions=num_repetitions,
|
||||
|
Loading…
Reference in New Issue
Block a user