Allow replicate prompt key to be manually specified (#10516)

Since inference logic doesn't work for all models

Co-authored-by: Taqi Jaffri <tjaffri@gmail.com>
Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
Bagatur 2023-09-12 15:52:13 -07:00 committed by GitHub
commit eaf916f999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -33,6 +33,7 @@ class Replicate(LLM):
input: Dict[str, Any] = Field(default_factory=dict) input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None
streaming: bool = Field(default=False) streaming: bool = Field(default=False)
"""Whether to stream the results.""" """Whether to stream the results."""
@ -81,7 +82,7 @@ class Replicate(LLM):
return values return values
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return { return {
"model": self.model, "model": self.model,
@ -114,6 +115,7 @@ class Replicate(LLM):
model = replicate_python.models.get(model_str) model = replicate_python.models.get(model_str)
version = model.versions.get(version_str) version = model.versions.get(version_str)
if not self.prompt_key:
# sort through the openapi schema to get the name of the first input # sort through the openapi schema to get the name of the first input
input_properties = sorted( input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][ version.openapi_schema["components"]["schemas"]["Input"][
@ -121,8 +123,10 @@ class Replicate(LLM):
].items(), ].items(),
key=lambda item: item[1].get("x-order", 0), key=lambda item: item[1].get("x-order", 0),
) )
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input} self.prompt_key = input_properties[0][0]
inputs: Dict = {self.prompt_key: prompt, **self.input}
prediction = replicate_python.predictions.create( prediction = replicate_python.predictions.create(
version=version, input={**inputs, **kwargs} version=version, input={**inputs, **kwargs}