add apply functionality (#150)

This commit is contained in:
Harrison Chase 2022-11-16 21:39:02 -08:00 committed by GitHub
parent 47e35d7d0e
commit d775ddd749
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 5 deletions

View File

@ -49,6 +49,10 @@ class Chain(BaseModel, ABC):
self._validate_outputs(outputs) self._validate_outputs(outputs)
return {**inputs, **outputs} return {**inputs, **outputs}
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]
def run(self, text: str) -> str: def run(self, text: str) -> str:
"""Run text in, text out (if applicable).""" """Run text in, text out (if applicable)."""
if len(self.input_keys) != 1: if len(self.input_keys) != 1:

View File

@ -60,13 +60,12 @@ class MapReduceChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
# Split the larger text into smaller chunks. # Split the larger text into smaller chunks.
docs = self.text_splitter.split_text(inputs[self.input_key]) docs = self.text_splitter.split_text(inputs[self.input_key])
# Now that we have the chunks, we send them to the LLM and track results. # Now that we have the chunks, we send them to the LLM and track results.
# This is the "map" part. # This is the "map" part.
summaries = [] input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs]
for d in docs: summary_results = self.map_llm.apply(input_list)
inputs = {self.map_llm.prompt.input_variables[0]: d} summaries = [res[self.map_llm.output_key] for res in summary_results]
res = self.map_llm.predict(**inputs)
summaries.append(res)
# We then need to combine these individual parts into one. # We then need to combine these individual parts into one.
# This is the reduce part. # This is the reduce part.