diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index a454b40c9bb..22840dcb5b2 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -164,7 +164,16 @@ class BaseOpenAI(BaseLLM, BaseModel): for i, prompt in enumerate(prompts): sub_choices = choices[i * self.n : (i + 1) * self.n] generations.append( - [Generation(text=choice["text"]) for choice in sub_choices] + [ + Generation( + text=choice["text"], + generation_info=dict( + finish_reason=choice["finish_reason"], + logprobs=choice["logprobs"], + ), + ) + for choice in sub_choices + ] ) return LLMResult( generations=generations, llm_output={"token_usage": token_usage} diff --git a/langchain/schema.py b/langchain/schema.py index a4b4e6267ce..6bb53eb5443 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -1,6 +1,6 @@ """Common schema objects.""" -from typing import List, NamedTuple, Optional +from typing import Any, Dict, List, NamedTuple, Optional class AgentAction(NamedTuple): @@ -23,6 +23,10 @@ class Generation(NamedTuple): text: str """Generated text output.""" + + generation_info: Optional[Dict[str, Any]] = None + """Raw generation info response from the provider""" + """May include things like reason for finishing (e.g. in OpenAI)""" # TODO: add log probs