From aff44d0a98f2aefd38390c8b6cfea1712a75c4a8 Mon Sep 17 00:00:00 2001 From: Mario Kostelac Date: Fri, 17 Mar 2023 05:55:55 +0100 Subject: [PATCH] (OpenAI) Add model_name to LLMResult.llm_output (#1713) Given that different models have very different latencies and pricings, it's benefitial to pass the information about the model that generated the response. Such information allows implementing custom callback managers and track usage and price per model. Addresses https://github.com/hwchase17/langchain/issues/1557. --- langchain/llms/openai.py | 17 ++++++++++++----- tests/integration_tests/llms/test_openai.py | 8 ++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index adc0c55edd1..cce51bc443b 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -362,9 +362,8 @@ class BaseOpenAI(BaseLLM, BaseModel): for choice in sub_choices ] ) - return LLMResult( - generations=generations, llm_output={"token_usage": token_usage} - ) + llm_output = {"token_usage": token_usage, "model_name": self.model_name} + return LLMResult(generations=generations, llm_output=llm_output) def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: """Call OpenAI with streaming flag and return the resulting generator. @@ -643,11 +642,15 @@ class OpenAIChat(BaseLLM, BaseModel): ) else: full_response = completion_with_retry(self, messages=messages, **params) + llm_output = { + "token_usage": full_response["usage"], + "model_name": self.model_name, + } return LLMResult( generations=[ [Generation(text=full_response["choices"][0]["message"]["content"])] ], - llm_output={"token_usage": full_response["usage"]}, + llm_output=llm_output, ) async def _agenerate( @@ -679,11 +682,15 @@ class OpenAIChat(BaseLLM, BaseModel): full_response = await acompletion_with_retry( self, messages=messages, **params ) + llm_output = { + "token_usage": full_response["usage"], + "model_name": self.model_name, + } return LLMResult( generations=[ [Generation(text=full_response["choices"][0]["message"]["content"])] ], - llm_output={"token_usage": full_response["usage"]}, + llm_output=llm_output, ) @property diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 4f80565e1ef..1ada0ca6094 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -35,6 +35,14 @@ def test_openai_extra_kwargs() -> None: OpenAI(foo=3, model_kwargs={"foo": 2}) +def test_openai_llm_output_contains_model_name() -> None: + """Test llm_output contains model_name.""" + llm = OpenAI(max_tokens=10) + llm_result = llm.generate(["Hello, how are you?"]) + assert llm_result.llm_output is not None + assert llm_result.llm_output["model_name"] == llm.model_name + + def test_openai_stop_valid() -> None: """Test openai stop logic on valid configuration.""" query = "write an ordered list of five items"