mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
community: update Replicate to work with official models (#20633)
Description: you don't need to pass a version for Replicate official models. That was broken on LangChain until now! You can now run: ``` llm = Replicate( model="meta/meta-llama-3-8b-instruct", model_kwargs={"temperature": 0.75, "max_length": 500, "top_p": 1}, ) prompt = """ User: Answer the following yes/no question by reasoning step by step. Can a dog drive a car? Assistant: """ llm(prompt) ``` I've updated the replicate.ipynb to reflect that. twitter: @charliebholtz --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -44,7 +44,7 @@ class Replicate(LLM):
|
||||
replicate_api_token: Optional[str] = None
|
||||
prompt_key: Optional[str] = None
|
||||
version_obj: Any = Field(default=None, exclude=True)
|
||||
"""Optionally pass in the model version object during initialization to avoid
|
||||
"""Optionally pass in the model version object during initialization to avoid
|
||||
having to make an extra API call to retrieve it during streaming. NOTE: not
|
||||
serializable, is excluded from serialization.
|
||||
"""
|
||||
@@ -197,9 +197,13 @@ class Replicate(LLM):
|
||||
|
||||
# get the model and version
|
||||
if self.version_obj is None:
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = replicate_python.models.get(model_str)
|
||||
self.version_obj = model.versions.get(version_str)
|
||||
if ":" in self.model:
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = replicate_python.models.get(model_str)
|
||||
self.version_obj = model.versions.get(version_str)
|
||||
else:
|
||||
model = replicate_python.models.get(self.model)
|
||||
self.version_obj = model.latest_version
|
||||
|
||||
if self.prompt_key is None:
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
@@ -217,6 +221,11 @@ class Replicate(LLM):
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
return replicate_python.predictions.create(
|
||||
version=self.version_obj, input=input_
|
||||
)
|
||||
|
||||
# if it's an official model
|
||||
if ":" not in self.model:
|
||||
return replicate_python.models.predictions.create(self.model, input=input_)
|
||||
else:
|
||||
return replicate_python.predictions.create(
|
||||
version=self.version_obj, input=input_
|
||||
)
|
||||
|
Reference in New Issue
Block a user