This commit is contained in:
Ankush Gola
2023-06-07 23:36:36 -07:00
parent aea44eb246
commit 70f0f337df

View File

@@ -120,15 +120,21 @@ class BaseLLM(BaseLanguageModel, ABC):
llm_outputs = []
for prompt, gens in zip(prompts, result.generations):
try:
llm_output = {
token_usage = {
"completion_tokens": self.get_num_tokens(
"".join([gen.text for gen in gens])
),
"prompt_tokens": self.get_num_tokens(prompt),
}
llm_output["total_tokens"] = (
llm_output["completion_tokens"] + llm_output["prompt_tokens"]
token_usage["total_tokens"] = (
token_usage["completion_tokens"] + token_usage["prompt_tokens"]
)
llm_output = {
"token_usage": token_usage,
}
if result.llm_output and result.llm_output["model_name"]:
llm_output["model_name"] = result.llm_output["model_name"]
except ImportError:
llm_output = None
llm_outputs.append(llm_output)
@@ -176,6 +182,32 @@ class BaseLLM(BaseLanguageModel, ABC):
prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
def _generate_helper(self, prompts, stop, run_managers, new_arg_supported):
try:
output = (
self._generate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = (
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
)
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate(
self,
prompts: List[str],
@@ -183,8 +215,6 @@ class BaseLLM(BaseLanguageModel, ABC):
callbacks: Callbacks = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
@@ -206,7 +236,6 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
@@ -214,27 +243,9 @@ class BaseLLM(BaseLanguageModel, ABC):
run_managers = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
)
try:
output = (
self._generate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = self._flatten_llm_result(prompts, output)
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
output = self._generate_helper(
prompts, stop, run_managers, new_arg_supported
)
return output
if len(missing_prompts) > 0:
run_managers = callback_manager.on_llm_start(
@@ -242,37 +253,56 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts,
invocation_params=params,
)
try:
new_results = (
self._generate(
missing_prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
)
if new_arg_supported
else self._generate(missing_prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
new_results = self._generate_helper(
missing_prompts, stop, run_managers, new_arg_supported
)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_managers:
run_info = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
async def _agenerate_helper(self, prompts, stop, run_managers, new_arg_supported):
try:
output = (
await self._agenerate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = (
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
async def agenerate(
self,
prompts: List[str],
@@ -296,7 +326,6 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
@@ -304,32 +333,9 @@ class BaseLLM(BaseLanguageModel, ABC):
run_managers = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
)
try:
output = (
await self._agenerate(
prompts, stop=stop, run_manager=run_managers[0]
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = self._flatten_llm_result(prompts, output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
output = await self._agenerate_helper(
prompts, stop, run_managers, new_arg_supported
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
if len(missing_prompts) > 0:
run_managers = await callback_manager.on_llm_start(
@@ -337,36 +343,17 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts,
invocation_params=params,
)
try:
new_results = (
await self._agenerate(
missing_prompts, stop=stop, run_manager=run_managers[0]
)
if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, new_arg_supported
)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_managers:
run_info = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None