community: update Replicate to work with official models (#20633)

Description: you don't need to pass a version for Replicate official
models. That was broken on LangChain until now!

You can now run: 

```
llm = Replicate(
    model="meta/meta-llama-3-8b-instruct",
    model_kwargs={"temperature": 0.75, "max_length": 500, "top_p": 1},
)
prompt = """
User: Answer the following yes/no question by reasoning step by step. Can a dog drive a car?
Assistant:
"""
llm(prompt)
```

I've updated the replicate.ipynb to reflect that.

twitter: @charliebholtz

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Charlie Holtz
2024-04-18 18:43:40 -07:00
committed by GitHub
parent dd5139e304
commit 1cbab0ebda
2 changed files with 41 additions and 28 deletions

View File

@@ -44,7 +44,7 @@ class Replicate(LLM):
replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None
version_obj: Any = Field(default=None, exclude=True)
"""Optionally pass in the model version object during initialization to avoid
"""Optionally pass in the model version object during initialization to avoid
having to make an extra API call to retrieve it during streaming. NOTE: not
serializable, is excluded from serialization.
"""
@@ -197,9 +197,13 @@ class Replicate(LLM):
# get the model and version
if self.version_obj is None:
model_str, version_str = self.model.split(":")
model = replicate_python.models.get(model_str)
self.version_obj = model.versions.get(version_str)
if ":" in self.model:
model_str, version_str = self.model.split(":")
model = replicate_python.models.get(model_str)
self.version_obj = model.versions.get(version_str)
else:
model = replicate_python.models.get(self.model)
self.version_obj = model.latest_version
if self.prompt_key is None:
# sort through the openapi schema to get the name of the first input
@@ -217,6 +221,11 @@ class Replicate(LLM):
**self.model_kwargs,
**kwargs,
}
return replicate_python.predictions.create(
version=self.version_obj, input=input_
)
# if it's an official model
if ":" not in self.model:
return replicate_python.models.predictions.create(self.model, input=input_)
else:
return replicate_python.predictions.create(
version=self.version_obj, input=input_
)