mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 02:58:06 +00:00
Allow replicate prompt key to be manually specified (#10516)
Since inference logic doesn't work for all models Co-authored-by: Taqi Jaffri <tjaffri@gmail.com> Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
commit
eaf916f999
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@ -33,6 +33,7 @@ class Replicate(LLM):
|
||||
input: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
replicate_api_token: Optional[str] = None
|
||||
prompt_key: Optional[str] = None
|
||||
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to stream the results."""
|
||||
@ -81,7 +82,7 @@ class Replicate(LLM):
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self.model,
|
||||
@ -114,15 +115,18 @@ class Replicate(LLM):
|
||||
model = replicate_python.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
inputs = {first_input_name: prompt, **self.input}
|
||||
if not self.prompt_key:
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
|
||||
self.prompt_key = input_properties[0][0]
|
||||
|
||||
inputs: Dict = {self.prompt_key: prompt, **self.input}
|
||||
|
||||
prediction = replicate_python.predictions.create(
|
||||
version=version, input={**inputs, **kwargs}
|
||||
|
Loading…
Reference in New Issue
Block a user