Rm RunTypeEnum (#8553)

We already support raw strings in the SDK but would like to deprecate
client-side validation of run types. This removes its usage
This commit is contained in:
William FH 2023-07-31 23:32:07 -07:00 committed by GitHub
parent 2a26cc6d2b
commit e83250cc5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 39 deletions

View File

@ -10,7 +10,7 @@ from uuid import UUID
from tenacity import RetryCallState from tenacity import RetryCallState
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum from langchain.callbacks.tracers.schemas import Run
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult from langchain.schema.output import ChatGeneration, LLMResult
@ -110,7 +110,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
start_time=start_time, start_time=start_time,
execution_order=execution_order, execution_order=execution_order,
child_execution_order=execution_order, child_execution_order=execution_order,
run_type=RunTypeEnum.llm, run_type="llm",
tags=tags or [], tags=tags or [],
) )
self._start_trace(llm_run) self._start_trace(llm_run)
@ -130,7 +130,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id) run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_) llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm: if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced") raise TracerException("No LLM Run found to be traced")
llm_run.events.append( llm_run.events.append(
{ {
@ -182,7 +182,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id) run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_) llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm: if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced") raise TracerException("No LLM Run found to be traced")
llm_run.outputs = response.dict() llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations): for i, generations in enumerate(response.generations):
@ -210,7 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id) run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_) llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm: if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced") raise TracerException("No LLM Run found to be traced")
llm_run.error = repr(error) llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow() llm_run.end_time = datetime.utcnow()
@ -246,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order, execution_order=execution_order,
child_execution_order=execution_order, child_execution_order=execution_order,
child_runs=[], child_runs=[],
run_type=RunTypeEnum.chain, run_type="chain",
tags=tags or [], tags=tags or [],
) )
self._start_trace(chain_run) self._start_trace(chain_run)
@ -259,7 +259,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.") raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id)) chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain: if chain_run is None or chain_run.run_type != "chain":
raise TracerException("No chain Run found to be traced") raise TracerException("No chain Run found to be traced")
chain_run.outputs = outputs chain_run.outputs = outputs
@ -279,7 +279,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.") raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id)) chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain: if chain_run is None or chain_run.run_type != "chain":
raise TracerException("No chain Run found to be traced") raise TracerException("No chain Run found to be traced")
chain_run.error = repr(error) chain_run.error = repr(error)
@ -316,7 +316,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order, execution_order=execution_order,
child_execution_order=execution_order, child_execution_order=execution_order,
child_runs=[], child_runs=[],
run_type=RunTypeEnum.tool, run_type="tool",
tags=tags or [], tags=tags or [],
) )
self._start_trace(tool_run) self._start_trace(tool_run)
@ -327,7 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.") raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id)) tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool: if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced") raise TracerException("No tool Run found to be traced")
tool_run.outputs = {"output": output} tool_run.outputs = {"output": output}
@ -347,7 +347,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.") raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id)) tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool: if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced") raise TracerException("No tool Run found to be traced")
tool_run.error = repr(error) tool_run.error = repr(error)
@ -386,7 +386,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order, child_execution_order=execution_order,
tags=tags, tags=tags,
child_runs=[], child_runs=[],
run_type=RunTypeEnum.retriever, run_type="retriever",
) )
self._start_trace(retrieval_run) self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run) self._on_retriever_start(retrieval_run)
@ -402,7 +402,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.") raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id)) retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever: if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced") raise TracerException("No retriever Run found to be traced")
retrieval_run.error = repr(error) retrieval_run.error = repr(error)
@ -418,7 +418,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.") raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id)) retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever: if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced") raise TracerException("No retriever Run found to be traced")
retrieval_run.outputs = {"documents": documents} retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow() retrieval_run.end_time = datetime.utcnow()

View File

@ -11,7 +11,7 @@ from uuid import UUID
from langsmith import Client from langsmith import Client
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession from langchain.callbacks.tracers.schemas import Run, TracerSession
from langchain.env import get_runtime_environment from langchain.env import get_runtime_environment
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
@ -107,7 +107,7 @@ class LangChainTracer(BaseTracer):
start_time=start_time, start_time=start_time,
execution_order=execution_order, execution_order=execution_order,
child_execution_order=execution_order, child_execution_order=execution_order,
run_type=RunTypeEnum.llm, run_type="llm",
tags=tags, tags=tags,
) )
self._start_trace(chat_model_run) self._start_trace(chat_model_run)

View File

@ -2,16 +2,27 @@
from __future__ import annotations from __future__ import annotations
import datetime import datetime
import warnings
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2 from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from pydantic import BaseModel, Field, root_validator from pydantic import BaseModel, Field, root_validator
from langchain.schema import LLMResult from langchain.schema import LLMResult
def RunTypeEnum() -> RunTypeEnumDep:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
)
return RunTypeEnumDep
class TracerSessionV1Base(BaseModel): class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1.""" """Base class for TracerSessionV1."""

View File

@ -15,7 +15,7 @@ from typing import (
) )
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum from langchain.callbacks.tracers.schemas import Run
if TYPE_CHECKING: if TYPE_CHECKING:
from wandb import Settings as WBSettings from wandb import Settings as WBSettings
@ -154,11 +154,11 @@ class RunProcessor:
:param run: The LangChain Run to convert. :param run: The LangChain Run to convert.
:return: The converted W&B Trace Span. :return: The converted W&B Trace Span.
""" """
if run.run_type == RunTypeEnum.llm: if run.run_type == "llm":
return self._convert_llm_run_to_wb_span(run) return self._convert_llm_run_to_wb_span(run)
elif run.run_type == RunTypeEnum.chain: elif run.run_type == "chain":
return self._convert_chain_run_to_wb_span(run) return self._convert_chain_run_to_wb_span(run)
elif run.run_type == RunTypeEnum.tool: elif run.run_type == "tool":
return self._convert_tool_run_to_wb_span(run) return self._convert_tool_run_to_wb_span(run)
else: else:
return self._convert_run_to_wb_span(run) return self._convert_run_to_wb_span(run)

View File

@ -22,7 +22,7 @@ from typing import (
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
from langsmith import Client, RunEvaluator from langsmith import Client, RunEvaluator
from langsmith.schemas import Dataset, DataType, Example, RunTypeEnum from langsmith.schemas import Dataset, DataType, Example
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
@ -341,9 +341,9 @@ def _setup_evaluation(
first_example, examples = _first_example(examples) first_example, examples = _first_example(examples)
if isinstance(llm_or_chain_factory, BaseLanguageModel): if isinstance(llm_or_chain_factory, BaseLanguageModel):
run_inputs, run_outputs = None, None run_inputs, run_outputs = None, None
run_type = RunTypeEnum.llm run_type = "llm"
else: else:
run_type = RunTypeEnum.chain run_type = "chain"
if data_type in (DataType.chat, DataType.llm): if data_type in (DataType.chat, DataType.llm):
raise ValueError( raise ValueError(
"Cannot evaluate a chain on dataset with " "Cannot evaluate a chain on dataset with "
@ -370,13 +370,13 @@ def _setup_evaluation(
def _determine_input_key( def _determine_input_key(
config: RunEvalConfig, config: RunEvalConfig,
run_inputs: Optional[List[str]], run_inputs: Optional[List[str]],
run_type: RunTypeEnum, run_type: str,
) -> Optional[str]: ) -> Optional[str]:
if config.input_key: if config.input_key:
input_key = config.input_key input_key = config.input_key
if run_inputs and input_key not in run_inputs: if run_inputs and input_key not in run_inputs:
raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}") raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}")
elif run_type == RunTypeEnum.llm: elif run_type == "llm":
input_key = None input_key = None
elif run_inputs and len(run_inputs) == 1: elif run_inputs and len(run_inputs) == 1:
input_key = run_inputs[0] input_key = run_inputs[0]
@ -391,7 +391,7 @@ def _determine_input_key(
def _determine_prediction_key( def _determine_prediction_key(
config: RunEvalConfig, config: RunEvalConfig,
run_outputs: Optional[List[str]], run_outputs: Optional[List[str]],
run_type: RunTypeEnum, run_type: str,
) -> Optional[str]: ) -> Optional[str]:
if config.prediction_key: if config.prediction_key:
prediction_key = config.prediction_key prediction_key = config.prediction_key
@ -399,7 +399,7 @@ def _determine_prediction_key(
raise ValueError( raise ValueError(
f"Prediction key {prediction_key} not in run outputs {run_outputs}" f"Prediction key {prediction_key} not in run outputs {run_outputs}"
) )
elif run_type == RunTypeEnum.llm: elif run_type == "llm":
prediction_key = None prediction_key = None
elif run_outputs and len(run_outputs) == 1: elif run_outputs and len(run_outputs) == 1:
prediction_key = run_outputs[0] prediction_key = run_outputs[0]
@ -432,7 +432,7 @@ def _determine_reference_key(
def _construct_run_evaluator( def _construct_run_evaluator(
eval_config: Union[EvaluatorType, EvalConfig], eval_config: Union[EvaluatorType, EvalConfig],
eval_llm: BaseLanguageModel, eval_llm: BaseLanguageModel,
run_type: RunTypeEnum, run_type: str,
data_type: DataType, data_type: DataType,
example_outputs: Optional[List[str]], example_outputs: Optional[List[str]],
reference_key: Optional[str], reference_key: Optional[str],
@ -472,7 +472,7 @@ def _construct_run_evaluator(
def _load_run_evaluators( def _load_run_evaluators(
config: RunEvalConfig, config: RunEvalConfig,
run_type: RunTypeEnum, run_type: str,
data_type: DataType, data_type: DataType,
example_outputs: Optional[List[str]], example_outputs: Optional[List[str]],
run_inputs: Optional[List[str]], run_inputs: Optional[List[str]],

View File

@ -5,7 +5,7 @@ from abc import abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langsmith import EvaluationResult, RunEvaluator from langsmith import EvaluationResult, RunEvaluator
from langsmith.schemas import DataType, Example, Run, RunTypeEnum from langsmith.schemas import DataType, Example, Run
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -327,7 +327,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
def from_run_and_data_type( def from_run_and_data_type(
cls, cls,
evaluator: StringEvaluator, evaluator: StringEvaluator,
run_type: RunTypeEnum, run_type: str,
data_type: DataType, data_type: DataType,
input_key: Optional[str] = None, input_key: Optional[str] = None,
prediction_key: Optional[str] = None, prediction_key: Optional[str] = None,
@ -343,7 +343,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
Args: Args:
evaluator (StringEvaluator): The string evaluator to use. evaluator (StringEvaluator): The string evaluator to use.
run_type (RunTypeEnum): The type of run being evaluated. run_type (str): The type of run being evaluated.
Supported types are LLM and Chain. Supported types are LLM and Chain.
data_type (DataType): The type of dataset used in the run. data_type (DataType): The type of dataset used in the run.
input_key (str, optional): The key used to map the input from the run. input_key (str, optional): The key used to map the input from the run.
@ -361,9 +361,9 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
""" # noqa: E501 """ # noqa: E501
# Configure how run inputs/predictions are passed to the evaluator # Configure how run inputs/predictions are passed to the evaluator
if run_type == RunTypeEnum.llm: if run_type == "llm":
run_mapper: StringRunMapper = LLMStringRunMapper() run_mapper: StringRunMapper = LLMStringRunMapper()
elif run_type == RunTypeEnum.chain: elif run_type == "chain":
run_mapper = ChainStringRunMapper( run_mapper = ChainStringRunMapper(
input_key=input_key, prediction_key=prediction_key input_key=input_key, prediction_key=prediction_key
) )

View File

@ -17,7 +17,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
ToolRun, ToolRun,
TracerSessionV1, TracerSessionV1,
) )
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base
from langchain.schema import LLMResult from langchain.schema import LLMResult
from langchain.schema.messages import HumanMessage from langchain.schema.messages import HumanMessage
@ -589,7 +589,7 @@ def test_convert_run(
outputs=LLMResult(generations=[[]]).dict(), outputs=LLMResult(generations=[[]]).dict(),
serialized={}, serialized={},
extra={}, extra={},
run_type=RunTypeEnum.llm, run_type="llm",
) )
chain_run = Run( chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad", id="57a08cc4-73d2-4236-8371-549099d07fad",
@ -603,7 +603,7 @@ def test_convert_run(
outputs={}, outputs={},
child_runs=[llm_run], child_runs=[llm_run],
extra={}, extra={},
run_type=RunTypeEnum.chain, run_type="chain",
) )
tool_run = Run( tool_run = Run(
@ -618,7 +618,7 @@ def test_convert_run(
serialized={}, serialized={},
child_runs=[], child_runs=[],
extra={}, extra={},
run_type=RunTypeEnum.tool, run_type="tool",
) )
expected_llm_run = LLMRun( expected_llm_run = LLMRun(