diff --git a/langchain/VERSION b/langchain/VERSION index 32786aa437b..44517d5188e 100644 --- a/langchain/VERSION +++ b/langchain/VERSION @@ -1 +1 @@ -0.0.18 +0.0.19 diff --git a/langchain/model_laboratory.py b/langchain/model_laboratory.py index 2bf30904878..614bba344f9 100644 --- a/langchain/model_laboratory.py +++ b/langchain/model_laboratory.py @@ -1,6 +1,7 @@ """Experiment with different models.""" -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Union +from langchain.agents.agent import Agent from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping, print_text @@ -11,29 +12,32 @@ from langchain.prompts.prompt import PromptTemplate class ModelLaboratory: """Experiment with different models.""" - def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None): + def __init__( + self, chains: Sequence[Union[Chain, Agent]], names: Optional[List[str]] = None + ): """Initialize with chains to experiment with. Args: chains: list of chains to experiment with. """ - if not isinstance(chains[0], Chain): - raise ValueError( - "ModelLaboratory should now be initialized with Chains. " - "If you want to initialize with LLMs, use the `from_llms` method " - "instead (`ModelLaboratory.from_llms(...)`)" - ) for chain in chains: - if len(chain.input_keys) != 1: + if not isinstance(chain, (Chain, Agent)): raise ValueError( - "Currently only support chains with one input variable, " - f"got {chain.input_keys}" - ) - if len(chain.output_keys) != 1: - raise ValueError( - "Currently only support chains with one output variable, " - f"got {chain.output_keys}" + "ModelLaboratory should now be initialized with Chains or Agents. " + "If you want to initialize with LLMs, use the `from_llms` method " + "instead (`ModelLaboratory.from_llms(...)`)" ) + if isinstance(chain, Chain): + if len(chain.input_keys) != 1: + raise ValueError( + "Currently only support chains with one input variable, " + f"got {chain.input_keys}" + ) + if len(chain.output_keys) != 1: + raise ValueError( + "Currently only support chains with one output variable, " + f"got {chain.output_keys}" + ) if names is not None: if len(names) != len(chains): raise ValueError("Length of chains does not match length of names.")