mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	OpenAI LLM: update modelname_to_contextsize with new models (#2843)
				
					
				
			Token counts pulled from https://openai.com/pricing
This commit is contained in:
		@@ -476,13 +476,6 @@ class BaseOpenAI(BaseLLM):
 | 
				
			|||||||
    def modelname_to_contextsize(self, modelname: str) -> int:
 | 
					    def modelname_to_contextsize(self, modelname: str) -> int:
 | 
				
			||||||
        """Calculate the maximum number of tokens possible to generate for a model.
 | 
					        """Calculate the maximum number of tokens possible to generate for a model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        text-davinci-003: 4,097 tokens
 | 
					 | 
				
			||||||
        text-curie-001: 2,048 tokens
 | 
					 | 
				
			||||||
        text-babbage-001: 2,048 tokens
 | 
					 | 
				
			||||||
        text-ada-001: 2,048 tokens
 | 
					 | 
				
			||||||
        code-davinci-002: 8,000 tokens
 | 
					 | 
				
			||||||
        code-cushman-001: 2,048 tokens
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            modelname: The modelname we want to know the context size for.
 | 
					            modelname: The modelname we want to know the context size for.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -494,20 +487,37 @@ class BaseOpenAI(BaseLLM):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                max_tokens = openai.modelname_to_contextsize("text-davinci-003")
 | 
					                max_tokens = openai.modelname_to_contextsize("text-davinci-003")
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if modelname == "text-davinci-003":
 | 
					        model_token_mapping = {
 | 
				
			||||||
            return 4097
 | 
					            "gpt-4": 8192,
 | 
				
			||||||
        elif modelname == "text-curie-001":
 | 
					            "gpt-4-0314": 8192,
 | 
				
			||||||
            return 2048
 | 
					            "gpt-4-32k": 32768,
 | 
				
			||||||
        elif modelname == "text-babbage-001":
 | 
					            "gpt-4-32k-0314": 32768,
 | 
				
			||||||
            return 2048
 | 
					            "gpt-3.5-turbo": 4096,
 | 
				
			||||||
        elif modelname == "text-ada-001":
 | 
					            "gpt-3.5-turbo-0301": 4096,
 | 
				
			||||||
            return 2048
 | 
					            "text-ada-001": 2049,
 | 
				
			||||||
        elif modelname == "code-davinci-002":
 | 
					            "ada": 2049,
 | 
				
			||||||
            return 8000
 | 
					            "text-babbage-001": 2040,
 | 
				
			||||||
        elif modelname == "code-cushman-001":
 | 
					            "babbage": 2049,
 | 
				
			||||||
            return 2048
 | 
					            "text-curie-001": 2049,
 | 
				
			||||||
        else:
 | 
					            "curie": 2049,
 | 
				
			||||||
            return 4097
 | 
					            "davinci": 2049,
 | 
				
			||||||
 | 
					            "text-davinci-003": 4097,
 | 
				
			||||||
 | 
					            "text-davinci-002": 4097,
 | 
				
			||||||
 | 
					            "code-davinci-002": 8001,
 | 
				
			||||||
 | 
					            "code-davinci-001": 8001,
 | 
				
			||||||
 | 
					            "code-cushman-002": 2048,
 | 
				
			||||||
 | 
					            "code-cushman-001": 2048,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        context_size = model_token_mapping.get(modelname, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if context_size is None:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
 | 
				
			||||||
 | 
					                "Known models are: " + ", ".join(model_token_mapping.keys())
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return context_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def max_tokens_for_prompt(self, prompt: str) -> int:
 | 
					    def max_tokens_for_prompt(self, prompt: str) -> int:
 | 
				
			||||||
        """Calculate the maximum number of tokens possible to generate for a prompt.
 | 
					        """Calculate the maximum number of tokens possible to generate for a prompt.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -211,3 +211,14 @@ async def test_openai_chat_async_streaming_callback() -> None:
 | 
				
			|||||||
    result = await llm.agenerate(["Write me a sentence with 100 words."])
 | 
					    result = await llm.agenerate(["Write me a sentence with 100 words."])
 | 
				
			||||||
    assert callback_handler.llm_streams != 0
 | 
					    assert callback_handler.llm_streams != 0
 | 
				
			||||||
    assert isinstance(result, LLMResult)
 | 
					    assert isinstance(result, LLMResult)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_openai_modelname_to_contextsize_valid() -> None:
 | 
				
			||||||
 | 
					    """Test model name to context size on a valid model."""
 | 
				
			||||||
 | 
					    assert OpenAI().modelname_to_contextsize("davinci") == 2049
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_openai_modelname_to_contextsize_invalid() -> None:
 | 
				
			||||||
 | 
					    """Test model name to context size on an invalid model."""
 | 
				
			||||||
 | 
					    with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					        OpenAI().modelname_to_contextsize("foobar")
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user