Harrison/version 0019 (#172)

This commit is contained in:
Harrison Chase 2022-11-22 06:51:51 -08:00 committed by GitHub
parent d3a7429f61
commit d70b5a2cbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 17 deletions

View File

@ -1 +1 @@
0.0.18
0.0.19

View File

@ -1,6 +1,7 @@
"""Experiment with different models."""
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union
from langchain.agents.agent import Agent
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
@ -11,29 +12,32 @@ from langchain.prompts.prompt import PromptTemplate
class ModelLaboratory:
"""Experiment with different models."""
def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
def __init__(
self, chains: Sequence[Union[Chain, Agent]], names: Optional[List[str]] = None
):
"""Initialize with chains to experiment with.
Args:
chains: list of chains to experiment with.
"""
if not isinstance(chains[0], Chain):
raise ValueError(
"ModelLaboratory should now be initialized with Chains. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
for chain in chains:
if len(chain.input_keys) != 1:
if not isinstance(chain, (Chain, Agent)):
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
"ModelLaboratory should now be initialized with Chains or Agents. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
if isinstance(chain, Chain):
if len(chain.input_keys) != 1:
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
if names is not None:
if len(names) != len(chains):
raise ValueError("Length of chains does not match length of names.")