diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 15827f6577c..53befcff134 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -49,6 +49,10 @@ class Chain(BaseModel, ABC): self._validate_outputs(outputs) return {**inputs, **outputs} + def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + """Call the chain on all inputs in the list.""" + return [self(inputs) for inputs in input_list] + def run(self, text: str) -> str: """Run text in, text out (if applicable).""" if len(self.input_keys) != 1: diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 8286e49cca1..623b95cdfcd 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -60,13 +60,12 @@ class MapReduceChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: # Split the larger text into smaller chunks. docs = self.text_splitter.split_text(inputs[self.input_key]) + # Now that we have the chunks, we send them to the LLM and track results. # This is the "map" part. - summaries = [] - for d in docs: - inputs = {self.map_llm.prompt.input_variables[0]: d} - res = self.map_llm.predict(**inputs) - summaries.append(res) + input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs] + summary_results = self.map_llm.apply(input_list) + summaries = [res[self.map_llm.output_key] for res in summary_results] # We then need to combine these individual parts into one. # This is the reduce part.