diff --git a/libs/experimental/langchain_experimental/smart_llm/base.py b/libs/experimental/langchain_experimental/smart_llm/base.py index 8301c5df534..d9d8929cb07 100644 --- a/libs/experimental/langchain_experimental/smart_llm/base.py +++ b/libs/experimental/langchain_experimental/smart_llm/base.py @@ -66,6 +66,7 @@ class SmartLLMChain(Chain): prompt: BasePromptTemplate """Prompt object to use.""" + output_key: str = "resolution" ideation_llm: Optional[BaseLanguageModel] = None """LLM to use in ideation step. If None given, 'llm' will be used.""" critique_llm: Optional[BaseLanguageModel] = None @@ -132,8 +133,8 @@ class SmartLLMChain(Chain): def output_keys(self) -> List[str]: """Defines the output keys.""" if self.return_intermediate_steps: - return ["ideas", "critique", "resolution"] - return ["resolution"] + return ["ideas", "critique", self.output_key] + return [self.output_key] def prep_prompts( self, @@ -169,8 +170,8 @@ class SmartLLMChain(Chain): self.history.critique = critique resolution = self._resolve(stop, run_manager) if self.return_intermediate_steps: - return {"ideas": ideas, "critique": critique, "resolution": resolution} - return {"resolution": resolution} + return {"ideas": ideas, "critique": critique, self.output_key: resolution} + return {self.output_key: resolution} def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str: """Between steps, only the LLM result text is passed, not the LLMResult object.