mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 19:18:48 +00:00
refactor
This commit is contained in:
@@ -120,15 +120,21 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
llm_outputs = []
|
||||
for prompt, gens in zip(prompts, result.generations):
|
||||
try:
|
||||
llm_output = {
|
||||
token_usage = {
|
||||
"completion_tokens": self.get_num_tokens(
|
||||
"".join([gen.text for gen in gens])
|
||||
),
|
||||
"prompt_tokens": self.get_num_tokens(prompt),
|
||||
}
|
||||
llm_output["total_tokens"] = (
|
||||
llm_output["completion_tokens"] + llm_output["prompt_tokens"]
|
||||
token_usage["total_tokens"] = (
|
||||
token_usage["completion_tokens"] + token_usage["prompt_tokens"]
|
||||
)
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
if result.llm_output and result.llm_output["model_name"]:
|
||||
llm_output["model_name"] = result.llm_output["model_name"]
|
||||
|
||||
except ImportError:
|
||||
llm_output = None
|
||||
llm_outputs.append(llm_output)
|
||||
@@ -176,6 +182,32 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
|
||||
|
||||
def _generate_helper(self, prompts, stop, run_managers, new_arg_supported):
|
||||
try:
|
||||
output = (
|
||||
self._generate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = (
|
||||
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
|
||||
)
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -183,8 +215,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
@@ -206,7 +236,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
@@ -214,27 +243,9 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
self._generate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = self._flatten_llm_result(prompts, output)
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
output = self._generate_helper(
|
||||
prompts, stop, run_managers, new_arg_supported
|
||||
)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
@@ -242,37 +253,56 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
self._generate(
|
||||
missing_prompts,
|
||||
stop=stop,
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(missing_prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
new_results = self._generate_helper(
|
||||
missing_prompts, stop, run_managers, new_arg_supported
|
||||
)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_managers:
|
||||
run_info = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
run_info = (
|
||||
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
|
||||
if run_managers
|
||||
else None
|
||||
)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
async def _agenerate_helper(self, prompts, stop, run_managers, new_arg_supported):
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = (
|
||||
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -296,7 +326,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
@@ -304,32 +333,9 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(
|
||||
prompts, stop=stop, run_manager=run_managers[0]
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = self._flatten_llm_result(prompts, output)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
output = await self._agenerate_helper(
|
||||
prompts, stop, run_managers, new_arg_supported
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
@@ -337,36 +343,17 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
await self._agenerate(
|
||||
missing_prompts, stop=stop, run_manager=run_managers[0]
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(missing_prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
new_results = await self._agenerate_helper(
|
||||
missing_prompts, stop, run_managers, new_arg_supported
|
||||
)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_managers:
|
||||
run_info = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
run_info = (
|
||||
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
|
||||
if run_managers
|
||||
else None
|
||||
)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
|
||||
Reference in New Issue
Block a user