mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
add apply functionality (#150)
This commit is contained in:
parent
47e35d7d0e
commit
d775ddd749
@ -49,6 +49,10 @@ class Chain(BaseModel, ABC):
|
||||
self._validate_outputs(outputs)
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs) for inputs in input_list]
|
||||
|
||||
def run(self, text: str) -> str:
|
||||
"""Run text in, text out (if applicable)."""
|
||||
if len(self.input_keys) != 1:
|
||||
|
@ -60,13 +60,12 @@ class MapReduceChain(Chain, BaseModel):
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
# Split the larger text into smaller chunks.
|
||||
docs = self.text_splitter.split_text(inputs[self.input_key])
|
||||
|
||||
# Now that we have the chunks, we send them to the LLM and track results.
|
||||
# This is the "map" part.
|
||||
summaries = []
|
||||
for d in docs:
|
||||
inputs = {self.map_llm.prompt.input_variables[0]: d}
|
||||
res = self.map_llm.predict(**inputs)
|
||||
summaries.append(res)
|
||||
input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs]
|
||||
summary_results = self.map_llm.apply(input_list)
|
||||
summaries = [res[self.map_llm.output_key] for res in summary_results]
|
||||
|
||||
# We then need to combine these individual parts into one.
|
||||
# This is the reduce part.
|
||||
|
Loading…
Reference in New Issue
Block a user