Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
ede12b3a43 add option to return prompt 2023-05-16 20:24:19 -07:00
4 changed files with 35 additions and 15 deletions

View File

@@ -36,7 +36,7 @@ class _ResponseChain(LLMChain):
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[Sequence[str], Sequence[float]]:
llm_result = self.generate([_input], run_manager=run_manager)
_, llm_result = self.generate([_input], run_manager=run_manager)
return self._extract_tokens_and_log_probs(llm_result.generations[0])
@abstractmethod

View File

@@ -53,7 +53,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
def embed_query(self, text: str) -> List[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.llm_chain.input_keys[0]
result = self.llm_chain.generate([{var_name: text}])
_, result = self.llm_chain.generate([{var_name: text}])
documents = [generation.text for generation in result.generations[0]]
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)

View File

@@ -38,6 +38,7 @@ class LLMChain(Chain):
"""Prompt object to use."""
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
return_formatted_prompt: bool = False
class Config:
"""Configuration for this pydantic object."""
@@ -59,37 +60,45 @@ class LLMChain(Chain):
:meta private:
"""
return [self.output_key]
if self.return_formatted_prompt:
return [self.output_key, "formatted_prompt"]
else:
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
prompts, response = self.generate([inputs], run_manager=run_manager)
output = self.create_outputs(response)[0]
if self.return_formatted_prompt:
output["formatted_prompt"] = prompts[0]
return output
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
) -> Tuple[List[PromptValue], LLMResult]:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
result = self.llm.generate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
)
return prompts, result
async def agenerate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult:
) -> Tuple[List[PromptValue], LLMResult]:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
return await self.llm.agenerate_prompt(
result = await self.llm.agenerate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
)
return prompts, result
def prep_prompts(
self,
@@ -151,12 +160,15 @@ class LLMChain(Chain):
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
prompts, response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
if self.return_formatted_prompt:
for i, o in enumerate(outputs):
o["formatted_prompt"] = prompts[i]
return outputs
async def aapply(
@@ -171,15 +183,20 @@ class LLMChain(Chain):
{"input_list": input_list},
)
try:
response = await self.agenerate(input_list, run_manager=run_manager)
prompts, response = await self.agenerate(
input_list, run_manager=run_manager
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
await run_manager.on_chain_end({"outputs": outputs})
if self.return_formatted_prompt:
for i, o in enumerate(outputs):
o["formatted_prompt"] = prompts[i]
return outputs
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
def create_outputs(self, response: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
return [
# Get the text of the top generated string.
@@ -192,8 +209,11 @@ class LLMChain(Chain):
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = await self.agenerate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
prompts, response = await self.agenerate([inputs], run_manager=run_manager)
output = self.create_outputs(response)[0]
if self.return_formatted_prompt:
output["formatted_prompt"] = prompts[0]
return output
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.

View File

@@ -52,7 +52,7 @@ class QAGenerationChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, List]:
docs = self.text_splitter.create_documents([inputs[self.input_key]])
results = self.llm_chain.generate(
_, results = self.llm_chain.generate(
[{"text": d.page_content} for d in docs], run_manager=run_manager
)
qa = [json.loads(res[0].text) for res in results.generations]