diff --git a/langchain/llms/base.py b/langchain/llms/base.py index c30799816c2..d09c4eff925 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -120,15 +120,21 @@ class BaseLLM(BaseLanguageModel, ABC): llm_outputs = [] for prompt, gens in zip(prompts, result.generations): try: - llm_output = { + token_usage = { "completion_tokens": self.get_num_tokens( "".join([gen.text for gen in gens]) ), "prompt_tokens": self.get_num_tokens(prompt), } - llm_output["total_tokens"] = ( - llm_output["completion_tokens"] + llm_output["prompt_tokens"] + token_usage["total_tokens"] = ( + token_usage["completion_tokens"] + token_usage["prompt_tokens"] ) + llm_output = { + "token_usage": token_usage, + } + if result.llm_output and result.llm_output["model_name"]: + llm_output["model_name"] = result.llm_output["model_name"] + except ImportError: llm_output = None llm_outputs.append(llm_output) @@ -176,6 +182,32 @@ class BaseLLM(BaseLanguageModel, ABC): prompt_strings = [p.to_string() for p in prompts] return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks) + def _generate_helper(self, prompts, stop, run_managers, new_arg_supported): + try: + output = ( + self._generate( + prompts, + stop=stop, + run_manager=run_managers[0] if run_managers else None, + ) + if new_arg_supported + else self._generate(prompts, stop=stop) + ) + except (KeyboardInterrupt, Exception) as e: + for run_manager in run_managers: + run_manager.on_llm_error(e) + raise e + flattened_outputs = ( + self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output] + ) + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] + return output + def generate( self, prompts: List[str], @@ -183,8 +215,6 @@ class BaseLLM(BaseLanguageModel, ABC): callbacks: Callbacks = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - # If string is passed in directly no errors will be raised but outputs will - # not make sense. if not isinstance(prompts, list): raise ValueError( "Argument 'prompts' is expected to be of type List[str], received" @@ -206,7 +236,6 @@ class BaseLLM(BaseLanguageModel, ABC): "run_manager" ) if langchain.llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." @@ -214,27 +243,9 @@ class BaseLLM(BaseLanguageModel, ABC): run_managers = callback_manager.on_llm_start( {"name": self.__class__.__name__}, prompts, invocation_params=params ) - try: - output = ( - self._generate( - prompts, - stop=stop, - run_manager=run_managers[0] if run_managers else None, - ) - if new_arg_supported - else self._generate(prompts, stop=stop) - ) - except (KeyboardInterrupt, Exception) as e: - for run_manager in run_managers: - run_manager.on_llm_error(e) - raise e - flattened_outputs = self._flatten_llm_result(prompts, output) - for manager, flattened_output in zip(run_managers, flattened_outputs): - manager.on_llm_end(flattened_output) - if run_managers: - output.run = [ - RunInfo(run_id=run_manager.run_id) for run_manager in run_managers - ] + output = self._generate_helper( + prompts, stop, run_managers, new_arg_supported + ) return output if len(missing_prompts) > 0: run_managers = callback_manager.on_llm_start( @@ -242,37 +253,56 @@ class BaseLLM(BaseLanguageModel, ABC): missing_prompts, invocation_params=params, ) - try: - new_results = ( - self._generate( - missing_prompts, - stop=stop, - run_manager=run_managers[0] if run_managers else None, - ) - if new_arg_supported - else self._generate(missing_prompts, stop=stop) - ) - except (KeyboardInterrupt, Exception) as e: - for run_manager in run_managers: - run_manager.on_llm_error(e) - raise e - flattened_outputs = self._flatten_llm_result(missing_prompts, new_results) - for manager, flattened_output in zip(run_managers, flattened_outputs): - manager.on_llm_end(flattened_output) + new_results = self._generate_helper( + missing_prompts, stop, run_managers, new_arg_supported + ) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) - run_info = None - if run_managers: - run_info = [ - RunInfo(run_id=run_manager.run_id) for run_manager in run_managers - ] + run_info = ( + [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] + if run_managers + else None + ) else: llm_output = {} run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=llm_output, run=run_info) + async def _agenerate_helper(self, prompts, stop, run_managers, new_arg_supported): + try: + output = ( + await self._agenerate( + prompts, + stop=stop, + run_manager=run_managers[0] if run_managers else None, + ) + if new_arg_supported + else await self._agenerate(prompts, stop=stop) + ) + except (KeyboardInterrupt, Exception) as e: + await asyncio.gather( + *[run_manager.on_llm_error(e) for run_manager in run_managers] + ) + raise e + flattened_outputs = ( + self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output] + ) + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] + return output + async def agenerate( self, prompts: List[str], @@ -296,7 +326,6 @@ class BaseLLM(BaseLanguageModel, ABC): "run_manager" ) if langchain.llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." @@ -304,32 +333,9 @@ class BaseLLM(BaseLanguageModel, ABC): run_managers = await callback_manager.on_llm_start( {"name": self.__class__.__name__}, prompts, invocation_params=params ) - try: - output = ( - await self._agenerate( - prompts, stop=stop, run_manager=run_managers[0] - ) - if new_arg_supported - else await self._agenerate(prompts, stop=stop) - ) - except (KeyboardInterrupt, Exception) as e: - await asyncio.gather( - *[run_manager.on_llm_error(e) for run_manager in run_managers] - ) - raise e - flattened_outputs = self._flatten_llm_result(prompts, output) - await asyncio.gather( - *[ - run_manager.on_llm_end(flattened_output) - for run_manager, flattened_output in zip( - run_managers, flattened_outputs - ) - ] + output = await self._agenerate_helper( + prompts, stop, run_managers, new_arg_supported ) - if run_managers: - output.run = [ - RunInfo(run_id=run_manager.run_id) for run_manager in run_managers - ] return output if len(missing_prompts) > 0: run_managers = await callback_manager.on_llm_start( @@ -337,36 +343,17 @@ class BaseLLM(BaseLanguageModel, ABC): missing_prompts, invocation_params=params, ) - try: - new_results = ( - await self._agenerate( - missing_prompts, stop=stop, run_manager=run_managers[0] - ) - if new_arg_supported - else await self._agenerate(missing_prompts, stop=stop) - ) - except (KeyboardInterrupt, Exception) as e: - await asyncio.gather( - *[run_manager.on_llm_error(e) for run_manager in run_managers] - ) - raise e - flattened_outputs = self._flatten_llm_result(missing_prompts, new_results) - await asyncio.gather( - *[ - run_manager.on_llm_end(flattened_output) - for run_manager, flattened_output in zip( - run_managers, flattened_outputs - ) - ] + new_results = await self._agenerate_helper( + missing_prompts, stop, run_managers, new_arg_supported ) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) - run_info = None - if run_managers: - run_info = [ - RunInfo(run_id=run_manager.run_id) for run_manager in run_managers - ] + run_info = ( + [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] + if run_managers + else None + ) else: llm_output = {} run_info = None