Compare commits

...

14 Commits

Author SHA1 Message Date
Bagatur
ed8753e7ce wip 2023-08-27 15:16:52 -07:00
Bagatur
0d01cede03 bump 274 (#9805) 2023-08-26 12:16:26 -07:00
Vikas Sheoran
63921e327d docs: Fix a spelling mistake in adding_memory.ipynb (#9794)
# Description 
This pull request fixes a small spelling mistake found while reading
docs.
2023-08-26 12:04:43 -07:00
Rosário P. Fernandes
aab01b55db typo: funtions --> functions (#9784)
Minor typo in the extractions use-case
2023-08-26 11:47:47 -07:00
Nikhil Suresh
0da5803f5a fixed regex to match sources for all cases, also includes source (#9775)
- Description: Updated the regex to handle all the different cases for
string matching (SOURCES, sources, Sources),
  - Issue: https://github.com/langchain-ai/langchain/issues/9774
  - Dependencies: N/A
2023-08-25 18:10:33 -07:00
Sam Partee
a28eea5767 Redis metadata filtering and specification, index customization (#8612)
### Description

The previous Redis implementation did not allow for the user to specify
the index configuration (i.e. changing the underlying algorithm) or add
additional metadata to use for querying (i.e. hybrid or "filtered"
search).

This PR introduces the ability to specify custom index attributes and
metadata attributes as well as use that metadata in filtered queries.
Overall, more structure was introduced to the Redis implementation that
should allow for easier maintainability moving forward.

# New Features

The following features are now available with the Redis integration into
Langchain

## Index schema generation

The schema for the index will now be automatically generated if not
specified by the user. For example, the data above has the multiple
metadata categories. The the following example

```python

from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.redis import Redis

embeddings = OpenAIEmbeddings()


rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users"
)
```

Loading the data in through this and the other ``from_documents`` and
``from_texts`` methods will now generate index schema in Redis like the
following.

view index schema with the ``redisvl`` tool. [link](redisvl.com)

```bash
$ rvl index info -i users
```


Index Information:
| Index Name | Storage Type | Prefixes | Index Options | Indexing |

|--------------|----------------|---------------|-----------------|------------|
| users | HASH | ['doc:users'] | [] | 0 |
Index Fields:
| Name | Attribute | Type | Field Option | Option Value |

|----------------|----------------|---------|----------------|----------------|
| user | user | TEXT | WEIGHT | 1 |
| job | job | TEXT | WEIGHT | 1 |
| credit_score | credit_score | TEXT | WEIGHT | 1 |
| content | content | TEXT | WEIGHT | 1 |
| age | age | NUMERIC | | |
| content_vector | content_vector | VECTOR | | |


### Custom Metadata specification

The metadata schema generation has the following rules
1. All text fields are indexed as text fields.
2. All numeric fields are index as numeric fields.

If you would like to have a text field as a tag field, users can specify
overrides like the following for the example data

```python

# this can also be a path to a yaml file
index_schema = {
    "text": [{"name": "user"}, {"name": "job"}],
    "tag": [{"name": "credit_score"}],
    "numeric": [{"name": "age"}],
}

rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users"
)
```
This will change the index specification to 

Index Information:
| Index Name | Storage Type | Prefixes | Index Options | Indexing |

|--------------|----------------|----------------|-----------------|------------|
| users2 | HASH | ['doc:users2'] | [] | 0 |
Index Fields:
| Name | Attribute | Type | Field Option | Option Value |

|----------------|----------------|---------|----------------|----------------|
| user | user | TEXT | WEIGHT | 1 |
| job | job | TEXT | WEIGHT | 1 |
| content | content | TEXT | WEIGHT | 1 |
| credit_score | credit_score | TAG | SEPARATOR | , |
| age | age | NUMERIC | | |
| content_vector | content_vector | VECTOR | | |


and throw a warning to the user (log output) that the generated schema
does not match the specified schema.

```text
index_schema does not match generated schema from metadata.
index_schema: {'text': [{'name': 'user'}, {'name': 'job'}], 'tag': [{'name': 'credit_score'}], 'numeric': [{'name': 'age'}]}
generated_schema: {'text': [{'name': 'user'}, {'name': 'job'}, {'name': 'credit_score'}], 'numeric': [{'name': 'age'}]}
```

As long as this is on purpose,  this is fine.

The schema can be defined as a yaml file or a dictionary

```yaml

text:
  - name: user
  - name: job
tag:
  - name: credit_score
numeric:
  - name: age

```

and you pass in a path like

```python
rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users3",
    index_schema=Path("sample1.yml").resolve()
)
```

Which will create the same schema as defined in the dictionary example


Index Information:
| Index Name | Storage Type | Prefixes | Index Options | Indexing |

|--------------|----------------|----------------|-----------------|------------|
| users3 | HASH | ['doc:users3'] | [] | 0 |
Index Fields:
| Name | Attribute | Type | Field Option | Option Value |

|----------------|----------------|---------|----------------|----------------|
| user | user | TEXT | WEIGHT | 1 |
| job | job | TEXT | WEIGHT | 1 |
| content | content | TEXT | WEIGHT | 1 |
| credit_score | credit_score | TAG | SEPARATOR | , |
| age | age | NUMERIC | | |
| content_vector | content_vector | VECTOR | | |



### Custom Vector Indexing Schema

Users with large use cases may want to change how they formulate the
vector index created by Langchain

To utilize all the features of Redis for vector database use cases like
this, you can now do the following to pass in index attribute modifiers
like changing the indexing algorithm to HNSW.

```python
vector_schema = {
    "algorithm": "HNSW"
}

rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users3",
    vector_schema=vector_schema
)

```

A more complex example may look like

```python
vector_schema = {
    "algorithm": "HNSW",
    "ef_construction": 200,
    "ef_runtime": 20
}

rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users3",
    vector_schema=vector_schema
)
```

All names correspond to the arguments you would set if using Redis-py or
RedisVL. (put in doc link later)


### Better Querying

Both vector queries and Range (limit) queries are now available and
metadata is returned by default. The outputs are shown.

```python
>>> query = "foo"
>>> results = rds.similarity_search(query, k=1)
>>> print(results)
[Document(page_content='foo', metadata={'user': 'derrick', 'job': 'doctor', 'credit_score': 'low', 'age': '14', 'id': 'doc:users:657a47d7db8b447e88598b83da879b9d', 'score': '7.15255737305e-07'})]

>>> results = rds.similarity_search_with_score(query, k=1, return_metadata=False)
>>> print(results) # no metadata, but with scores
[(Document(page_content='foo', metadata={}), 7.15255737305e-07)]

>>> results = rds.similarity_search_limit_score(query, k=6, score_threshold=0.0001)
>>> print(len(results)) # range query (only above threshold even if k is higher)
4
```

### Custom metadata filtering

A big advantage of Redis in this space is being able to do filtering on
data stored alongside the vector itself. With the example above, the
following is now possible in langchain. The equivalence operators are
overridden to describe a new expression language that mimic that of
[redisvl](redisvl.com). This allows for arbitrarily long sequences of
filters that resemble SQL commands that can be used directly with vector
queries and range queries.

There are two interfaces by which to do so and both are shown. 

```python

>>> from langchain.vectorstores.redis import RedisFilter, RedisNum, RedisText

>>> age_filter = RedisFilter.num("age") > 18
>>> age_filter = RedisNum("age") > 18 # equivalent
>>> results = rds.similarity_search(query, filter=age_filter)
>>> print(len(results))
3

>>> job_filter = RedisFilter.text("job") == "engineer" 
>>> job_filter = RedisText("job") == "engineer" # equivalent
>>> results = rds.similarity_search(query, filter=job_filter)
>>> print(len(results))
2

# fuzzy match text search
>>> job_filter = RedisFilter.text("job") % "eng*"
>>> results = rds.similarity_search(query, filter=job_filter)
>>> print(len(results))
2


# combined filters (AND)
>>> combined = age_filter & job_filter
>>> results = rds.similarity_search(query, filter=combined)
>>> print(len(results))
1

# combined filters (OR)
>>> combined = age_filter | job_filter
>>> results = rds.similarity_search(query, filter=combined)
>>> print(len(results))
4
```

All the above filter results can be checked against the data above.


### Other

  - Issue: #3967 
  - Dependencies: No added dependencies
  - Tag maintainer: @hwchase17 @baskaryan @rlancemartin 
  - Twitter handle: @sampartee

---------

Co-authored-by: Naresh Rangan <naresh.rangan0@walmart.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-08-25 17:22:50 -07:00
Anish Shah
fa0b8f3368 fix broken wandb link in debugging page (#9771)
- Description: Fix broken hyperlink in debugging page
2023-08-25 15:34:08 -07:00
Monami Sharma
12a373810c Fixing broken links to Moderation and Constitutional chain (#9768)
- Description: Fixing broken links for Moderation and Constitutional
chain
  - Issue: N/A
  - Twitter handle: MonamiSharma
2023-08-25 15:19:32 -07:00
nikhilkjha
d57d08fd01 Initial commit for comprehend moderator (#9665)
This PR implements a custom chain that wraps Amazon Comprehend API
calls. The custom chain is aimed to be used with LLM chains to provide
moderation capability that let’s you detect and redact PII, Toxic and
Intent content in the LLM prompt, or the LLM response. The
implementation accepts a configuration object to control what checks
will be performed on a LLM prompt and can be used in a variety of setups
using the LangChain expression language to not only detect the
configured info in chains, but also other constructs such as a
retriever.
The included sample notebook goes over the different configuration
options and how to use it with other chains.

###  Usage sample
```python
from langchain_experimental.comprehend_moderation import BaseModerationActions, BaseModerationFilters

moderation_config = { 
        "filters":[ 
                BaseModerationFilters.PII, 
                BaseModerationFilters.TOXICITY,
                BaseModerationFilters.INTENT
        ],
        "pii":{ 
                "action": BaseModerationActions.ALLOW, 
                "threshold":0.5, 
                "labels":["SSN"],
                "mask_character": "X"
        },
        "toxicity":{ 
                "action": BaseModerationActions.STOP, 
                "threshold":0.5
        },
        "intent":{ 
                "action": BaseModerationActions.STOP, 
                "threshold":0.5
        }
}

comp_moderation_with_config = AmazonComprehendModerationChain(
    moderation_config=moderation_config, #specify the configuration
    client=comprehend_client,            #optionally pass the Boto3 Client
    verbose=True
)

template = """Question: {question}

Answer:"""

prompt = PromptTemplate(template=template, input_variables=["question"])

responses = [
    "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", 
    "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here."
]
llm = FakeListLLM(responses=responses)

llm_chain = LLMChain(prompt=prompt, llm=llm)

chain = ( 
    prompt 
    | comp_moderation_with_config 
    | {llm_chain.input_keys[0]: lambda x: x['output'] }  
    | llm_chain 
    | { "input": lambda x: x['text'] } 
    | comp_moderation_with_config 
)

response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"})

print(response['output'])


```
### Output
```
> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii validation...
Found PII content..stopping..
The prompt contains PII entities and cannot be processed
```

---------

Co-authored-by: Piyush Jain <piyushjain@duck.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-08-25 15:11:27 -07:00
Lance Martin
4339d21cf1 Code LLaMA in code understanding use case (#9779)
Update Code Understanding use case doc w/ Code-llama.
2023-08-25 14:24:38 -07:00
William FH
1960ac8d25 token chunks (#9739)
Co-authored-by: Andrew <abatutin@gmail.com>
2023-08-25 12:52:07 -07:00
Lance Martin
2ab04a4e32 Update agent docs, move to use-case sub-directory (#9344)
Re-structure and add new agent page
2023-08-25 11:28:55 -07:00
Lance Martin
985873c497 Update RAG use case (move to ntbk) (#9340) 2023-08-25 11:27:27 -07:00
Harrison Chase
709a67d9bf multivector notebook (#9740) 2023-08-25 07:07:27 -07:00
82 changed files with 9115 additions and 2766 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -2,5 +2,5 @@
One of the key concerns with using LLMs is that they may generate harmful or unethical text. This is an area of active research in the field. Here we present some built-in chains inspired by this research, which are intended to make the outputs of LLMs safer.
- [Moderation chain](/docs/use_cases/safety/moderation): Explicitly check if any output text is harmful and flag it.
- [Constitutional chain](/docs/use_cases/safety/constitutional_chain): Prompt the model with a set of principles which should guide it's behavior.
- [Moderation chain](/docs/guides/safety/moderation): Explicitly check if any output text is harmful and flag it.
- [Constitutional chain](/docs/guides/safety/constitutional_chain): Prompt the model with a set of principles which should guide it's behavior.

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 236 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 177 KiB

View File

@@ -8,7 +8,7 @@ Here's a few different tools and functionalities to aid in debugging.
## Tracing
Platforms with tracing capabilities like [LangSmith](/docs/guides/langsmith/) and [WandB](/docs/ecosystem/integrations/agent_with_wandb_tracing) are the most comprehensive solutions for debugging. These platforms make it easy to not only log and visualize LLM apps, but also to actively debug, test and refine them.
Platforms with tracing capabilities like [LangSmith](/docs/guides/langsmith/) and [WandB](/docs/integrations/providers/wandb_tracing) are the most comprehensive solutions for debugging. These platforms make it easy to not only log and visualize LLM apps, but also to actively debug, test and refine them.
For anyone building production-grade LLM applications, we highly recommend using a platform like this.

File diff suppressed because it is too large Load Diff

View File

@@ -13,7 +13,10 @@
"\n",
"- smaller chunks: split a document into smaller chunks, and embed those (this is ParentDocumentRetriever)\n",
"- summary: create a summary for each document, embed that along with (or instead of) the document\n",
"- hypothetical questions: create hypothetical questions that each document would be appropriate to answer, embed those along with (or instead of) the document"
"- hypothetical questions: create hypothetical questions that each document would be appropriate to answer, embed those along with (or instead of) the document\n",
"\n",
"\n",
"Note that this also enables another method of adding embeddings - manually. This is great because you can explicitly add questions or queries that should lead to a document being recovered, giving you more control"
]
},
{
@@ -106,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "5d23247d",
"metadata": {},
"outputs": [],
@@ -122,7 +125,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "92ed5861",
"metadata": {},
"outputs": [],
@@ -133,17 +136,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "8afed60c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.', metadata={'doc_id': 'b4ca7817-e3fe-4103-ac81-574fb41439ef', 'source': '../../state_of_the_union.txt'})"
"Document(page_content='Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.', metadata={'doc_id': '10e9cbc0-4ba5-4d79-a09b-c033d1ba7b01', 'source': '../../state_of_the_union.txt'})"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -155,7 +158,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "3c9017f1",
"metadata": {},
"outputs": [
@@ -165,7 +168,7 @@
"9874"
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -187,7 +190,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 10,
"id": "1433dff4",
"metadata": {},
"outputs": [],
@@ -201,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"id": "35b30390",
"metadata": {},
"outputs": [],
@@ -216,17 +219,17 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"id": "41a2a738",
"metadata": {},
"outputs": [],
"source": [
"summaries = [chain.invoke(d) for d in docs]"
"summaries = chain.batch(docs, {\"max_concurrency\": 5})"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"id": "7ac5e4b1",
"metadata": {},
"outputs": [],
@@ -250,7 +253,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 15,
"id": "0d93309f",
"metadata": {},
"outputs": [],
@@ -260,7 +263,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 16,
"id": "6d5edf0d",
"metadata": {},
"outputs": [],
@@ -271,7 +274,20 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 17,
"id": "862ae920",
"metadata": {},
"outputs": [],
"source": [
"# # We can also add the original chunks to the vectorstore if we so want\n",
"# for i, doc in enumerate(docs):\n",
"# doc.metadata[id_key] = doc_ids[i]\n",
"# retriever.vectorstore.add_documents(docs)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "299232d6",
"metadata": {},
"outputs": [],
@@ -281,17 +297,17 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 19,
"id": "10e404c0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='The document discusses various topics and proposals put forth by the President in a State of the Union address. These include the nomination of a judge for the Supreme Court, securing the border and fixing the immigration system, advancing liberty and justice for women and LGBTQ+ individuals, passing bipartisan legislation, addressing the opioid epidemic and mental health issues, supporting veterans, and ending cancer. The President expresses optimism about the future of the country and emphasizes the strength of the American people.', metadata={'doc_id': '8c7a707d-615d-42d5-919d-bc5178dd1ae4'})"
"Document(page_content=\"The document is a transcript of a speech given by the President of the United States. The President discusses several important issues and initiatives, including the nomination of a Supreme Court Justice, border security and immigration reform, protecting women's rights, advancing LGBTQ+ equality, bipartisan legislation, addressing the opioid epidemic and mental health, supporting veterans, investigating the health effects of burn pits on military personnel, ending cancer, and the strength and resilience of the American people.\", metadata={'doc_id': '79fa2e9f-28d9-4372-8af3-2caf4f1de312'})"
]
},
"execution_count": 20,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -302,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"id": "e4cce5c2",
"metadata": {},
"outputs": [],
@@ -312,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 21,
"id": "c8570dbb",
"metadata": {},
"outputs": [
@@ -322,7 +338,7 @@
"9194"
]
},
"execution_count": 24,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -340,6 +356,203 @@
"\n",
"An LLM can also be used to generate a list of hypothetical questions that could be asked of a particular document. These questions can then be embedded"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "5219b085",
"metadata": {},
"outputs": [],
"source": [
"functions = [\n",
" {\n",
" \"name\": \"hypothetical_questions\",\n",
" \"description\": \"Generate hypothetical questions\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"questions\": {\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"string\"\n",
" },\n",
" },\n",
" },\n",
" \"required\": [\"questions\"]\n",
" }\n",
" }\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "523deb92",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers.openai_functions import JsonKeyOutputFunctionsParser\n",
"chain = (\n",
" {\"doc\": lambda x: x.page_content}\n",
" # Only asking for 3 hypothetical questions, but this could be adjusted\n",
" | ChatPromptTemplate.from_template(\"Generate a list of 3 hypothetical questions that the below document could be used to answer:\\n\\n{doc}\")\n",
" | ChatOpenAI(max_retries=0, model=\"gpt-4\").bind(functions=functions, function_call={\"name\": \"hypothetical_questions\"})\n",
" | JsonKeyOutputFunctionsParser(key_name=\"questions\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "11d30554",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[\"What was the author's initial impression of philosophy as a field of study, and how did it change when they got to college?\",\n",
" 'Why did the author decide to switch their focus to Artificial Intelligence (AI)?',\n",
" \"What led to the author's disillusionment with the field of AI as it was practiced at the time?\"]"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke(docs[0])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "3eb2e48c",
"metadata": {},
"outputs": [],
"source": [
"hypothetical_questions = chain.batch(docs, {\"max_concurrency\": 5})"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "b2cd6e75",
"metadata": {},
"outputs": [],
"source": [
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(\n",
" collection_name=\"hypo-questions\",\n",
" embedding_function=OpenAIEmbeddings()\n",
")\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()\n",
"id_key = \"doc_id\"\n",
"# The retriever (empty to start)\n",
"retriever = MultiVectorRetriever(\n",
" vectorstore=vectorstore, \n",
" docstore=store, \n",
" id_key=id_key,\n",
")\n",
"doc_ids = [str(uuid.uuid4()) for _ in docs]"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "18831b3b",
"metadata": {},
"outputs": [],
"source": [
"question_docs = []\n",
"for i, question_list in enumerate(hypothetical_questions):\n",
" question_docs.extend([Document(page_content=s,metadata={id_key: doc_ids[i]}) for s in question_list])"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "224b24c5",
"metadata": {},
"outputs": [],
"source": [
"retriever.vectorstore.add_documents(question_docs)\n",
"retriever.docstore.mset(list(zip(doc_ids, docs)))"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "7b442b90",
"metadata": {},
"outputs": [],
"source": [
"sub_docs = vectorstore.similarity_search(\"justice breyer\")"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "089b5ad0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content=\"What is the President's stance on immigration reform?\", metadata={'doc_id': '505d73e3-8350-46ec-a58e-3af032f04ab3'}),\n",
" Document(page_content=\"What is the President's stance on immigration reform?\", metadata={'doc_id': '1c9618f0-7660-4b4f-a37c-509cbbbf6dba'}),\n",
" Document(page_content=\"What is the President's stance on immigration reform?\", metadata={'doc_id': '82c08209-b904-46a8-9532-edd2380950b7'}),\n",
" Document(page_content='What measures is the President proposing to protect the rights of LGBTQ+ Americans?', metadata={'doc_id': '82c08209-b904-46a8-9532-edd2380950b7'})]"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sub_docs"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "7594b24e",
"metadata": {},
"outputs": [],
"source": [
"retrieved_docs = retriever.get_relevant_documents(\"justice breyer\")"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "4c120c65",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9194"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(retrieved_docs[0].page_content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "616cfeeb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -203,7 +203,7 @@
"prompt = ChatPromptTemplate.from_messages([\n",
" SystemMessage(content=\"You are a chatbot having a conversation with a human.\"), # The persistent system prompt\n",
" MessagesPlaceholder(variable_name=\"chat_history\"), # Where the memory will be stored.\n",
" HumanMessagePromptTemplate.from_template(\"{human_input}\"), # Where the human input will injectd\n",
" HumanMessagePromptTemplate.from_template(\"{human_input}\"), # Where the human input will injected\n",
"])\n",
" \n",
"memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)"

View File

@@ -1,565 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "517a9fd4",
"metadata": {},
"source": [
"# BabyAGI User Guide\n",
"\n",
"This notebook demonstrates how to implement [BabyAGI](https://github.com/yoheinakajima/babyagi/tree/main) by [Yohei Nakajima](https://twitter.com/yoheinakajima). BabyAGI is an AI agent that can generate and pretend to execute tasks based on a given objective.\n",
"\n",
"This guide will help you understand the components to create your own recursive agents.\n",
"\n",
"Although BabyAGI uses specific vectorstores/model providers (Pinecone, OpenAI), one of the benefits of implementing it with LangChain is that you can easily swap those out for different options. In this implementation we use a FAISS vectorstore (because it runs locally and is free)."
]
},
{
"cell_type": "markdown",
"id": "556af556",
"metadata": {},
"source": [
"## Install and Import Required Modules"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "c8a354b6",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from collections import deque\n",
"from typing import Dict, List, Optional, Any\n",
"\n",
"from langchain import LLMChain, OpenAI, PromptTemplate\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.llms import BaseLLM\n",
"from langchain.vectorstores.base import VectorStore\n",
"from pydantic import BaseModel, Field\n",
"from langchain.chains.base import Chain"
]
},
{
"cell_type": "markdown",
"id": "09f70772",
"metadata": {},
"source": [
"## Connect to the Vector Store\n",
"\n",
"Depending on what vectorstore you use, this step may look different."
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "794045d4",
"metadata": {},
"outputs": [],
"source": [
"from langchain.vectorstores import FAISS\n",
"from langchain.docstore import InMemoryDocstore"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "6e0305eb",
"metadata": {},
"outputs": [],
"source": [
"# Define your embedding model\n",
"embeddings_model = OpenAIEmbeddings()\n",
"# Initialize the vectorstore as empty\n",
"import faiss\n",
"\n",
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})"
]
},
{
"cell_type": "markdown",
"id": "0f3b72bf",
"metadata": {},
"source": [
"## Define the Chains\n",
"\n",
"BabyAGI relies on three LLM chains:\n",
"- Task creation chain to select new tasks to add to the list\n",
"- Task prioritization chain to re-prioritize tasks\n",
"- Execution Chain to execute the tasks"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "bf4bd5cd",
"metadata": {},
"outputs": [],
"source": [
"class TaskCreationChain(LLMChain):\n",
" \"\"\"Chain to generates tasks.\"\"\"\n",
"\n",
" @classmethod\n",
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
" \"\"\"Get the response parser.\"\"\"\n",
" task_creation_template = (\n",
" \"You are a task creation AI that uses the result of an execution agent\"\n",
" \" to create new tasks with the following objective: {objective},\"\n",
" \" The last completed task has the result: {result}.\"\n",
" \" This result was based on this task description: {task_description}.\"\n",
" \" These are incomplete tasks: {incomplete_tasks}.\"\n",
" \" Based on the result, create new tasks to be completed\"\n",
" \" by the AI system that do not overlap with incomplete tasks.\"\n",
" \" Return the tasks as an array.\"\n",
" )\n",
" prompt = PromptTemplate(\n",
" template=task_creation_template,\n",
" input_variables=[\n",
" \"result\",\n",
" \"task_description\",\n",
" \"incomplete_tasks\",\n",
" \"objective\",\n",
" ],\n",
" )\n",
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "b6488ffe",
"metadata": {},
"outputs": [],
"source": [
"class TaskPrioritizationChain(LLMChain):\n",
" \"\"\"Chain to prioritize tasks.\"\"\"\n",
"\n",
" @classmethod\n",
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
" \"\"\"Get the response parser.\"\"\"\n",
" task_prioritization_template = (\n",
" \"You are a task prioritization AI tasked with cleaning the formatting of and reprioritizing\"\n",
" \" the following tasks: {task_names}.\"\n",
" \" Consider the ultimate objective of your team: {objective}.\"\n",
" \" Do not remove any tasks. Return the result as a numbered list, like:\"\n",
" \" #. First task\"\n",
" \" #. Second task\"\n",
" \" Start the task list with number {next_task_id}.\"\n",
" )\n",
" prompt = PromptTemplate(\n",
" template=task_prioritization_template,\n",
" input_variables=[\"task_names\", \"next_task_id\", \"objective\"],\n",
" )\n",
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "b43cd580",
"metadata": {},
"outputs": [],
"source": [
"class ExecutionChain(LLMChain):\n",
" \"\"\"Chain to execute tasks.\"\"\"\n",
"\n",
" @classmethod\n",
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
" \"\"\"Get the response parser.\"\"\"\n",
" execution_template = (\n",
" \"You are an AI who performs one task based on the following objective: {objective}.\"\n",
" \" Take into account these previously completed tasks: {context}.\"\n",
" \" Your task: {task}.\"\n",
" \" Response:\"\n",
" )\n",
" prompt = PromptTemplate(\n",
" template=execution_template,\n",
" input_variables=[\"objective\", \"context\", \"task\"],\n",
" )\n",
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
]
},
{
"cell_type": "markdown",
"id": "3ad996c5",
"metadata": {},
"source": [
"### Define the BabyAGI Controller\n",
"\n",
"BabyAGI composes the chains defined above in a (potentially-)infinite loop."
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "0ada0636",
"metadata": {},
"outputs": [],
"source": [
"def get_next_task(\n",
" task_creation_chain: LLMChain,\n",
" result: Dict,\n",
" task_description: str,\n",
" task_list: List[str],\n",
" objective: str,\n",
") -> List[Dict]:\n",
" \"\"\"Get the next task.\"\"\"\n",
" incomplete_tasks = \", \".join(task_list)\n",
" response = task_creation_chain.run(\n",
" result=result,\n",
" task_description=task_description,\n",
" incomplete_tasks=incomplete_tasks,\n",
" objective=objective,\n",
" )\n",
" new_tasks = response.split(\"\\n\")\n",
" return [{\"task_name\": task_name} for task_name in new_tasks if task_name.strip()]"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "d35250ad",
"metadata": {},
"outputs": [],
"source": [
"def prioritize_tasks(\n",
" task_prioritization_chain: LLMChain,\n",
" this_task_id: int,\n",
" task_list: List[Dict],\n",
" objective: str,\n",
") -> List[Dict]:\n",
" \"\"\"Prioritize tasks.\"\"\"\n",
" task_names = [t[\"task_name\"] for t in task_list]\n",
" next_task_id = int(this_task_id) + 1\n",
" response = task_prioritization_chain.run(\n",
" task_names=task_names, next_task_id=next_task_id, objective=objective\n",
" )\n",
" new_tasks = response.split(\"\\n\")\n",
" prioritized_task_list = []\n",
" for task_string in new_tasks:\n",
" if not task_string.strip():\n",
" continue\n",
" task_parts = task_string.strip().split(\".\", 1)\n",
" if len(task_parts) == 2:\n",
" task_id = task_parts[0].strip()\n",
" task_name = task_parts[1].strip()\n",
" prioritized_task_list.append({\"task_id\": task_id, \"task_name\": task_name})\n",
" return prioritized_task_list"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "e3f1840c",
"metadata": {},
"outputs": [],
"source": [
"def _get_top_tasks(vectorstore, query: str, k: int) -> List[str]:\n",
" \"\"\"Get the top k tasks based on the query.\"\"\"\n",
" results = vectorstore.similarity_search_with_score(query, k=k)\n",
" if not results:\n",
" return []\n",
" sorted_results, _ = zip(*sorted(results, key=lambda x: x[1], reverse=True))\n",
" return [str(item.metadata[\"task\"]) for item in sorted_results]\n",
"\n",
"\n",
"def execute_task(\n",
" vectorstore, execution_chain: LLMChain, objective: str, task: str, k: int = 5\n",
") -> str:\n",
" \"\"\"Execute a task.\"\"\"\n",
" context = _get_top_tasks(vectorstore, query=objective, k=k)\n",
" return execution_chain.run(objective=objective, context=context, task=task)"
]
},
{
"cell_type": "code",
"execution_count": 137,
"id": "1e978938",
"metadata": {},
"outputs": [],
"source": [
"class BabyAGI(Chain, BaseModel):\n",
" \"\"\"Controller model for the BabyAGI agent.\"\"\"\n",
"\n",
" task_list: deque = Field(default_factory=deque)\n",
" task_creation_chain: TaskCreationChain = Field(...)\n",
" task_prioritization_chain: TaskPrioritizationChain = Field(...)\n",
" execution_chain: ExecutionChain = Field(...)\n",
" task_id_counter: int = Field(1)\n",
" vectorstore: VectorStore = Field(init=False)\n",
" max_iterations: Optional[int] = None\n",
"\n",
" class Config:\n",
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
"\n",
" arbitrary_types_allowed = True\n",
"\n",
" def add_task(self, task: Dict):\n",
" self.task_list.append(task)\n",
"\n",
" def print_task_list(self):\n",
" print(\"\\033[95m\\033[1m\" + \"\\n*****TASK LIST*****\\n\" + \"\\033[0m\\033[0m\")\n",
" for t in self.task_list:\n",
" print(str(t[\"task_id\"]) + \": \" + t[\"task_name\"])\n",
"\n",
" def print_next_task(self, task: Dict):\n",
" print(\"\\033[92m\\033[1m\" + \"\\n*****NEXT TASK*****\\n\" + \"\\033[0m\\033[0m\")\n",
" print(str(task[\"task_id\"]) + \": \" + task[\"task_name\"])\n",
"\n",
" def print_task_result(self, result: str):\n",
" print(\"\\033[93m\\033[1m\" + \"\\n*****TASK RESULT*****\\n\" + \"\\033[0m\\033[0m\")\n",
" print(result)\n",
"\n",
" @property\n",
" def input_keys(self) -> List[str]:\n",
" return [\"objective\"]\n",
"\n",
" @property\n",
" def output_keys(self) -> List[str]:\n",
" return []\n",
"\n",
" def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:\n",
" \"\"\"Run the agent.\"\"\"\n",
" objective = inputs[\"objective\"]\n",
" first_task = inputs.get(\"first_task\", \"Make a todo list\")\n",
" self.add_task({\"task_id\": 1, \"task_name\": first_task})\n",
" num_iters = 0\n",
" while True:\n",
" if self.task_list:\n",
" self.print_task_list()\n",
"\n",
" # Step 1: Pull the first task\n",
" task = self.task_list.popleft()\n",
" self.print_next_task(task)\n",
"\n",
" # Step 2: Execute the task\n",
" result = execute_task(\n",
" self.vectorstore, self.execution_chain, objective, task[\"task_name\"]\n",
" )\n",
" this_task_id = int(task[\"task_id\"])\n",
" self.print_task_result(result)\n",
"\n",
" # Step 3: Store the result in Pinecone\n",
" result_id = f\"result_{task['task_id']}_{num_iters}\"\n",
" self.vectorstore.add_texts(\n",
" texts=[result],\n",
" metadatas=[{\"task\": task[\"task_name\"]}],\n",
" ids=[result_id],\n",
" )\n",
"\n",
" # Step 4: Create new tasks and reprioritize task list\n",
" new_tasks = get_next_task(\n",
" self.task_creation_chain,\n",
" result,\n",
" task[\"task_name\"],\n",
" [t[\"task_name\"] for t in self.task_list],\n",
" objective,\n",
" )\n",
" for new_task in new_tasks:\n",
" self.task_id_counter += 1\n",
" new_task.update({\"task_id\": self.task_id_counter})\n",
" self.add_task(new_task)\n",
" self.task_list = deque(\n",
" prioritize_tasks(\n",
" self.task_prioritization_chain,\n",
" this_task_id,\n",
" list(self.task_list),\n",
" objective,\n",
" )\n",
" )\n",
" num_iters += 1\n",
" if self.max_iterations is not None and num_iters == self.max_iterations:\n",
" print(\n",
" \"\\033[91m\\033[1m\" + \"\\n*****TASK ENDING*****\\n\" + \"\\033[0m\\033[0m\"\n",
" )\n",
" break\n",
" return {}\n",
"\n",
" @classmethod\n",
" def from_llm(\n",
" cls, llm: BaseLLM, vectorstore: VectorStore, verbose: bool = False, **kwargs\n",
" ) -> \"BabyAGI\":\n",
" \"\"\"Initialize the BabyAGI Controller.\"\"\"\n",
" task_creation_chain = TaskCreationChain.from_llm(llm, verbose=verbose)\n",
" task_prioritization_chain = TaskPrioritizationChain.from_llm(\n",
" llm, verbose=verbose\n",
" )\n",
" execution_chain = ExecutionChain.from_llm(llm, verbose=verbose)\n",
" return cls(\n",
" task_creation_chain=task_creation_chain,\n",
" task_prioritization_chain=task_prioritization_chain,\n",
" execution_chain=execution_chain,\n",
" vectorstore=vectorstore,\n",
" **kwargs,\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "05ba762e",
"metadata": {},
"source": [
"### Run the BabyAGI\n",
"\n",
"Now it's time to create the BabyAGI controller and watch it try to accomplish your objective."
]
},
{
"cell_type": "code",
"execution_count": 138,
"id": "3d220b69",
"metadata": {},
"outputs": [],
"source": [
"OBJECTIVE = \"Write a weather report for SF today\""
]
},
{
"cell_type": "code",
"execution_count": 139,
"id": "8a8e5543",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 140,
"id": "3d69899b",
"metadata": {},
"outputs": [],
"source": [
"# Logging of LLMChains\n",
"verbose = False\n",
"# If None, will keep on going forever\n",
"max_iterations: Optional[int] = 3\n",
"baby_agi = BabyAGI.from_llm(\n",
" llm=llm, vectorstore=vectorstore, verbose=verbose, max_iterations=max_iterations\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 141,
"id": "f7957b51",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[95m\u001b[1m\n",
"*****TASK LIST*****\n",
"\u001b[0m\u001b[0m\n",
"1: Make a todo list\n",
"\u001b[92m\u001b[1m\n",
"*****NEXT TASK*****\n",
"\u001b[0m\u001b[0m\n",
"1: Make a todo list\n",
"\u001b[93m\u001b[1m\n",
"*****TASK RESULT*****\n",
"\u001b[0m\u001b[0m\n",
"\n",
"\n",
"1. Check the temperature range for the day.\n",
"2. Gather temperature data for SF today.\n",
"3. Analyze the temperature data and create a weather report.\n",
"4. Publish the weather report.\n",
"\u001b[95m\u001b[1m\n",
"*****TASK LIST*****\n",
"\u001b[0m\u001b[0m\n",
"2: Gather data on the expected temperature range for the day.\n",
"3: Collect data on the expected precipitation for the day.\n",
"4: Analyze the data and create a weather report.\n",
"5: Check the current weather conditions in SF.\n",
"6: Publish the weather report.\n",
"\u001b[92m\u001b[1m\n",
"*****NEXT TASK*****\n",
"\u001b[0m\u001b[0m\n",
"2: Gather data on the expected temperature range for the day.\n",
"\u001b[93m\u001b[1m\n",
"*****TASK RESULT*****\n",
"\u001b[0m\u001b[0m\n",
"\n",
"\n",
"I have gathered data on the expected temperature range for the day in San Francisco. The forecast is for temperatures to range from a low of 55 degrees Fahrenheit to a high of 68 degrees Fahrenheit.\n",
"\u001b[95m\u001b[1m\n",
"*****TASK LIST*****\n",
"\u001b[0m\u001b[0m\n",
"3: Check the current weather conditions in SF.\n",
"4: Calculate the average temperature for the day in San Francisco.\n",
"5: Determine the probability of precipitation for the day in San Francisco.\n",
"6: Identify any potential weather warnings or advisories for the day in San Francisco.\n",
"7: Research any historical weather patterns for the day in San Francisco.\n",
"8: Compare the expected temperature range to the historical average for the day in San Francisco.\n",
"9: Collect data on the expected precipitation for the day.\n",
"10: Analyze the data and create a weather report.\n",
"11: Publish the weather report.\n",
"\u001b[92m\u001b[1m\n",
"*****NEXT TASK*****\n",
"\u001b[0m\u001b[0m\n",
"3: Check the current weather conditions in SF.\n",
"\u001b[93m\u001b[1m\n",
"*****TASK RESULT*****\n",
"\u001b[0m\u001b[0m\n",
"\n",
"\n",
"I am checking the current weather conditions in SF. According to the data I have gathered, the temperature in SF today is currently around 65 degrees Fahrenheit with clear skies. The temperature range for the day is expected to be between 60 and 70 degrees Fahrenheit.\n",
"\u001b[91m\u001b[1m\n",
"*****TASK ENDING*****\n",
"\u001b[0m\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'objective': 'Write a weather report for SF today'}"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"baby_agi({\"objective\": OBJECTIVE})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "898a210b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,647 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "517a9fd4",
"metadata": {},
"source": [
"# BabyAGI with Tools\n",
"\n",
"This notebook builds on top of [baby agi](baby_agi.html), but shows how you can swap out the execution chain. The previous execution chain was just an LLM which made stuff up. By swapping it out with an agent that has access to tools, we can hopefully get real reliable information"
]
},
{
"cell_type": "markdown",
"id": "556af556",
"metadata": {},
"source": [
"## Install and Import Required Modules"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c8a354b6",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from collections import deque\n",
"from typing import Dict, List, Optional, Any\n",
"\n",
"from langchain import LLMChain, OpenAI, PromptTemplate\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.llms import BaseLLM\n",
"from langchain.vectorstores.base import VectorStore\n",
"from pydantic import BaseModel, Field\n",
"from langchain.chains.base import Chain"
]
},
{
"cell_type": "markdown",
"id": "09f70772",
"metadata": {},
"source": [
"## Connect to the Vector Store\n",
"\n",
"Depending on what vectorstore you use, this step may look different."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "794045d4",
"metadata": {},
"outputs": [],
"source": [
"%pip install faiss-cpu > /dev/null\n",
"%pip install google-search-results > /dev/null\n",
"from langchain.vectorstores import FAISS\n",
"from langchain.docstore import InMemoryDocstore"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6e0305eb",
"metadata": {},
"outputs": [],
"source": [
"# Define your embedding model\n",
"embeddings_model = OpenAIEmbeddings()\n",
"# Initialize the vectorstore as empty\n",
"import faiss\n",
"\n",
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})"
]
},
{
"cell_type": "markdown",
"id": "0f3b72bf",
"metadata": {},
"source": [
"## Define the Chains\n",
"\n",
"BabyAGI relies on three LLM chains:\n",
"- Task creation chain to select new tasks to add to the list\n",
"- Task prioritization chain to re-prioritize tasks\n",
"- Execution Chain to execute the tasks\n",
"\n",
"\n",
"NOTE: in this notebook, the Execution chain will now be an agent."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bf4bd5cd",
"metadata": {},
"outputs": [],
"source": [
"class TaskCreationChain(LLMChain):\n",
" \"\"\"Chain to generates tasks.\"\"\"\n",
"\n",
" @classmethod\n",
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
" \"\"\"Get the response parser.\"\"\"\n",
" task_creation_template = (\n",
" \"You are an task creation AI that uses the result of an execution agent\"\n",
" \" to create new tasks with the following objective: {objective},\"\n",
" \" The last completed task has the result: {result}.\"\n",
" \" This result was based on this task description: {task_description}.\"\n",
" \" These are incomplete tasks: {incomplete_tasks}.\"\n",
" \" Based on the result, create new tasks to be completed\"\n",
" \" by the AI system that do not overlap with incomplete tasks.\"\n",
" \" Return the tasks as an array.\"\n",
" )\n",
" prompt = PromptTemplate(\n",
" template=task_creation_template,\n",
" input_variables=[\n",
" \"result\",\n",
" \"task_description\",\n",
" \"incomplete_tasks\",\n",
" \"objective\",\n",
" ],\n",
" )\n",
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b6488ffe",
"metadata": {},
"outputs": [],
"source": [
"class TaskPrioritizationChain(LLMChain):\n",
" \"\"\"Chain to prioritize tasks.\"\"\"\n",
"\n",
" @classmethod\n",
" def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:\n",
" \"\"\"Get the response parser.\"\"\"\n",
" task_prioritization_template = (\n",
" \"You are an task prioritization AI tasked with cleaning the formatting of and reprioritizing\"\n",
" \" the following tasks: {task_names}.\"\n",
" \" Consider the ultimate objective of your team: {objective}.\"\n",
" \" Do not remove any tasks. Return the result as a numbered list, like:\"\n",
" \" #. First task\"\n",
" \" #. Second task\"\n",
" \" Start the task list with number {next_task_id}.\"\n",
" )\n",
" prompt = PromptTemplate(\n",
" template=task_prioritization_template,\n",
" input_variables=[\"task_names\", \"next_task_id\", \"objective\"],\n",
" )\n",
" return cls(prompt=prompt, llm=llm, verbose=verbose)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "b43cd580",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
"from langchain import OpenAI, SerpAPIWrapper, LLMChain\n",
"\n",
"todo_prompt = PromptTemplate.from_template(\n",
" \"You are a planner who is an expert at coming up with a todo list for a given objective. Come up with a todo list for this objective: {objective}\"\n",
")\n",
"todo_chain = LLMChain(llm=OpenAI(temperature=0), prompt=todo_prompt)\n",
"search = SerpAPIWrapper()\n",
"tools = [\n",
" Tool(\n",
" name=\"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\",\n",
" ),\n",
" Tool(\n",
" name=\"TODO\",\n",
" func=todo_chain.run,\n",
" description=\"useful for when you need to come up with todo lists. Input: an objective to create a todo list for. Output: a todo list for that objective. Please be very clear what the objective is!\",\n",
" ),\n",
"]\n",
"\n",
"\n",
"prefix = \"\"\"You are an AI who performs one task based on the following objective: {objective}. Take into account these previously completed tasks: {context}.\"\"\"\n",
"suffix = \"\"\"Question: {task}\n",
"{agent_scratchpad}\"\"\"\n",
"prompt = ZeroShotAgent.create_prompt(\n",
" tools,\n",
" prefix=prefix,\n",
" suffix=suffix,\n",
" input_variables=[\"objective\", \"task\", \"context\", \"agent_scratchpad\"],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3ad996c5",
"metadata": {},
"source": [
"### Define the BabyAGI Controller\n",
"\n",
"BabyAGI composes the chains defined above in a (potentially-)infinite loop."
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "0ada0636",
"metadata": {},
"outputs": [],
"source": [
"def get_next_task(\n",
" task_creation_chain: LLMChain,\n",
" result: Dict,\n",
" task_description: str,\n",
" task_list: List[str],\n",
" objective: str,\n",
") -> List[Dict]:\n",
" \"\"\"Get the next task.\"\"\"\n",
" incomplete_tasks = \", \".join(task_list)\n",
" response = task_creation_chain.run(\n",
" result=result,\n",
" task_description=task_description,\n",
" incomplete_tasks=incomplete_tasks,\n",
" objective=objective,\n",
" )\n",
" new_tasks = response.split(\"\\n\")\n",
" return [{\"task_name\": task_name} for task_name in new_tasks if task_name.strip()]"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "d35250ad",
"metadata": {},
"outputs": [],
"source": [
"def prioritize_tasks(\n",
" task_prioritization_chain: LLMChain,\n",
" this_task_id: int,\n",
" task_list: List[Dict],\n",
" objective: str,\n",
") -> List[Dict]:\n",
" \"\"\"Prioritize tasks.\"\"\"\n",
" task_names = [t[\"task_name\"] for t in task_list]\n",
" next_task_id = int(this_task_id) + 1\n",
" response = task_prioritization_chain.run(\n",
" task_names=task_names, next_task_id=next_task_id, objective=objective\n",
" )\n",
" new_tasks = response.split(\"\\n\")\n",
" prioritized_task_list = []\n",
" for task_string in new_tasks:\n",
" if not task_string.strip():\n",
" continue\n",
" task_parts = task_string.strip().split(\".\", 1)\n",
" if len(task_parts) == 2:\n",
" task_id = task_parts[0].strip()\n",
" task_name = task_parts[1].strip()\n",
" prioritized_task_list.append({\"task_id\": task_id, \"task_name\": task_name})\n",
" return prioritized_task_list"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "e3f1840c",
"metadata": {},
"outputs": [],
"source": [
"def _get_top_tasks(vectorstore, query: str, k: int) -> List[str]:\n",
" \"\"\"Get the top k tasks based on the query.\"\"\"\n",
" results = vectorstore.similarity_search_with_score(query, k=k)\n",
" if not results:\n",
" return []\n",
" sorted_results, _ = zip(*sorted(results, key=lambda x: x[1], reverse=True))\n",
" return [str(item.metadata[\"task\"]) for item in sorted_results]\n",
"\n",
"\n",
"def execute_task(\n",
" vectorstore, execution_chain: LLMChain, objective: str, task: str, k: int = 5\n",
") -> str:\n",
" \"\"\"Execute a task.\"\"\"\n",
" context = _get_top_tasks(vectorstore, query=objective, k=k)\n",
" return execution_chain.run(objective=objective, context=context, task=task)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "1e978938",
"metadata": {},
"outputs": [],
"source": [
"class BabyAGI(Chain, BaseModel):\n",
" \"\"\"Controller model for the BabyAGI agent.\"\"\"\n",
"\n",
" task_list: deque = Field(default_factory=deque)\n",
" task_creation_chain: TaskCreationChain = Field(...)\n",
" task_prioritization_chain: TaskPrioritizationChain = Field(...)\n",
" execution_chain: AgentExecutor = Field(...)\n",
" task_id_counter: int = Field(1)\n",
" vectorstore: VectorStore = Field(init=False)\n",
" max_iterations: Optional[int] = None\n",
"\n",
" class Config:\n",
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
"\n",
" arbitrary_types_allowed = True\n",
"\n",
" def add_task(self, task: Dict):\n",
" self.task_list.append(task)\n",
"\n",
" def print_task_list(self):\n",
" print(\"\\033[95m\\033[1m\" + \"\\n*****TASK LIST*****\\n\" + \"\\033[0m\\033[0m\")\n",
" for t in self.task_list:\n",
" print(str(t[\"task_id\"]) + \": \" + t[\"task_name\"])\n",
"\n",
" def print_next_task(self, task: Dict):\n",
" print(\"\\033[92m\\033[1m\" + \"\\n*****NEXT TASK*****\\n\" + \"\\033[0m\\033[0m\")\n",
" print(str(task[\"task_id\"]) + \": \" + task[\"task_name\"])\n",
"\n",
" def print_task_result(self, result: str):\n",
" print(\"\\033[93m\\033[1m\" + \"\\n*****TASK RESULT*****\\n\" + \"\\033[0m\\033[0m\")\n",
" print(result)\n",
"\n",
" @property\n",
" def input_keys(self) -> List[str]:\n",
" return [\"objective\"]\n",
"\n",
" @property\n",
" def output_keys(self) -> List[str]:\n",
" return []\n",
"\n",
" def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:\n",
" \"\"\"Run the agent.\"\"\"\n",
" objective = inputs[\"objective\"]\n",
" first_task = inputs.get(\"first_task\", \"Make a todo list\")\n",
" self.add_task({\"task_id\": 1, \"task_name\": first_task})\n",
" num_iters = 0\n",
" while True:\n",
" if self.task_list:\n",
" self.print_task_list()\n",
"\n",
" # Step 1: Pull the first task\n",
" task = self.task_list.popleft()\n",
" self.print_next_task(task)\n",
"\n",
" # Step 2: Execute the task\n",
" result = execute_task(\n",
" self.vectorstore, self.execution_chain, objective, task[\"task_name\"]\n",
" )\n",
" this_task_id = int(task[\"task_id\"])\n",
" self.print_task_result(result)\n",
"\n",
" # Step 3: Store the result in Pinecone\n",
" result_id = f\"result_{task['task_id']}_{num_iters}\"\n",
" self.vectorstore.add_texts(\n",
" texts=[result],\n",
" metadatas=[{\"task\": task[\"task_name\"]}],\n",
" ids=[result_id],\n",
" )\n",
"\n",
" # Step 4: Create new tasks and reprioritize task list\n",
" new_tasks = get_next_task(\n",
" self.task_creation_chain,\n",
" result,\n",
" task[\"task_name\"],\n",
" [t[\"task_name\"] for t in self.task_list],\n",
" objective,\n",
" )\n",
" for new_task in new_tasks:\n",
" self.task_id_counter += 1\n",
" new_task.update({\"task_id\": self.task_id_counter})\n",
" self.add_task(new_task)\n",
" self.task_list = deque(\n",
" prioritize_tasks(\n",
" self.task_prioritization_chain,\n",
" this_task_id,\n",
" list(self.task_list),\n",
" objective,\n",
" )\n",
" )\n",
" num_iters += 1\n",
" if self.max_iterations is not None and num_iters == self.max_iterations:\n",
" print(\n",
" \"\\033[91m\\033[1m\" + \"\\n*****TASK ENDING*****\\n\" + \"\\033[0m\\033[0m\"\n",
" )\n",
" break\n",
" return {}\n",
"\n",
" @classmethod\n",
" def from_llm(\n",
" cls, llm: BaseLLM, vectorstore: VectorStore, verbose: bool = False, **kwargs\n",
" ) -> \"BabyAGI\":\n",
" \"\"\"Initialize the BabyAGI Controller.\"\"\"\n",
" task_creation_chain = TaskCreationChain.from_llm(llm, verbose=verbose)\n",
" task_prioritization_chain = TaskPrioritizationChain.from_llm(\n",
" llm, verbose=verbose\n",
" )\n",
" llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
" tool_names = [tool.name for tool in tools]\n",
" agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)\n",
" agent_executor = AgentExecutor.from_agent_and_tools(\n",
" agent=agent, tools=tools, verbose=True\n",
" )\n",
" return cls(\n",
" task_creation_chain=task_creation_chain,\n",
" task_prioritization_chain=task_prioritization_chain,\n",
" execution_chain=agent_executor,\n",
" vectorstore=vectorstore,\n",
" **kwargs,\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "05ba762e",
"metadata": {},
"source": [
"### Run the BabyAGI\n",
"\n",
"Now it's time to create the BabyAGI controller and watch it try to accomplish your objective."
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "3d220b69",
"metadata": {},
"outputs": [],
"source": [
"OBJECTIVE = \"Write a weather report for SF today\""
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "8a8e5543",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "3d69899b",
"metadata": {},
"outputs": [],
"source": [
"# Logging of LLMChains\n",
"verbose = False\n",
"# If None, will keep on going forever\n",
"max_iterations: Optional[int] = 3\n",
"baby_agi = BabyAGI.from_llm(\n",
" llm=llm, vectorstore=vectorstore, verbose=verbose, max_iterations=max_iterations\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "f7957b51",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[95m\u001b[1m\n",
"*****TASK LIST*****\n",
"\u001b[0m\u001b[0m\n",
"1: Make a todo list\n",
"\u001b[92m\u001b[1m\n",
"*****NEXT TASK*****\n",
"\u001b[0m\u001b[0m\n",
"1: Make a todo list\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to gather data on the current weather conditions in SF\n",
"Action: Search\n",
"Action Input: Current weather conditions in SF\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mHigh 67F. Winds WNW at 10 to 15 mph. Clear to partly cloudy.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to make a todo list\n",
"Action: TODO\n",
"Action Input: Write a weather report for SF today\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"\n",
"1. Research current weather conditions in San Francisco\n",
"2. Gather data on temperature, humidity, wind speed, and other relevant weather conditions\n",
"3. Analyze data to determine current weather trends\n",
"4. Write a brief introduction to the weather report\n",
"5. Describe current weather conditions in San Francisco\n",
"6. Discuss any upcoming weather changes\n",
"7. Summarize the weather report\n",
"8. Proofread and edit the report\n",
"9. Submit the report\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: A weather report for SF today should include research on current weather conditions in San Francisco, gathering data on temperature, humidity, wind speed, and other relevant weather conditions, analyzing data to determine current weather trends, writing a brief introduction to the weather report, describing current weather conditions in San Francisco, discussing any upcoming weather changes, summarizing the weather report, proofreading and editing the report, and submitting the report.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[93m\u001b[1m\n",
"*****TASK RESULT*****\n",
"\u001b[0m\u001b[0m\n",
"A weather report for SF today should include research on current weather conditions in San Francisco, gathering data on temperature, humidity, wind speed, and other relevant weather conditions, analyzing data to determine current weather trends, writing a brief introduction to the weather report, describing current weather conditions in San Francisco, discussing any upcoming weather changes, summarizing the weather report, proofreading and editing the report, and submitting the report.\n",
"\u001b[95m\u001b[1m\n",
"*****TASK LIST*****\n",
"\u001b[0m\u001b[0m\n",
"2: Gather data on temperature, humidity, wind speed, and other relevant weather conditions\n",
"3: Analyze data to determine current weather trends\n",
"4: Write a brief introduction to the weather report\n",
"5: Describe current weather conditions in San Francisco\n",
"6: Discuss any upcoming weather changes\n",
"7: Summarize the weather report\n",
"8: Proofread and edit the report\n",
"9: Submit the report\n",
"1: Research current weather conditions in San Francisco\n",
"\u001b[92m\u001b[1m\n",
"*****NEXT TASK*****\n",
"\u001b[0m\u001b[0m\n",
"2: Gather data on temperature, humidity, wind speed, and other relevant weather conditions\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to search for the current weather conditions in SF\n",
"Action: Search\n",
"Action Input: Current weather conditions in SF\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mHigh 67F. Winds WNW at 10 to 15 mph. Clear to partly cloudy.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to make a todo list\n",
"Action: TODO\n",
"Action Input: Create a weather report for SF today\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m\n",
"\n",
"1. Gather current weather data for SF, including temperature, wind speed, humidity, and precipitation.\n",
"2. Research historical weather data for SF to compare current conditions.\n",
"3. Analyze current and historical data to determine any trends or patterns.\n",
"4. Create a visual representation of the data, such as a graph or chart.\n",
"5. Write a summary of the weather report, including key findings and any relevant information.\n",
"6. Publish the weather report on a website or other platform.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Today in San Francisco, the temperature is 67F with winds WNW at 10 to 15 mph. The sky is clear to partly cloudy.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[93m\u001b[1m\n",
"*****TASK RESULT*****\n",
"\u001b[0m\u001b[0m\n",
"Today in San Francisco, the temperature is 67F with winds WNW at 10 to 15 mph. The sky is clear to partly cloudy.\n",
"\u001b[95m\u001b[1m\n",
"*****TASK LIST*****\n",
"\u001b[0m\u001b[0m\n",
"3: Research current weather conditions in San Francisco\n",
"4: Compare the current weather conditions in San Francisco to the average for this time of year.\n",
"5: Identify any potential weather-related hazards in the area.\n",
"6: Research any historical weather patterns in San Francisco.\n",
"7: Analyze data to determine current weather trends\n",
"8: Include any relevant data from nearby cities in the report.\n",
"9: Include any relevant data from the National Weather Service in the report.\n",
"10: Include any relevant data from local news sources in the report.\n",
"11: Include any relevant data from online weather sources in the report.\n",
"12: Include any relevant data from local meteorologists in the report.\n",
"13: Include any relevant data from local weather stations in the report.\n",
"14: Include any relevant data from satellite images in the report.\n",
"15: Describe current weather conditions in San Francisco\n",
"16: Discuss any upcoming weather changes\n",
"17: Write a brief introduction to the weather report\n",
"18: Summarize the weather report\n",
"19: Proofread and edit the report\n",
"20: Submit the report\n",
"\u001b[92m\u001b[1m\n",
"*****NEXT TASK*****\n",
"\u001b[0m\u001b[0m\n",
"3: Research current weather conditions in San Francisco\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThought: I need to search for current weather conditions in San Francisco\n",
"Action: Search\n",
"Action Input: Current weather conditions in San Francisco\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mTodaySun 04/09 High 67 · 1% Precip. ; TonightSun 04/09 Low 49 · 9% Precip. ; TomorrowMon 04/10 High 64 · 11% Precip.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Today in San Francisco, the high temperature is 67 degrees with 1% chance of precipitation. The low temperature tonight is 49 degrees with 9% chance of precipitation. Tomorrow's high temperature is 64 degrees with 11% chance of precipitation.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[93m\u001b[1m\n",
"*****TASK RESULT*****\n",
"\u001b[0m\u001b[0m\n",
"Today in San Francisco, the high temperature is 67 degrees with 1% chance of precipitation. The low temperature tonight is 49 degrees with 9% chance of precipitation. Tomorrow's high temperature is 64 degrees with 11% chance of precipitation.\n",
"\u001b[91m\u001b[1m\n",
"*****TASK ENDING*****\n",
"\u001b[0m\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'objective': 'Write a weather report for SF today'}"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"baby_agi({\"objective\": OBJECTIVE})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "898a210b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -66,11 +66,11 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from git import Repo\n",
"# from git import Repo\n",
"from langchain.text_splitter import Language\n",
"from langchain.document_loaders.generic import GenericLoader\n",
"from langchain.document_loaders.parsers import LanguageParser"
@@ -78,13 +78,13 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Clone\n",
"repo_path = \"/Users/rlm/Desktop/test_repo\"\n",
"repo = Repo.clone_from(\"https://github.com/hwchase17/langchain\", to_path=repo_path)"
"# repo = Repo.clone_from(\"https://github.com/hwchase17/langchain\", to_path=repo_path)"
]
},
{
@@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -109,7 +109,7 @@
"1293"
]
},
"execution_count": 39,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -139,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -148,7 +148,7 @@
"3748"
]
},
"execution_count": 40,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -187,7 +187,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -333,68 +333,673 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Private chat\n",
"### Open source LLMs\n",
"\n",
"We can use [Code LLaMA](https://about.fb.com/news/2023/08/code-llama-ai-for-coding/) via the Ollama integration.\n",
"We can use [Code LLaMA](https://about.fb.com/news/2023/08/code-llama-ai-for-coding/) via LLamaCPP or [Ollama integration](https://ollama.ai/blog/run-code-llama-locally).\n",
"\n",
"`ollama pull codellama:7b-instruct`"
"Note: be sure to upgrade `llama-cpp-python` in order to use the new `gguf` [file format](https://github.com/abetlen/llama-cpp-python/pull/633).\n",
"\n",
"```\n",
"CMAKE_ARGS=\"-DLLAMA_METAL=on\" FORCE_CMAKE=1 /Users/rlm/miniforge3/envs/llama2/bin/pip install -U llama-cpp-python --no-cache-dir\n",
"```\n",
" \n",
"Check out the latest code-llama models [here](https://huggingface.co/TheBloke/CodeLlama-13B-Instruct-GGUF/tree/main)."
]
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Ollama\n",
"from langchain.llms import LlamaCpp\n",
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.callbacks.manager import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler \n",
"llm = Ollama(model=\"codellama:7b-instruct\", \n",
" callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]))\n",
"memory = ConversationSummaryMemory(llm=llm,memory_key=\"chat_history\",return_messages=True)\n",
"qa_llama=ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory)"
"from langchain.memory import ConversationSummaryMemory\n",
"from langchain.chains import ConversationalRetrievalChain \n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
]
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 15,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"llama_model_loader: loaded meta data with 17 key-value pairs and 363 tensors from /Users/rlm/Desktop/Code/llama/code-llama/codellama-13b-instruct.Q4_K_M.gguf (version GGUF V1 (latest))\n",
"llama_model_loader: - tensor 0: token_embd.weight q4_0 [ 5120, 32016, 1, 1 ]\n",
"llama_model_loader: - tensor 1: output_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 2: output.weight f16 [ 5120, 32016, 1, 1 ]\n",
"llama_model_loader: - tensor 3: blk.0.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 4: blk.0.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 5: blk.0.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 6: blk.0.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 7: blk.0.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 8: blk.0.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 9: blk.0.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 10: blk.0.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 11: blk.0.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 12: blk.1.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 13: blk.1.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 14: blk.1.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 15: blk.1.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 16: blk.1.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 17: blk.1.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 18: blk.1.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 19: blk.1.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 20: blk.1.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 21: blk.2.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 22: blk.2.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 23: blk.2.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 24: blk.2.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 25: blk.2.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 26: blk.2.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 27: blk.2.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 28: blk.2.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 29: blk.2.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 30: blk.3.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 31: blk.3.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 32: blk.3.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 33: blk.3.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 34: blk.3.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 35: blk.3.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 36: blk.3.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 37: blk.3.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 38: blk.3.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 39: blk.4.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 40: blk.4.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 41: blk.4.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 42: blk.4.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 43: blk.4.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 44: blk.4.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 45: blk.4.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 46: blk.4.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 47: blk.4.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 48: blk.5.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 49: blk.5.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 50: blk.5.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 51: blk.5.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 52: blk.5.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 53: blk.5.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 54: blk.5.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 55: blk.5.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 56: blk.5.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 57: blk.6.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 58: blk.6.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 59: blk.6.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 60: blk.6.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 61: blk.6.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 62: blk.6.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 63: blk.6.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 64: blk.6.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 65: blk.6.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 66: blk.7.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 67: blk.7.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 68: blk.7.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 69: blk.7.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 70: blk.7.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 71: blk.7.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 72: blk.7.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 73: blk.7.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 74: blk.7.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 75: blk.8.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 76: blk.8.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 77: blk.8.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 78: blk.8.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 79: blk.8.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 80: blk.8.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 81: blk.8.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 82: blk.8.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 83: blk.8.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 84: blk.9.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 85: blk.9.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 86: blk.9.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 87: blk.9.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 88: blk.9.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 89: blk.9.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 90: blk.9.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 91: blk.9.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 92: blk.9.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 93: blk.10.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 94: blk.10.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 95: blk.10.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 96: blk.10.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 97: blk.10.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 98: blk.10.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 99: blk.10.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 100: blk.10.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 101: blk.10.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 102: blk.11.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 103: blk.11.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 104: blk.11.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 105: blk.11.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 106: blk.11.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 107: blk.11.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 108: blk.11.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 109: blk.11.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 110: blk.11.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 111: blk.12.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 112: blk.12.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 113: blk.12.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 114: blk.12.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 115: blk.12.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 116: blk.12.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 117: blk.12.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 118: blk.12.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 119: blk.12.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 120: blk.13.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 121: blk.13.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 122: blk.13.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 123: blk.13.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 124: blk.13.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 125: blk.13.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 126: blk.13.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 127: blk.13.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 128: blk.13.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 129: blk.14.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 130: blk.14.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 131: blk.14.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 132: blk.14.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 133: blk.14.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 134: blk.14.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 135: blk.14.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 136: blk.14.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 137: blk.14.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 138: blk.15.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 139: blk.15.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 140: blk.15.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 141: blk.15.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 142: blk.15.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 143: blk.15.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 144: blk.15.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 145: blk.15.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 146: blk.15.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 147: blk.16.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 148: blk.16.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 149: blk.16.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 150: blk.16.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 151: blk.16.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 152: blk.16.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 153: blk.16.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 154: blk.16.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 155: blk.16.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 156: blk.17.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 157: blk.17.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 158: blk.17.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 159: blk.17.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 160: blk.17.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 161: blk.17.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 162: blk.17.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 163: blk.17.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 164: blk.17.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 165: blk.18.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 166: blk.18.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 167: blk.18.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 168: blk.18.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 169: blk.18.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 170: blk.18.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 171: blk.18.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 172: blk.18.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 173: blk.18.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 174: blk.19.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 175: blk.19.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 176: blk.19.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 177: blk.19.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 178: blk.19.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 179: blk.19.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 180: blk.19.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 181: blk.19.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 182: blk.19.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 183: blk.20.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 184: blk.20.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 185: blk.20.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 186: blk.20.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 187: blk.20.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 188: blk.20.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 189: blk.20.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 190: blk.20.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 191: blk.20.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 192: blk.21.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 193: blk.21.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 194: blk.21.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 195: blk.21.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 196: blk.21.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 197: blk.21.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 198: blk.21.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 199: blk.21.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 200: blk.21.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 201: blk.22.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 202: blk.22.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 203: blk.22.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 204: blk.22.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 205: blk.22.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 206: blk.22.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 207: blk.22.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 208: blk.22.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 209: blk.22.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 210: blk.23.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 211: blk.23.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 212: blk.23.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 213: blk.23.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 214: blk.23.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 215: blk.23.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 216: blk.23.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 217: blk.23.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 218: blk.23.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 219: blk.24.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 220: blk.24.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 221: blk.24.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 222: blk.24.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 223: blk.24.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 224: blk.24.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 225: blk.24.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 226: blk.24.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 227: blk.24.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 228: blk.25.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 229: blk.25.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 230: blk.25.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 231: blk.25.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 232: blk.25.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 233: blk.25.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 234: blk.25.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 235: blk.25.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 236: blk.25.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 237: blk.26.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 238: blk.26.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 239: blk.26.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 240: blk.26.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 241: blk.26.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 242: blk.26.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 243: blk.26.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 244: blk.26.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 245: blk.26.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 246: blk.27.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 247: blk.27.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 248: blk.27.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 249: blk.27.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 250: blk.27.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 251: blk.27.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 252: blk.27.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 253: blk.27.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 254: blk.27.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 255: blk.28.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 256: blk.28.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 257: blk.28.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 258: blk.28.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 259: blk.28.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 260: blk.28.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 261: blk.28.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 262: blk.28.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 263: blk.28.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 264: blk.29.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 265: blk.29.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 266: blk.29.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 267: blk.29.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 268: blk.29.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 269: blk.29.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 270: blk.29.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 271: blk.29.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 272: blk.29.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 273: blk.30.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 274: blk.30.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 275: blk.30.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 276: blk.30.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 277: blk.30.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 278: blk.30.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 279: blk.30.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 280: blk.30.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 281: blk.30.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 282: blk.31.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 283: blk.31.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 284: blk.31.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 285: blk.31.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 286: blk.31.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 287: blk.31.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 288: blk.31.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 289: blk.31.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 290: blk.31.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 291: blk.32.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 292: blk.32.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 293: blk.32.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 294: blk.32.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 295: blk.32.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 296: blk.32.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 297: blk.32.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 298: blk.32.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 299: blk.32.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 300: blk.33.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 301: blk.33.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 302: blk.33.attn_v.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 303: blk.33.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 304: blk.33.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 305: blk.33.ffn_down.weight q4_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 306: blk.33.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 307: blk.33.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 308: blk.33.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 309: blk.34.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 310: blk.34.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 311: blk.34.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 312: blk.34.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 313: blk.34.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 314: blk.34.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 315: blk.34.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 316: blk.34.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 317: blk.34.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 318: blk.35.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 319: blk.35.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 320: blk.35.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 321: blk.35.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 322: blk.35.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 323: blk.35.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 324: blk.35.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 325: blk.35.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 326: blk.35.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 327: blk.36.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 328: blk.36.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 329: blk.36.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 330: blk.36.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 331: blk.36.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 332: blk.36.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 333: blk.36.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 334: blk.36.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 335: blk.36.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 336: blk.37.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 337: blk.37.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 338: blk.37.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 339: blk.37.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 340: blk.37.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 341: blk.37.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 342: blk.37.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 343: blk.37.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 344: blk.37.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 345: blk.38.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 346: blk.38.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 347: blk.38.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 348: blk.38.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 349: blk.38.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 350: blk.38.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 351: blk.38.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 352: blk.38.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 353: blk.38.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 354: blk.39.attn_q.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 355: blk.39.attn_k.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 356: blk.39.attn_v.weight q6_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 357: blk.39.attn_output.weight q4_K [ 5120, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 358: blk.39.ffn_gate.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 359: blk.39.ffn_down.weight q6_K [ 13824, 5120, 1, 1 ]\n",
"llama_model_loader: - tensor 360: blk.39.ffn_up.weight q4_K [ 5120, 13824, 1, 1 ]\n",
"llama_model_loader: - tensor 361: blk.39.attn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - tensor 362: blk.39.ffn_norm.weight f32 [ 5120, 1, 1, 1 ]\n",
"llama_model_loader: - kv 0: general.architecture str \n",
"llama_model_loader: - kv 1: general.name str \n",
"llama_model_loader: - kv 2: llama.context_length u32 \n",
"llama_model_loader: - kv 3: llama.embedding_length u32 \n",
"llama_model_loader: - kv 4: llama.block_count u32 \n",
"llama_model_loader: - kv 5: llama.feed_forward_length u32 \n",
"llama_model_loader: - kv 6: llama.rope.dimension_count u32 \n",
"llama_model_loader: - kv 7: llama.attention.head_count u32 \n",
"llama_model_loader: - kv 8: llama.attention.head_count_kv u32 \n",
"llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 \n",
"llama_model_loader: - kv 10: llama.rope.freq_base f32 \n",
"llama_model_loader: - kv 11: general.file_type u32 \n",
"llama_model_loader: - kv 12: tokenizer.ggml.model str \n",
"llama_model_loader: - kv 13: tokenizer.ggml.tokens arr \n",
"llama_model_loader: - kv 14: tokenizer.ggml.scores arr \n",
"llama_model_loader: - kv 15: tokenizer.ggml.token_type arr \n",
"llama_model_loader: - kv 16: general.quantization_version u32 \n",
"llama_model_loader: - type f32: 81 tensors\n",
"llama_model_loader: - type f16: 1 tensors\n",
"llama_model_loader: - type q4_0: 1 tensors\n",
"llama_model_loader: - type q4_K: 240 tensors\n",
"llama_model_loader: - type q6_K: 40 tensors\n",
"llm_load_print_meta: format = GGUF V1 (latest)\n",
"llm_load_print_meta: arch = llama\n",
"llm_load_print_meta: vocab type = SPM\n",
"llm_load_print_meta: n_vocab = 32016\n",
"llm_load_print_meta: n_merges = 0\n",
"llm_load_print_meta: n_ctx_train = 16384\n",
"llm_load_print_meta: n_ctx = 5000\n",
"llm_load_print_meta: n_embd = 5120\n",
"llm_load_print_meta: n_head = 40\n",
"llm_load_print_meta: n_head_kv = 40\n",
"llm_load_print_meta: n_layer = 40\n",
"llm_load_print_meta: n_rot = 128\n",
"llm_load_print_meta: n_gqa = 1\n",
"llm_load_print_meta: f_norm_eps = 1.0e-05\n",
"llm_load_print_meta: f_norm_rms_eps = 1.0e-05\n",
"llm_load_print_meta: n_ff = 13824\n",
"llm_load_print_meta: freq_base = 1000000.0\n",
"llm_load_print_meta: freq_scale = 1\n",
"llm_load_print_meta: model type = 13B\n",
"llm_load_print_meta: model ftype = mostly Q4_K - Medium\n",
"llm_load_print_meta: model size = 13.02 B\n",
"llm_load_print_meta: general.name = LLaMA\n",
"llm_load_print_meta: BOS token = 1 '<s>'\n",
"llm_load_print_meta: EOS token = 2 '</s>'\n",
"llm_load_print_meta: UNK token = 0 '<unk>'\n",
"llm_load_print_meta: LF token = 13 '<0x0A>'\n",
"llm_load_tensors: ggml ctx size = 0.11 MB\n",
"llm_load_tensors: mem required = 7685.49 MB (+ 3906.25 MB per state)\n",
".................................................................................................\n",
"llama_new_context_with_model: kv self size = 3906.25 MB\n",
"ggml_metal_init: allocating\n",
"ggml_metal_init: loading '/Users/rlm/miniforge3/envs/llama2/lib/python3.9/site-packages/llama_cpp/ggml-metal.metal'\n",
"ggml_metal_init: loaded kernel_add 0x12126dd00 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_add_row 0x12126d610 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul 0x12126f2a0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_row 0x12126f500 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_scale 0x12126f760 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_silu 0x12126fe40 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_relu 0x1212700a0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_gelu 0x121270300 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_soft_max 0x121270560 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_diag_mask_inf 0x1212707c0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_f16 0x121270a20 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q4_0 0x121270c80 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q4_1 0x121270ee0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q8_0 0x121271140 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q2_K 0x1212713a0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q3_K 0x121271600 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q4_K 0x121271860 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q5_K 0x121271ac0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_get_rows_q6_K 0x121271d20 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_rms_norm 0x121271f80 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_norm 0x1212721e0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_f16_f32 0x121272440 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q4_0_f32 0x1212726a0 | th_max = 896 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q4_1_f32 0x121272900 | th_max = 896 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q8_0_f32 0x121272b60 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q2_K_f32 0x121272dc0 | th_max = 640 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q3_K_f32 0x121273020 | th_max = 704 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q4_K_f32 0x121273280 | th_max = 576 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q5_K_f32 0x1212734e0 | th_max = 576 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mat_q6_K_f32 0x121273740 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_f16_f32 0x1212739a0 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q4_0_f32 0x121273c00 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q8_0_f32 0x121273e60 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q4_1_f32 0x1212740c0 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q2_K_f32 0x121274320 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q3_K_f32 0x121274580 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q4_K_f32 0x1212747e0 | th_max = 768 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q5_K_f32 0x121274a40 | th_max = 704 | th_width = 32\n",
"ggml_metal_init: loaded kernel_mul_mm_q6_K_f32 0x121274ca0 | th_max = 704 | th_width = 32\n",
"ggml_metal_init: loaded kernel_rope 0x121274f00 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_alibi_f32 0x121275160 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_cpy_f32_f16 0x1212753c0 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_cpy_f32_f32 0x121275620 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: loaded kernel_cpy_f16_f16 0x121275880 | th_max = 1024 | th_width = 32\n",
"ggml_metal_init: recommendedMaxWorkingSetSize = 21845.34 MB\n",
"ggml_metal_init: hasUnifiedMemory = true\n",
"ggml_metal_init: maxTransferRate = built-in GPU\n",
"llama_new_context_with_model: compute buffer total size = 442.03 MB\n",
"llama_new_context_with_model: max tensor size = 312.66 MB\n",
"ggml_metal_add_buffer: allocated 'data ' buffer, size = 7686.00 MB, (20243.77 / 21845.34)\n",
"ggml_metal_add_buffer: allocated 'eval ' buffer, size = 1.42 MB, (20245.19 / 21845.34)\n",
"ggml_metal_add_buffer: allocated 'kv ' buffer, size = 3908.25 MB, (24153.44 / 21845.34), warning: current allocated size is greater than the recommended max working set size\n",
"AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | \n",
"ggml_metal_add_buffer: allocated 'alloc ' buffer, size = 440.64 MB, (24594.08 / 21845.34), warning: current allocated size is greater than the recommended max working set size\n"
]
}
],
"source": [
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
"llm = LlamaCpp(\n",
" model_path=\"/Users/rlm/Desktop/Code/llama/code-llama/codellama-13b-instruct.Q4_K_M.gguf\",\n",
" n_ctx=5000,\n",
" n_gpu_layers=1,\n",
" n_batch=512,\n",
" f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls\n",
" callback_manager=callback_manager,\n",
" verbose=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Llama.generate: prefix-match hit\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \"How can I initialize a ReAct agent?\" To initialize a ReAct agent, you can use the `ReActAgent.from_llm_and_tools()` class method. This method takes two arguments: the LLM and a list of tools.\n",
"Here is an example of how to initialize a ReAct agent with the OpenAI language model and the \"Search\" tool:\n",
"from langchain.agents.mrkl.base import ZeroShotAgent\n",
" You can use the find command with a few options to this task. Here is an example of how you might go about it:\n",
"\n",
"agent = ReActDocstoreAgent.from_llm_and_tools(OpenAIFunctionsAgent(), [Tool(\"Search\")]])\n",
"find . -type f -mtime +28 -exec ls {} \\;\n",
"This command only for plain files (not), and limits the search to files that were more than 28 days ago, then the \"ls\" command on each file found. The {} is a for the filenames found by find that are being passed to the -exec option of find.\n",
"\n",
" The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential."
"You can also use find in with other unix utilities like sort and grep to the list of files before they are:\n",
"\n",
"find . -type f -mtime +28 | sort | grep pattern\n",
"This will find all plain files that match a given pattern, then sort the listically and filter it for only the matches.\n",
"\n",
"Answer: `find` is pretty with its search. The should work as well:\n",
"\n",
"\\begin{code}\n",
"ls -l $(find . -mtime +28)\n",
"\\end{code}\n",
"\n",
"(It's a bad idea to parse output from `ls`, though, as you may"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"llama_print_timings: load time = 1074.43 ms\n",
"llama_print_timings: sample time = 180.71 ms / 256 runs ( 0.71 ms per token, 1416.67 tokens per second)\n",
"llama_print_timings: prompt eval time = 0.00 ms / 1 tokens ( 0.00 ms per token, inf tokens per second)\n",
"llama_print_timings: eval time = 9593.04 ms / 256 runs ( 37.47 ms per token, 26.69 tokens per second)\n",
"llama_print_timings: total time = 10139.91 ms\n"
]
},
{
"data": {
"text/plain": [
"' To initialize a ReAct agent, you can use the `ReActAgent.from_llm_and_tools()` class method. This method takes two arguments: the LLM and a list of tools.\\nHere is an example of how to initialize a ReAct agent with the OpenAI language model and the \"Search\" tool:\\nfrom langchain.agents.mrkl.base import ZeroShotAgent\\n\\nagent = ReActDocstoreAgent.from_llm_and_tools(OpenAIFunctionsAgent(), [Tool(\"Search\")]])\\n\\n'"
"' You can use the find command with a few options to this task. Here is an example of how you might go about it:\\n\\nfind . -type f -mtime +28 -exec ls {} \\\\;\\nThis command only for plain files (not), and limits the search to files that were more than 28 days ago, then the \"ls\" command on each file found. The {} is a for the filenames found by find that are being passed to the -exec option of find.\\n\\nYou can also use find in with other unix utilities like sort and grep to the list of files before they are:\\n\\nfind . -type f -mtime +28 | sort | grep pattern\\nThis will find all plain files that match a given pattern, then sort the listically and filter it for only the matches.\\n\\nAnswer: `find` is pretty with its search. The should work as well:\\n\\n\\\\begin{code}\\nls -l $(find . -mtime +28)\\n\\\\end{code}\\n\\n(It\\'s a bad idea to parse output from `ls`, though, as you may'"
]
},
"execution_count": 45,
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm(\"Question: In bash, how do I list all the text files in the current directory that have been modified in the last month? Answer:\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Llama.generate: prefix-match hit\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" You can use the `ReActAgent` class and pass it the desired tools as, for example, you would do like this to create an agent with the `Lookup` and `Search` tool:\n",
"```python\n",
"from langchain.agents.react import ReActAgent\n",
"from langchain.tools.lookup import Lookup\n",
"from langchain.tools.search import Search\n",
"ReActAgent(Lookup(), Search())\n",
"```"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"llama_print_timings: load time = 1074.43 ms\n",
"llama_print_timings: sample time = 65.46 ms / 94 runs ( 0.70 ms per token, 1435.95 tokens per second)\n",
"llama_print_timings: prompt eval time = 15975.57 ms / 1408 tokens ( 11.35 ms per token, 88.13 tokens per second)\n",
"llama_print_timings: eval time = 4772.57 ms / 93 runs ( 51.32 ms per token, 19.49 tokens per second)\n",
"llama_print_timings: total time = 20959.57 ms\n"
]
},
{
"data": {
"text/plain": [
"{'output_text': ' You can use the `ReActAgent` class and pass it the desired tools as, for example, you would do like this to create an agent with the `Lookup` and `Search` tool:\\n```python\\nfrom langchain.agents.react import ReActAgent\\nfrom langchain.tools.lookup import Lookup\\nfrom langchain.tools.search import Search\\nReActAgent(Lookup(), Search())\\n```'}"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains.question_answering import load_qa_chain\n",
"\n",
"# Prompt\n",
"template = \"\"\"Use the following pieces of context to answer the question at the end. \n",
"If you don't know the answer, just say that you don't know, don't try to make up an answer. \n",
"Use three sentences maximum and keep the answer as concise as possible. \n",
"{context}\n",
"Question: {question}\n",
"Helpful Answer:\"\"\"\n",
"QA_CHAIN_PROMPT = PromptTemplate(\n",
" input_variables=[\"context\", \"question\"],\n",
" template=template,\n",
")\n",
"\n",
"# Docs\n",
"question = \"How can I initialize a ReAct agent?\"\n",
"result = qa_llama(question)\n",
"result['answer']"
"docs = retriever.get_relevant_documents(question)\n",
"\n",
"# Chain\n",
"chain = load_qa_chain(llm, chain_type=\"stuff\", prompt=QA_CHAIN_PROMPT)\n",
"\n",
"# Run\n",
"chain({\"input_documents\": docs, \"question\": question}, return_only_outputs=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can view the [LangSmith trace](https://smith.langchain.com/public/fd24c734-e365-4a09-b883-cdbc7dcfa582/r) to sanity check the result relative to what was retrieved."
"Here's the trace [RAG](https://smith.langchain.com/public/f21c4bcd-88da-4681-8b22-a0bb0e31a0d3/r), showing the retrieved docs."
]
}
],
@@ -418,5 +1023,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

View File

@@ -55,7 +55,7 @@
"source": [
"## Quickstart\n",
"\n",
"OpenAI funtions are one way to get started with extraction.\n",
"OpenAI functions are one way to get started with extraction.\n",
"\n",
"Define a schema that specifies the properties we want to extract from the LLM output.\n",
"\n",
@@ -122,7 +122,7 @@
"id": "6f7eb826",
"metadata": {},
"source": [
"## Option 1: OpenAI funtions\n",
"## Option 1: OpenAI functions\n",
"\n",
"### Looking under the hood\n",
"\n",

View File

@@ -0,0 +1,718 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "842dd272",
"metadata": {},
"source": [
"# Agents\n",
"\n",
"[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/langchain-ai/langchain/blob/master/docs/extras/use_cases/more/agents/agents.ipynb)\n",
"\n",
"## Use case \n",
"\n",
"LLM-based agents are powerful general problem solvers.\n",
"\n",
"The [primary LLM agent components](https://lilianweng.github.io/posts/2023-06-23-agent/) include at least 3 things:\n",
"\n",
"* `Planning`: The ability to break down tasks into smaller sub-goals\n",
"* `Memory`: The ability to retain and recall information\n",
"* `Tools`: The ability to get information from external sources (e.g., APIs)\n",
"\n",
"Unlike LLMs simply connected to [APIs](/docs/use_cases/apis/apis), agents [can](https://www.youtube.com/watch?v=DWUdGhRrv2c):\n",
"\n",
"* Self-correct\n",
"* Handle multi-hop tasks (several intermediate \"hops\" or steps to arrive at a conclusion)\n",
"* Tackle long time horizon tasks (that require access to long-term memory)\n",
"\n",
"![Image description](/img/agents_use_case_1.png)\n",
"\n",
"## Overview \n",
"\n",
"LangChain has [many agent types](/docs/modules/agents/agent_types/).\n",
"\n",
"Nearly all agents will use the following components:\n",
" \n",
"**Planning**\n",
" \n",
"* `Prompt`: Can given the LLM [personality](https://arxiv.org/pdf/2304.03442.pdf), context (e.g, via retrieval from memory), or strategies for learninng (e.g., [chain-of-thought](https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/#chain-of-thought-cot)).\n",
"* `Agent` Responsible for deciding what step to take next using an LLM with the `Prompt`\n",
"\n",
"**Memory**\n",
"\n",
"* This can be short or long-term, allowing the agent to persist information.\n",
"\n",
"**Tools**\n",
"\n",
"* Tools are functions that an agent can call.\n",
"\n",
"But, there are some taxonomic differences:\n",
"\n",
"* `Action agents`: Designed to decide the sequence of actions (tool use) (e.g., OpenAI functions agents, ReAct agents).\n",
"* `Simulation agents`: Designed for role-play often in simulated enviorment (e.g., Generative Agents, CAMEL).\n",
"* `Autonomous agents`: Designed for indepdent execution towards long term goals (e.g., BabyAGI, Auto-GPT).\n",
"\n",
"This will focus on `Action agents`.\n",
"\n",
"\n",
"## Quickstart "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a704c7a",
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain openai google-search-results\n",
"\n",
"# Set env var OPENAI_API_KEY and SERPAPI_API_KEY or load from a .env file\n",
"# import dotenv\n",
"\n",
"# dotenv.load_env()"
]
},
{
"cell_type": "markdown",
"id": "639d41ad",
"metadata": {},
"source": [
"`Tools`\n",
"\n",
"LangChain has [many tools](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/load_tools.py) for Agents that we can load easily.\n",
"\n",
"Let's load search and a calcultor."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c60001c9",
"metadata": {},
"outputs": [],
"source": [
"# Tool\n",
"from langchain.agents import load_tools\n",
"from langchain.chat_models import ChatOpenAI\n",
"llm = ChatOpenAI(temperature=0)\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)"
]
},
{
"cell_type": "markdown",
"id": "431ba30b",
"metadata": {},
"source": [
"`Agent`\n",
"\n",
"The [`OPENAI_FUNCTIONS` agent](/docs/modules/agents/agent_types/openai_functions_agent) is a good action agent to start with.\n",
"\n",
"OpenAI models have been fine-tuned to recognize when function should be called."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d636395f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'As of 2023, the estimated population of Canada is approximately 39,858,480 people.'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Prompt\n",
"from langchain.agents import AgentExecutor\n",
"from langchain.schema import SystemMessage\n",
"from langchain.agents import OpenAIFunctionsAgent\n",
"system_message = SystemMessage(content=\"You are a search assistant.\")\n",
"prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)\n",
"\n",
"# Agent\n",
"search_agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)\n",
"agent_executor = AgentExecutor(agent=search_agent, tools=tools, verbose=False)\n",
"\n",
"# Run\n",
"agent_executor.run(\"How many people live in canada as of 2023?\")"
]
},
{
"cell_type": "markdown",
"id": "27842380",
"metadata": {},
"source": [
"Great, we have created a simple search agent with a tool!\n",
"\n",
"Note that we use an agent executor, which is the runtime for an agent. \n",
"\n",
"This is what calls the agent and executes the actions it chooses. \n",
"\n",
"Pseudocode for this runtime is below:\n",
"```\n",
"next_action = agent.get_action(...)\n",
"while next_action != AgentFinish:\n",
" observation = run(next_action)\n",
" next_action = agent.get_action(..., next_action, observation)\n",
"return next_action\n",
"```\n",
"\n",
"While this may seem simple, there are several complexities this runtime handles for you, including:\n",
"\n",
"* Handling cases where the agent selects a non-existent tool\n",
"* Handling cases where the tool errors\n",
"* Handling cases where the agent produces output that cannot be parsed into a tool invocation\n",
"* Logging and observability at all levels (agent decisions, tool calls) either to stdout or LangSmith.\n"
]
},
{
"cell_type": "markdown",
"id": "0b93c7d0",
"metadata": {},
"source": [
"## Memory \n",
"\n",
"### Short-term memory\n",
"\n",
"Of course, `memory` is needed to enable conversation / persistence of information.\n",
"\n",
"LangChain has many options for [short-term memory](/docs/modules/memory/types/), which are frequently used in [chat](/docs/modules/memory/adding_memory.html). \n",
"\n",
"They can be [employed with agents](/docs/modules/memory/agent_with_memory) too.\n",
"\n",
"`ConversationBufferMemory` is a popular choice for short-term memory.\n",
"\n",
"We set `MEMORY_KEY`, which can be referenced by the prompt later.\n",
"\n",
"Now, let's add memory to our agent."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "1d291015",
"metadata": {},
"outputs": [],
"source": [
"# Memory \n",
"from langchain.memory import ConversationBufferMemory\n",
"MEMORY_KEY = \"chat_history\"\n",
"memory = ConversationBufferMemory(memory_key=MEMORY_KEY, return_messages=True)\n",
"\n",
"# Prompt w/ placeholder for memory\n",
"from langchain.schema import SystemMessage\n",
"from langchain.agents import OpenAIFunctionsAgent\n",
"from langchain.prompts import MessagesPlaceholder\n",
"system_message = SystemMessage(content=\"You are a search assistant tasked with using Serpapi to answer questions.\")\n",
"prompt = OpenAIFunctionsAgent.create_prompt(\n",
" system_message=system_message,\n",
" extra_prompt_messages=[MessagesPlaceholder(variable_name=MEMORY_KEY)]\n",
")\n",
"\n",
"# Agent\n",
"search_agent_memory = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt, memory=memory)\n",
"agent_executor_memory = AgentExecutor(agent=search_agent_memory, tools=tools, memory=memory, verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "b4b2249a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'As of August 2023, the estimated population of Canada is approximately 38,781,291 people.'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor_memory.run(\"How many people live in Canada as of August, 2023?\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "4d31b0cf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'As of August 2023, the largest province in Canada is Ontario, with a population of over 15 million people.'"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor_memory.run(\"What is the population of its largest provence as of August, 2023?\")"
]
},
{
"cell_type": "markdown",
"id": "3606c32a",
"metadata": {},
"source": [
"Looking at the [trace](https://smith.langchain.com/public/4425a131-ec90-4aaa-acd8-5b880c7452a3/r), we can what is happening:\n",
"\n",
"* The chat history is passed to the LLMs\n",
"* This gives context to `its` in `What is the population of its largest provence as of August, 2023?`\n",
"* The LLM generates a function call to the search tool\n",
"\n",
"```\n",
"function_call:\n",
" name: Search\n",
" arguments: |-\n",
" {\n",
" \"query\": \"population of largest province in Canada as of August 2023\"\n",
" }\n",
"```\n",
"\n",
"* The search is executed\n",
"* The results frum search are passed back to the LLM for synthesis into an answer\n",
"\n",
"![Image description](/img/oai_function_agent.png)"
]
},
{
"cell_type": "markdown",
"id": "384e37f8",
"metadata": {},
"source": [
"### Long-term memory \n",
"\n",
"Vectorstores are great options for long-term memory."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "1489746c",
"metadata": {},
"outputs": [],
"source": [
"import faiss\n",
"from langchain.vectorstores import FAISS\n",
"from langchain.docstore import InMemoryDocstore\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"embedding_size = 1536\n",
"embeddings_model = OpenAIEmbeddings()\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})"
]
},
{
"cell_type": "markdown",
"id": "9668ef5d",
"metadata": {},
"source": [
"### Going deeper \n",
"\n",
"* Explore projects using long-term memory, such as [autonomous agents](/docs/use_cases/autonomous_agents/autonomous_agents)."
]
},
{
"cell_type": "markdown",
"id": "43fe2bb3",
"metadata": {},
"source": [
"## Tools \n",
"\n",
"As mentioned above, LangChain has [many tools](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/load_tools.py) for Agents that we can load easily.\n",
"\n",
"We can also define [custom tools](/docs/modules/agents/tools/custom_tools). For example, here is a search tool.\n",
"\n",
"* The `Tool` dataclass wraps functions that accept a single string input and returns a string output.\n",
"* `return_direct` determines whether to return the tool's output directly. \n",
"* Setting this to `True` means that after the tool is called, the `AgentExecutor` will stop looping."
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "7357e496",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import Tool, tool\n",
"from langchain.utilities import GoogleSearchAPIWrapper\n",
"search = GoogleSearchAPIWrapper()\n",
"search_tool = [\n",
" Tool(\n",
" name=\"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\",\n",
" return_direct=True,\n",
" )\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "c6ef5bfa",
"metadata": {},
"source": [
"To make it easier to define custom tools, a `@tool` decorator is provided. \n",
"\n",
"This decorator can be used to quickly create a Tool from a simple function."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b6308c69",
"metadata": {},
"outputs": [],
"source": [
"# Tool\n",
"@tool\n",
"def get_word_length(word: str) -> int:\n",
" \"\"\"Returns the length of a word.\"\"\"\n",
" return len(word)\n",
"word_length_tool = [get_word_length]"
]
},
{
"cell_type": "markdown",
"id": "83c104d7",
"metadata": {},
"source": [
"### Going deeper\n",
"\n",
"**Toolkits**\n",
"\n",
"* Toolkits are groups of tools needed to accomplish specific objectives.\n",
"* [Here](/docs/integrations/toolkits/) are > 15 different agent toolkits (e.g., Gmail, Pandas, etc). \n",
"\n",
"Here is a simple way to think about agents vs the various chains covered in other docs:\n",
"\n",
"![Image description](/img/agents_vs_chains.png)"
]
},
{
"cell_type": "markdown",
"id": "5eefe4a0",
"metadata": {},
"source": [
"## Agents\n",
"\n",
"There's a number of [action agent types](docs/modules/agents/agent_types/) available in LangChain.\n",
"\n",
"* [ReAct](/docs/modules/agents/agent_types/react.html): This is the most general purpose action agent using the [ReAct framework](https://arxiv.org/pdf/2205.00445.pdf), which can work with [Docstores](/docs/modules/agents/agent_types/react_docstore.html) or [Multi-tool Inputs](/docs/modules/agents/agent_types/structured_chat.html).\n",
"* [OpenAI functions](/docs/modules/agents/agent_types/openai_functions_agent.html): Designed to work with OpenAI function-calling models.\n",
"* [Conversational](/docs/modules/agents/agent_types/chat_conversation_agent.html): This agent is designed to be used in conversational settings\n",
"* [Self-ask with search](/docs/modules/agents/agent_types/self_ask_with_search.html): Designed to lookup factual answers to questions\n",
"\n",
"### OpenAI Functions agent\n",
"\n",
"As shown in Quickstart, let's continue with [`OpenAI functions` agent](/docs/modules/agents/agent_types/).\n",
"\n",
"This uses OpenAI models, which are fine-tuned to detect when a function should to be called.\n",
"\n",
"They will respond with the inputs that should be passed to the function.\n",
"\n",
"But, we can unpack it, first with a custom prompt:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "1c2deb4a",
"metadata": {},
"outputs": [],
"source": [
"# Memory\n",
"MEMORY_KEY = \"chat_history\"\n",
"memory = ConversationBufferMemory(memory_key=MEMORY_KEY, return_messages=True)\n",
"\n",
"# Prompt\n",
"from langchain.schema import SystemMessage\n",
"from langchain.agents import OpenAIFunctionsAgent\n",
"system_message = SystemMessage(content=\"You are very powerful assistant, but bad at calculating lengths of words.\")\n",
"prompt = OpenAIFunctionsAgent.create_prompt(\n",
" system_message=system_message,\n",
" extra_prompt_messages=[MessagesPlaceholder(variable_name=MEMORY_KEY)]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ee317a45",
"metadata": {},
"source": [
"Define agent:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "460dab9b",
"metadata": {},
"outputs": [],
"source": [
"# Agent \n",
"from langchain.agents import OpenAIFunctionsAgent\n",
"agent = OpenAIFunctionsAgent(llm=llm, tools=word_length_tool, prompt=prompt)"
]
},
{
"cell_type": "markdown",
"id": "184e6c23",
"metadata": {},
"source": [
"Run agent:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "f4f27d37",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'There are 5 letters in the word \"educa\".'"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Run the executer, including short-term memory we created\n",
"agent_executor = AgentExecutor(agent=agent, tools=word_length_tool, memory=memory, verbose=False)\n",
"agent_executor.run(\"how many letters in the word educa?\")"
]
},
{
"cell_type": "markdown",
"id": "e4d9217e",
"metadata": {},
"source": [
"### ReAct agent\n",
"\n",
"[ReAct](https://arxiv.org/abs/2210.03629) agents are another popular framework.\n",
"\n",
"There has been lots of work on [LLM reasoning](https://ai.googleblog.com/2022/05/language-models-perform-reasoning-via.html), such as chain-of-thought prompting.\n",
"\n",
"There also has been work on LLM action-taking to generate obervations, such as [Say-Can](https://say-can.github.io/).\n",
"\n",
"ReAct marries these two ideas:\n",
"\n",
"![Image description](/img/ReAct.png)\n",
" \n",
"It uses a charecteristic `Thought`, `Action`, `Observation` [pattern in the output](https://lilianweng.github.io/posts/2023-06-23-agent/).\n",
" \n",
"We can use `initialize_agent` to create the ReAct agent from a list of available types [here](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/agents/types.py):\n",
"\n",
"```\n",
"* AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent\n",
"* AgentType.REACT_DOCSTORE: ReActDocstoreAgent\n",
"* AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent\n",
"* AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent\n",
"* AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent\n",
"* AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent\n",
"* AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION: StructuredChatAgent\n",
"* AgentType.OPENAI_FUNCTIONS: OpenAIFunctionsAgent\n",
"* AgentType.OPENAI_MULTI_FUNCTIONS: OpenAIMultiFunctionsAgent\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "85f033d3",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentType\n",
"from langchain.agents import initialize_agent\n",
"MEMORY_KEY = \"chat_history\"\n",
"memory = ConversationBufferMemory(memory_key=MEMORY_KEY, return_messages=True)\n",
"react_agent = initialize_agent(search_tool, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False, memory=memory)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d05a26c",
"metadata": {},
"outputs": [],
"source": [
"react_agent(\"How many people live in Canada as of August, 2023?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b626dc5",
"metadata": {},
"outputs": [],
"source": [
"react_agent(\"What is the population of its largest provence as of August, 2023?\")"
]
},
{
"cell_type": "markdown",
"id": "d4df0638",
"metadata": {},
"source": [
"LangSmith can help us run diagnostics on the ReAct agent:\n",
"\n",
"The [ReAct agent](https://smith.langchain.com/public/3d8d0a15-d73f-44f3-9f81-037f7031c592/r) fails to pass chat history to LLM, gets wrong answer.\n",
" \n",
"The OAI functions agent does and [gets right answer](https://smith.langchain.com/public/4425a131-ec90-4aaa-acd8-5b880c7452a3/r), as shown above.\n",
" \n",
"Also the search tool result for [ReAct](https://smith.langchain.com/public/6473e608-fc9d-47c9-a8a4-2ef7f2801d82/r) is worse than [OAI](https://smith.langchain.com/public/4425a131-ec90-4aaa-acd8-5b880c7452a3/r/26b85fa9-e33a-4028-8650-1714f8b3db96).\n",
"\n",
"Collectivly, this tells us: carefully inspect Agent traces and tool outputs. \n",
"\n",
"As we saw with the [SQL use case](/docs/use_cases/sql), `ReAct agents` can be work very well for specific problems. \n",
"\n",
"But, as shown here, the result is degraded relative to what we see with the OpenAI agent."
]
},
{
"cell_type": "markdown",
"id": "5cde8f9a",
"metadata": {},
"source": [
"### Custom\n",
"\n",
"Let's peel it back even further to define our own action agent.\n",
"\n",
"We can [create a custom agent](/docs/modules/agents/how_to/custom_agent.html) to unpack the central pieces:\n",
"\n",
"* `Tools`: The tools the agent has available to use\n",
"* `Agent`: decides which action to take"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "3313f5cd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"The current population of Canada is 38,808,843 as of Tuesday, August 1, 2023, based on Worldometer elaboration of the latest United Nations data 1. Canada 2023\\xa0... Mar 22, 2023 ... Record-high population growth in the year 2022. Canada's population was estimated at 39,566,248 on January 1, 2023, after a record population\\xa0... Jun 19, 2023 ... As of June 16, 2023, there are now 40 million Canadians! This is a historic milestone for Canada and certainly cause for celebration. It is also\\xa0... Jun 28, 2023 ... Canada's population was estimated at 39,858,480 on April 1, 2023, an increase of 292,232 people (+0.7%) from January 1, 2023. The main driver of population growth is immigration, and to a lesser extent, natural growth. Demographics of Canada · Population pyramid of Canada in 2023. May 2, 2023 ... On January 1, 2023, Canada's population was estimated to be 39,566,248, following an unprecedented increase of 1,050,110 people between January\\xa0... Canada ranks 37th by population among countries of the world, comprising about 0.5% of the world's total, with over 40.0 million Canadians as of 2023. The current population of Canada in 2023 is 38,781,291, a 0.85% increase from 2022. The population of Canada in 2022 was 38,454,327, a 0.78% increase from 2021. Whether a given sub-nation is a province or a territory depends upon how its power and authority are derived. Provinces were given their power by the\\xa0... Jun 28, 2023 ... Index to the latest information from the Census of Population. ... 2023. Census in Brief: Multilingualism of Canadian households\\xa0...\""
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import List, Tuple, Any, Union\n",
"from langchain.schema import AgentAction, AgentFinish\n",
"from langchain.agents import Tool, AgentExecutor, BaseSingleActionAgent\n",
"\n",
"class FakeAgent(BaseSingleActionAgent):\n",
" \"\"\"Fake Custom Agent.\"\"\"\n",
"\n",
" @property\n",
" def input_keys(self):\n",
" return [\"input\"]\n",
"\n",
" def plan(\n",
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
" ) -> Union[AgentAction, AgentFinish]:\n",
" \"\"\"Given input, decided what to do.\n",
"\n",
" Args:\n",
" intermediate_steps: Steps the LLM has taken to date,\n",
" along with observations\n",
" **kwargs: User inputs.\n",
"\n",
" Returns:\n",
" Action specifying what tool to use.\n",
" \"\"\"\n",
" return AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\")\n",
"\n",
" async def aplan(\n",
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
" ) -> Union[AgentAction, AgentFinish]:\n",
" \"\"\"Given input, decided what to do.\n",
"\n",
" Args:\n",
" intermediate_steps: Steps the LLM has taken to date,\n",
" along with observations\n",
" **kwargs: User inputs.\n",
"\n",
" Returns:\n",
" Action specifying what tool to use.\n",
" \"\"\"\n",
" return AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\")\n",
" \n",
"fake_agent = FakeAgent()\n",
"fake_agent_executor = AgentExecutor.from_agent_and_tools(agent=fake_agent, \n",
" tools=search_tool, \n",
" verbose=False)\n",
"\n",
"fake_agent_executor.run(\"How many people live in canada as of 2023?\")"
]
},
{
"cell_type": "markdown",
"id": "1335f0c6",
"metadata": {},
"source": [
"## Runtime\n",
"\n",
"The `AgentExecutor` class is the main agent runtime supported by LangChain. \n",
"\n",
"However, there are other, more experimental runtimes for `autonomous_agents`:\n",
" \n",
"* Plan-and-execute Agent\n",
"* Baby AGI\n",
"* Auto GPT\n",
"\n",
"Explore more about:\n",
"\n",
"* [`Simulation agents`](/docs/modules/agents/agent_use_cases/agent_simulations): Designed for role-play often in simulated enviorment (e.g., Generative Agents, CAMEL).\n",
"* [`Autonomous agents`](/docs/modules/agents/agent_use_cases/autonomous_agents): Designed for indepdent execution towards long term goals (e.g., BabyAGI, Auto-GPT).\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

View File

@@ -249,7 +249,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.9.16"
}
},
"nbformat": 4,

View File

@@ -245,7 +245,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
"version": "3.9.16"
}
},
"nbformat": 4,

View File

@@ -25,8 +25,7 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install gpt4all\n",
"! pip install chromadb"
"pip install gpt4all chromadb"
]
},
{
@@ -157,7 +156,7 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install llama-cpp-python"
"pip install llama-cpp-python"
]
},
{
@@ -736,7 +735,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.9.16"
}
},
"nbformat": 4,

View File

@@ -1,342 +0,0 @@
---
sidebar_position: -1
---
# QA over Documents
## Use case
Suppose you have some text documents (PDF, blog, Notion pages, etc.) and want to ask questions related to the contents of those documents. LLMs, given their proficiency in understanding text, are a great tool for this.
In this walkthrough we'll go over how to build a question-answering over documents application using LLMs. Two very related use cases which we cover elsewhere are:
- [QA over structured data](/docs/use_cases/tabular) (e.g., SQL)
- [QA over code](/docs/use_cases/code) (e.g., Python)
![intro.png](/img/qa_intro.png)
## Overview
The pipeline for converting raw unstructured data into a QA chain looks like this:
1. `Loading`: First we need to load our data. Unstructured data can be loaded from many sources. Use the [LangChain integration hub](https://integrations.langchain.com/) to browse the full set of loaders.
Each loader returns data as a LangChain [`Document`](https://docs.langchain.com/docs/components/schema/document).
2. `Splitting`: [Text splitters](/docs/modules/data_connection/document_transformers/) break `Documents` into splits of specified size
3. `Storage`: Storage (e.g., often a [vectorstore](/docs/modules/data_connection/vectorstores/)) will house [and often embed](https://www.pinecone.io/learn/vector-embeddings/) the splits
4. `Retrieval`: The app retrieves splits from storage (e.g., often [with similar embeddings](https://www.pinecone.io/learn/k-nearest-neighbor/) to the input question)
5. `Generation`: An [LLM](/docs/modules/model_io/models/llms/) produces an answer using a prompt that includes the question and the retrieved data
6. `Conversation` (Extension): Hold a multi-turn conversation by adding [Memory](/docs/modules/memory/) to your QA chain.
![flow.jpeg](/img/qa_flow.jpeg)
## Quickstart
To give you a sneak preview, the above pipeline can be all be wrapped in a single object: `VectorstoreIndexCreator`. Suppose we want a QA app over this [blog post](https://lilianweng.github.io/posts/2023-06-23-agent/). We can create this in a few lines of code:
First set environment variables and install packages:
```bash
pip install openai chromadb
export OPENAI_API_KEY="..."
```
Then run:
```python
from langchain.document_loaders import WebBaseLoader
from langchain.indexes import VectorstoreIndexCreator
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
index = VectorstoreIndexCreator().from_loaders([loader])
```
And now ask your questions:
```python
index.query("What is Task Decomposition?")
```
' Task decomposition is a technique used to break down complex tasks into smaller and simpler steps. It can be done using LLM with simple prompting, task-specific instructions, or human inputs. Tree of Thoughts (Yao et al. 2023) is an example of a task decomposition technique that explores multiple reasoning possibilities at each step and generates multiple thoughts per step, creating a tree structure.'
Ok, but what's going on under the hood, and how could we customize this for our specific use case? For that, let's take a look at how we can construct this pipeline piece by piece.
## Step 1. Load
Specify a `DocumentLoader` to load in your unstructured data as `Documents`. A `Document` is a piece of text (the `page_content`) and associated metadata.
```python
from langchain.document_loaders import WebBaseLoader
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
data = loader.load()
```
### Go deeper
- Browse the > 120 data loader integrations [here](https://integrations.langchain.com/).
- See further documentation on loaders [here](/docs/modules/data_connection/document_loaders/).
## Step 2. Split
Split the `Document` into chunks for embedding and vector storage.
```python
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 0)
all_splits = text_splitter.split_documents(data)
```
### Go deeper
- `DocumentSplitters` are just one type of the more generic `DocumentTransformers`, which can all be useful in this preprocessing step.
- See further documentation on transformers [here](/docs/modules/data_connection/document_transformers/).
- `Context-aware splitters` keep the location ("context") of each split in the original `Document`:
- [Markdown files](/docs/use_cases/question_answering/document-context-aware-QA)
- [Code (py or js)](/docs/modules/data_connection/document_loaders/integrations/source_code)
- [Documents](/docs/modules/data_connection/document_loaders/integrations/grobid)
## Step 3. Store
To be able to look up our document splits, we first need to store them where we can later look them up.
The most common way to do this is to embed the contents of each document then store the embedding and document in a vector store, with the embedding being used to index the document.
```python
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())
```
### Go deeper
- Browse the > 40 vectorstores integrations [here](https://integrations.langchain.com/).
- See further documentation on vectorstores [here](/docs/modules/data_connection/vectorstores/).
- Browse the > 30 text embedding integrations [here](https://integrations.langchain.com/).
- See further documentation on embedding models [here](/docs/modules/data_connection/text_embedding/).
Here are Steps 1-3:
![lc.png](/img/qa_data_load.png)
## Step 4. Retrieve
Retrieve relevant splits for any question using [similarity search](https://www.pinecone.io/learn/what-is-similarity-search/).
```python
question = "What are the approaches to Task Decomposition?"
docs = vectorstore.similarity_search(question)
len(docs)
```
4
### Go deeper
Vectorstores are commonly used for retrieval, but they are not the only option. For example, SVMs (see thread [here](https://twitter.com/karpathy/status/1647025230546886658?s=20)) can also be used.
LangChain [has many retrievers](/docs/modules/data_connection/retrievers/) including, but not limited to, vectorstores. All retrievers implement a common method `get_relevant_documents()` (and its asynchronous variant `aget_relevant_documents()`).
```python
from langchain.retrievers import SVMRetriever
svm_retriever = SVMRetriever.from_documents(all_splits,OpenAIEmbeddings())
docs_svm=svm_retriever.get_relevant_documents(question)
len(docs_svm)
```
4
Some common ways to improve on vector similarity search include:
- `MultiQueryRetriever` [generates variants of the input question](/docs/modules/data_connection/retrievers/MultiQueryRetriever) to improve retrieval.
- `Max marginal relevance` selects for [relevance and diversity](https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf) among the retrieved documents.
- Documents can be filtered during retrieval using [`metadata` filters](/docs/use_cases/question_answering/how_to/document-context-aware-QA).
```python
import logging
from langchain.chat_models import ChatOpenAI
from langchain.retrievers.multi_query import MultiQueryRetriever
logging.basicConfig()
logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)
retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectorstore.as_retriever(),
llm=ChatOpenAI(temperature=0))
unique_docs = retriever_from_llm.get_relevant_documents(query=question)
len(unique_docs)
```
INFO:langchain.retrievers.multi_query:Generated queries: ['1. How can Task Decomposition be approached?', '2. What are the different methods for Task Decomposition?', '3. What are the various approaches to decomposing tasks?']
5
## Step 5. Generate
Distill the retrieved documents into an answer using an LLM/Chat model (e.g., `gpt-3.5-turbo`) with `RetrievalQA` chain.
```python
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever())
qa_chain({"query": question})
```
{
'query': 'What are the approaches to Task Decomposition?',
'result': 'The approaches to task decomposition include:\n\n1. Simple prompting: This approach involves using simple prompts or questions to guide the agent in breaking down a task into smaller subgoals. For example, the agent can be prompted with "Steps for XYZ" and asked to list the subgoals for achieving XYZ.\n\n2. Task-specific instructions: In this approach, task-specific instructions are provided to the agent to guide the decomposition process. For example, if the task is to write a novel, the agent can be instructed to "Write a story outline" as a subgoal.\n\n3. Human inputs: This approach involves incorporating human inputs in the task decomposition process. Humans can provide guidance, feedback, and suggestions to help the agent break down complex tasks into manageable subgoals.\n\nThese approaches aim to enable efficient handling of complex tasks by breaking them down into smaller, more manageable parts.'
}
Note, you can pass in an `LLM` or a `ChatModel` (like we did here) to the `RetrievalQA` chain.
### Go deeper
#### Choosing LLMs
- Browse the > 55 LLM and chat model integrations [here](https://integrations.langchain.com/).
- See further documentation on LLMs and chat models [here](/docs/modules/model_io/models/).
- Use local LLMS: The popularity of [PrivateGPT](https://github.com/imartinez/privateGPT) and [GPT4All](https://github.com/nomic-ai/gpt4all) underscore the importance of running LLMs locally.
Using `GPT4All` is as simple as [downloading the binary]((/docs/integrations/llms/gpt4all)) and then:
from langchain.llms import GPT4All
from langchain.chains import RetrievalQA
llm = GPT4All(model="/Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin",max_tokens=2048)
qa_chain = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever())
#### Customizing the prompt
The prompt in `RetrievalQA` chain can be easily customized.
```python
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectorstore.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
result = qa_chain({"query": question})
result["result"]
```
'The approaches to Task Decomposition are (1) using simple prompting by LLM, (2) using task-specific instructions, and (3) with human inputs. Thanks for asking!'
#### Return source documents
The full set of retrieved documents used for answer distillation can be returned using `return_source_documents=True`.
```python
from langchain.chains import RetrievalQA
qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),
return_source_documents=True)
result = qa_chain({"query": question})
print(len(result['source_documents']))
result['source_documents'][0]
```
4
Document(page_content='Task decomposition can be done (1) by LLM with simple prompting like "Steps for XYZ.\\n1.", "What are the subgoals for achieving XYZ?", (2) by using task-specific instructions; e.g. "Write a story outline." for writing a novel, or (3) with human inputs.', metadata={'source': 'https://lilianweng.github.io/posts/2023-06-23-agent/', 'title': "LLM Powered Autonomous Agents | Lil'Log", 'description': 'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as AutoGPT, GPT-Engineer and BabyAGI, serve as inspiring examples. The potentiality of LLM extends beyond generating well-written copies, stories, essays and programs; it can be framed as a powerful general problem solver.\nAgent System Overview In a LLM-powered autonomous agent system, LLM functions as the agents brain, complemented by several key components:', 'language': 'en'})
#### Return citations
Answer citations can be returned using `RetrievalQAWithSourcesChain`.
```python
from langchain.chains import RetrievalQAWithSourcesChain
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm,retriever=vectorstore.as_retriever())
result = qa_chain({"question": question})
result
```
{
'question': 'What are the approaches to Task Decomposition?',
'answer': 'The approaches to Task Decomposition include (1) using LLM with simple prompting, (2) using task-specific instructions, and (3) incorporating human inputs.\n',
'sources': 'https://lilianweng.github.io/posts/2023-06-23-agent/'
}
#### Customizing retrieved document processing
Retrieved documents can be fed to an LLM for answer distillation in a few different ways.
`stuff`, `refine`, `map-reduce`, and `map-rerank` chains for passing documents to an LLM prompt are well summarized [here](/docs/modules/chains/document/).
`stuff` is commonly used because it simply "stuffs" all retrieved documents into the prompt.
The [load_qa_chain](/docs/use_cases/question_answering/how_to/question_answering.html) is an easy way to pass documents to an LLM using these various approaches (e.g., see `chain_type`).
```python
from langchain.chains.question_answering import load_qa_chain
chain = load_qa_chain(llm, chain_type="stuff")
chain({"input_documents": unique_docs, "question": question},return_only_outputs=True)
```
{'output_text': 'The approaches to task decomposition include (1) using simple prompting to break down tasks into subgoals, (2) providing task-specific instructions to guide the decomposition process, and (3) incorporating human inputs for task decomposition.'}
We can also pass the `chain_type` to `RetrievalQA`.
```python
qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),
chain_type="stuff")
result = qa_chain({"query": question})
```
In summary, the user can choose the desired level of abstraction for QA:
![summary_chains.png](/img/summary_chains.png)
## Step 6. Converse (Extension)
To hold a conversation, a chain needs to be able to refer to past interactions. Chain `Memory` allows us to do this. To keep chat history, we can specify a Memory buffer to track the conversation inputs / outputs.
```python
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
```
The `ConversationalRetrievalChain` uses chat in the `Memory buffer`.
```python
from langchain.chains import ConversationalRetrievalChain
retriever = vectorstore.as_retriever()
chat = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory)
```
```python
result = chat({"question": "What are some of the main ideas in self-reflection?"})
result['answer']
```
"Some of the main ideas in self-reflection include:\n1. Iterative improvement: Self-reflection allows autonomous agents to improve by refining past action decisions and correcting mistakes.\n2. Trial and error: Self-reflection is crucial in real-world tasks where trial and error are inevitable.\n3. Two-shot examples: Self-reflection is created by showing pairs of failed trajectories and ideal reflections for guiding future changes in the plan.\n4. Working memory: Reflections are added to the agent's working memory, up to three, to be used as context for querying.\n5. Performance evaluation: Self-reflection involves continuously reviewing and analyzing actions, self-criticizing behavior, and reflecting on past decisions and strategies to refine approaches.\n6. Efficiency: Self-reflection encourages being smart and efficient, aiming to complete tasks in the least number of steps."
The Memory buffer has context to resolve `"it"` ("self-reflection") in the below question.
```python
result = chat({"question": "How does the Reflexion paper handle it?"})
result['answer']
```
"The Reflexion paper handles self-reflection by showing two-shot examples to the Learning Language Model (LLM). Each example consists of a failed trajectory and an ideal reflection that guides future changes in the agent's plan. These reflections are then added to the agent's working memory, up to a maximum of three, to be used as context for querying the LLM. This allows the agent to iteratively improve its reasoning skills by refining past action decisions and correcting previous mistakes."
### Go deeper
The [documentation](/docs/use_cases/question_answering/how_to/chat_vector_db) on `ConversationalRetrievalChain` offers a few extensions, such as streaming and source documents.
## Further reading
- Check out the [How to](/docs/use_cases/question_answer/how_to/) section for all the variations of chains that can be used for QA over docs in different settings.
- Check out the [Integrations-specific](/docs/use_cases/question_answer/integrations/) section for chains that use specific integrations.

View File

@@ -0,0 +1,686 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5151afed",
"metadata": {},
"source": [
"# Question Answering\n",
"\n",
"[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/langchain-ai/langchain/blob/master/docs/extras/use_cases/question_answering/qa.ipynb)\n",
"\n",
"## Use case\n",
"Suppose you have some text documents (PDF, blog, Notion pages, etc.) and want to ask questions related to the contents of those documents. LLMs, given their proficiency in understanding text, are a great tool for this.\n",
"\n",
"In this walkthrough we'll go over how to build a question-answering over documents application using LLMs. Two very related use cases which we cover elsewhere are:\n",
"- [QA over structured data](/docs/use_cases/sql) (e.g., SQL)\n",
"- [QA over code](/docs/use_cases/code) (e.g., Python)\n",
"\n",
"![intro.png](/img/qa_intro.png)\n",
"\n",
"## Overview\n",
"The pipeline for converting raw unstructured data into a QA chain looks like this:\n",
"1. `Loading`: First we need to load our data. Unstructured data can be loaded from many sources. Use the [LangChain integration hub](https://integrations.langchain.com/) to browse the full set of loaders.\n",
"Each loader returns data as a LangChain [`Document`](/docs/components/schema/document).\n",
"2. `Splitting`: [Text splitters](/docs/modules/data_connection/document_transformers/) break `Documents` into splits of specified size\n",
"3. `Storage`: Storage (e.g., often a [vectorstore](/docs/modules/data_connection/vectorstores/)) will house [and often embed](https://www.pinecone.io/learn/vector-embeddings/) the splits\n",
"4. `Retrieval`: The app retrieves splits from storage (e.g., often [with similar embeddings](https://www.pinecone.io/learn/k-nearest-neighbor/) to the input question)\n",
"5. `Generation`: An [LLM](/docs/modules/model_io/models/llms/) produces an answer using a prompt that includes the question and the retrieved data\n",
"6. `Conversation` (Extension): Hold a multi-turn conversation by adding [Memory](/docs/modules/memory/) to your QA chain.\n",
"\n",
"![flow.jpeg](/img/qa_flow.jpeg)\n",
"\n",
"## Quickstart\n",
"\n",
"To give you a sneak preview, the above pipeline can be all be wrapped in a single object: `VectorstoreIndexCreator`. Suppose we want a QA app over this [blog post](https://lilianweng.github.io/posts/2023-06-23-agent/). We can create this in a few lines of code. First set environment variables and install packages:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e14b744b",
"metadata": {},
"outputs": [],
"source": [
"pip install openai chromadb\n",
"\n",
"# Set env var OPENAI_API_KEY or load from a .env file\n",
"# import dotenv\n",
"\n",
"# dotenv.load_env()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "046cefc0",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import WebBaseLoader\n",
"from langchain.indexes import VectorstoreIndexCreator\n",
"\n",
"loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
"index = VectorstoreIndexCreator().from_loaders([loader])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f4bf8740",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' Task decomposition is a technique used to break down complex tasks into smaller and simpler steps. It can be done using LLM with simple prompting, task-specific instructions, or with human inputs. Tree of Thoughts (Yao et al. 2023) is an extension of Chain of Thought (Wei et al. 2022) which explores multiple reasoning possibilities at each step.'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.query(\"What is Task Decomposition?\")"
]
},
{
"cell_type": "markdown",
"id": "8224aad6",
"metadata": {},
"source": [
"Ok, but what's going on under the hood, and how could we customize this for our specific use case? For that, let's take a look at how we can construct this pipeline piece by piece."
]
},
{
"cell_type": "markdown",
"id": "ba5daed6",
"metadata": {},
"source": [
"## Step 1. Load\n",
"\n",
"Specify a `DocumentLoader` to load in your unstructured data as `Documents`. A `Document` is a piece of text (the `page_content`) and associated metadata."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cf4d5c72",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import WebBaseLoader\n",
"\n",
"loader = WebBaseLoader(\"https://lilianweng.github.io/posts/2023-06-23-agent/\")\n",
"data = loader.load()"
]
},
{
"cell_type": "markdown",
"id": "fd2cc9a7",
"metadata": {},
"source": [
"### Go deeper\n",
"- Browse the > 120 data loader integrations [here](https://integrations.langchain.com/).\n",
"- See further documentation on loaders [here](/docs/modules/data_connection/document_loaders/).\n",
"\n",
"## Step 2. Split\n",
"\n",
"Split the `Document` into chunks for embedding and vector storage."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4b11c01d",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 0)\n",
"all_splits = text_splitter.split_documents(data)"
]
},
{
"cell_type": "markdown",
"id": "0a33bd4d",
"metadata": {},
"source": [
"### Go deeper\n",
"\n",
"- `DocumentSplitters` are just one type of the more generic `DocumentTransformers`, which can all be useful in this preprocessing step.\n",
"- See further documentation on transformers [here](/docs/modules/data_connection/document_transformers/).\n",
"- `Context-aware splitters` keep the location (\"context\") of each split in the original `Document`:\n",
" - [Markdown files](/docs/use_cases/question_answering/how_to/document-context-aware-QA)\n",
" - [Code (py or js)](docs/integrations/document_loaders/source_code)\n",
" - [Documents](/docs/integrations/document_loaders/grobid)\n",
"\n",
"## Step 3. Store\n",
"\n",
"To be able to look up our document splits, we first need to store them where we can later look them up.\n",
"The most common way to do this is to embed the contents of each document then store the embedding and document in a vector store, with the embedding being used to index the document."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e9c302c8",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.vectorstores import Chroma\n",
"\n",
"vectorstore = Chroma.from_documents(documents=all_splits, embedding=OpenAIEmbeddings())"
]
},
{
"cell_type": "markdown",
"id": "dc6f22b0",
"metadata": {},
"source": [
"### Go deeper\n",
"- Browse the > 40 vectorstores integrations [here](https://integrations.langchain.com/).\n",
"- See further documentation on vectorstores [here](/docs/modules/data_connection/vectorstores/).\n",
"- Browse the > 30 text embedding integrations [here](https://integrations.langchain.com/).\n",
"- See further documentation on embedding models [here](/docs/modules/data_connection/text_embedding/).\n",
"\n",
" Here are Steps 1-3:\n",
"\n",
"![lc.png](/img/qa_data_load.png)\n",
"\n",
"## Step 4. Retrieve\n",
"\n",
"Retrieve relevant splits for any question using [similarity search](https://www.pinecone.io/learn/what-is-similarity-search/)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e2c26b7d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"What are the approaches to Task Decomposition?\"\n",
"docs = vectorstore.similarity_search(question)\n",
"len(docs)"
]
},
{
"cell_type": "markdown",
"id": "5d5a113b",
"metadata": {},
"source": [
"### Go deeper\n",
"\n",
"Vectorstores are commonly used for retrieval, but they are not the only option. For example, SVMs (see thread [here](https://twitter.com/karpathy/status/1647025230546886658?s=20)) can also be used.\n",
"\n",
"LangChain [has many retrievers](/docs/modules/data_connection/retrievers/) including, but not limited to, vectorstores. All retrievers implement a common method `get_relevant_documents()` (and its asynchronous variant `aget_relevant_documents()`)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c901eaee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.retrievers import SVMRetriever\n",
"\n",
"svm_retriever = SVMRetriever.from_documents(all_splits,OpenAIEmbeddings())\n",
"docs_svm=svm_retriever.get_relevant_documents(question)\n",
"len(docs_svm)"
]
},
{
"cell_type": "markdown",
"id": "69de3d54",
"metadata": {},
"source": [
"Some common ways to improve on vector similarity search include:\n",
"- `MultiQueryRetriever` [generates variants of the input question](/docs/modules/data_connection/retrievers/MultiQueryRetriever) to improve retrieval.\n",
"- `Max marginal relevance` selects for [relevance and diversity](https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf) among the retrieved documents.\n",
"- Documents can be filtered during retrieval using [`metadata` filters](/docs/use_cases/question_answering/how_to/document-context-aware-QA)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c690f01a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:langchain.retrievers.multi_query:Generated queries: ['1. How can Task Decomposition be approached?', '2. What are the different methods for Task Decomposition?', '3. What are the various approaches to decomposing tasks?']\n"
]
},
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import logging\n",
"\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.retrievers.multi_query import MultiQueryRetriever\n",
"\n",
"logging.basicConfig()\n",
"logging.getLogger('langchain.retrievers.multi_query').setLevel(logging.INFO)\n",
"\n",
"retriever_from_llm = MultiQueryRetriever.from_llm(retriever=vectorstore.as_retriever(),\n",
" llm=ChatOpenAI(temperature=0))\n",
"unique_docs = retriever_from_llm.get_relevant_documents(query=question)\n",
"len(unique_docs)"
]
},
{
"cell_type": "markdown",
"id": "415d6824",
"metadata": {},
"source": [
"## Step 5. Generate\n",
"\n",
"Distill the retrieved documents into an answer using an LLM/Chat model (e.g., `gpt-3.5-turbo`) with `RetrievalQA` chain.\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "99fa1aec",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'query': 'What are the approaches to Task Decomposition?',\n",
" 'result': 'There are three approaches to task decomposition:\\n\\n1. Using Language Model with simple prompting: This approach involves using a Language Model (LLM) with simple prompts like \"Steps for XYZ\" or \"What are the subgoals for achieving XYZ?\" to guide the task decomposition process.\\n\\n2. Using task-specific instructions: In this approach, task-specific instructions are provided to guide the task decomposition. For example, for the task of writing a novel, an instruction like \"Write a story outline\" can be given to help decompose the task into smaller subtasks.\\n\\n3. Human inputs: Task decomposition can also be done with the help of human inputs. This involves getting input and guidance from humans to break down a complex task into smaller, more manageable subtasks.'}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains import RetrievalQA\n",
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0)\n",
"qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever())\n",
"qa_chain({\"query\": question})"
]
},
{
"cell_type": "markdown",
"id": "f7d52c84",
"metadata": {},
"source": [
"Note, you can pass in an `LLM` or a `ChatModel` (like we did here) to the `RetrievalQA` chain.\n",
"\n",
"### Go deeper\n",
"\n",
"#### Choosing LLMs\n",
"- Browse the > 55 LLM and chat model integrations [here](https://integrations.langchain.com/).\n",
"- See further documentation on LLMs and chat models [here](/docs/modules/model_io/models/).\n",
"- Use local LLMS: The popularity of [PrivateGPT](https://github.com/imartinez/privateGPT) and [GPT4All](https://github.com/nomic-ai/gpt4all) underscore the importance of running LLMs locally.\n",
"Using `GPT4All` is as simple as [downloading the binary]((/docs/integrations/llms/gpt4all)) and then:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "02d6c9dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found model file at /Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"objc[61331]: Class GGMLMetalClass is implemented in both /Users/rlm/miniforge3/envs/llama/lib/python3.9/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libreplit-mainline-metal.dylib (0x2e3384208) and /Users/rlm/miniforge3/envs/llama/lib/python3.9/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/libllamamodel-mainline-metal.dylib (0x2e37b0208). One of the two will be used. Which one is undefined.\n",
"llama.cpp: using Metal\n",
"llama.cpp: loading model from /Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin\n",
"llama_model_load_internal: format = ggjt v3 (latest)\n",
"llama_model_load_internal: n_vocab = 32001\n",
"llama_model_load_internal: n_ctx = 2048\n",
"llama_model_load_internal: n_embd = 5120\n",
"llama_model_load_internal: n_mult = 256\n",
"llama_model_load_internal: n_head = 40\n",
"llama_model_load_internal: n_layer = 40\n",
"llama_model_load_internal: n_rot = 128\n",
"llama_model_load_internal: ftype = 2 (mostly Q4_0)\n",
"llama_model_load_internal: n_ff = 13824\n",
"llama_model_load_internal: n_parts = 1\n",
"llama_model_load_internal: model size = 13B\n",
"llama_model_load_internal: ggml ctx size = 0.09 MB\n",
"llama_model_load_internal: mem required = 9031.71 MB (+ 1608.00 MB per state)\n",
"llama_new_context_with_model: kv self size = 1600.00 MB\n",
"ggml_metal_init: allocating\n",
"ggml_metal_init: using MPS\n",
"ggml_metal_init: loading '/Users/rlm/miniforge3/envs/llama/lib/python3.9/site-packages/gpt4all/llmodel_DO_NOT_MODIFY/build/ggml-metal.metal'\n",
"ggml_metal_init: loaded kernel_add 0x2bbbbc2f0\n",
"ggml_metal_init: loaded kernel_mul 0x2bbbba840\n",
"ggml_metal_init: loaded kernel_mul_row 0x2bb917dd0\n",
"ggml_metal_init: loaded kernel_scale 0x2bb918150\n",
"ggml_metal_init: loaded kernel_silu 0x2bb9184d0\n",
"ggml_metal_init: loaded kernel_relu 0x2bb918850\n",
"ggml_metal_init: loaded kernel_gelu 0x2bbbc3f10\n",
"ggml_metal_init: loaded kernel_soft_max 0x2bbbc5840\n",
"ggml_metal_init: loaded kernel_diag_mask_inf 0x2bbbc4c70\n",
"ggml_metal_init: loaded kernel_get_rows_f16 0x2bbbc5fc0\n",
"ggml_metal_init: loaded kernel_get_rows_q4_0 0x2bbbc6720\n",
"ggml_metal_init: loaded kernel_get_rows_q4_1 0x2bb918c10\n",
"ggml_metal_init: loaded kernel_get_rows_q2_k 0x2bbbc51b0\n",
"ggml_metal_init: loaded kernel_get_rows_q3_k 0x2bbbc7630\n",
"ggml_metal_init: loaded kernel_get_rows_q4_k 0x2d4394e30\n",
"ggml_metal_init: loaded kernel_get_rows_q5_k 0x2bbbc7890\n",
"ggml_metal_init: loaded kernel_get_rows_q6_k 0x2d4395210\n",
"ggml_metal_init: loaded kernel_rms_norm 0x2bbbc8740\n",
"ggml_metal_init: loaded kernel_norm 0x2bbbc8b30\n",
"ggml_metal_init: loaded kernel_mul_mat_f16_f32 0x2d4395470\n",
"ggml_metal_init: loaded kernel_mul_mat_q4_0_f32 0x2d4395a70\n",
"ggml_metal_init: loaded kernel_mul_mat_q4_1_f32 0x1242b1a00\n",
"ggml_metal_init: loaded kernel_mul_mat_q2_k_f32 0x29f17d1c0\n",
"ggml_metal_init: loaded kernel_mul_mat_q3_k_f32 0x2d4396050\n",
"ggml_metal_init: loaded kernel_mul_mat_q4_k_f32 0x2bbbc98a0\n",
"ggml_metal_init: loaded kernel_mul_mat_q5_k_f32 0x2bbbca4a0\n",
"ggml_metal_init: loaded kernel_mul_mat_q6_k_f32 0x2bbbcae90\n",
"ggml_metal_init: loaded kernel_rope 0x2bbbca700\n",
"ggml_metal_init: loaded kernel_alibi_f32 0x2bbbcc6e0\n",
"ggml_metal_init: loaded kernel_cpy_f32_f16 0x2bbbccf90\n",
"ggml_metal_init: loaded kernel_cpy_f32_f32 0x2bbbcd900\n",
"ggml_metal_init: loaded kernel_cpy_f16_f16 0x2bbbce1f0\n",
"ggml_metal_init: recommendedMaxWorkingSetSize = 21845.34 MB\n",
"ggml_metal_init: hasUnifiedMemory = true\n",
"ggml_metal_init: maxTransferRate = built-in GPU\n",
"ggml_metal_add_buffer: allocated 'data ' buffer, size = 6984.06 MB, ( 6984.45 / 21845.34)\n",
"ggml_metal_add_buffer: allocated 'eval ' buffer, size = 1024.00 MB, ( 8008.45 / 21845.34)\n",
"ggml_metal_add_buffer: allocated 'kv ' buffer, size = 1602.00 MB, ( 9610.45 / 21845.34)\n",
"ggml_metal_add_buffer: allocated 'scr0 ' buffer, size = 512.00 MB, (10122.45 / 21845.34)\n",
"ggml_metal_add_buffer: allocated 'scr1 ' buffer, size = 512.00 MB, (10634.45 / 21845.34)\n"
]
}
],
"source": [
"from langchain.llms import GPT4All\n",
"from langchain.chains import RetrievalQA\n",
"\n",
"llm = GPT4All(model=\"/Users/rlm/Desktop/Code/gpt4all/models/nous-hermes-13b.ggmlv3.q4_0.bin\",max_tokens=2048)\n",
"qa_chain = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever())"
]
},
{
"cell_type": "markdown",
"id": "fa82f437",
"metadata": {},
"source": [
"#### Customizing the prompt\n",
"\n",
"The prompt in `RetrievalQA` chain can be easily customized."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e4fee704",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"ggml_metal_free: deallocating\n"
]
},
{
"data": {
"text/plain": [
"'The approaches to task decomposition include using LLM with simple prompting, task-specific instructions, or human inputs. Thanks for asking!'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains import RetrievalQA\n",
"from langchain.prompts import PromptTemplate\n",
"\n",
"template = \"\"\"Use the following pieces of context to answer the question at the end. \n",
"If you don't know the answer, just say that you don't know, don't try to make up an answer. \n",
"Use three sentences maximum and keep the answer as concise as possible. \n",
"Always say \"thanks for asking!\" at the end of the answer. \n",
"{context}\n",
"Question: {question}\n",
"Helpful Answer:\"\"\"\n",
"QA_CHAIN_PROMPT = PromptTemplate.from_template(template)\n",
"\n",
"llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0)\n",
"qa_chain = RetrievalQA.from_chain_type(\n",
" llm,\n",
" retriever=vectorstore.as_retriever(),\n",
" chain_type_kwargs={\"prompt\": QA_CHAIN_PROMPT}\n",
")\n",
"result = qa_chain({\"query\": question})\n",
"result[\"result\"]"
]
},
{
"cell_type": "markdown",
"id": "ff40e8db",
"metadata": {},
"source": [
"#### Return source documents\n",
"\n",
"The full set of retrieved documents used for answer distillation can be returned using `return_source_documents=True`."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "60004293",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
},
{
"data": {
"text/plain": [
"Document(page_content='Task decomposition can be done (1) by LLM with simple prompting like \"Steps for XYZ.\\\\n1.\", \"What are the subgoals for achieving XYZ?\", (2) by using task-specific instructions; e.g. \"Write a story outline.\" for writing a novel, or (3) with human inputs.', metadata={'description': 'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as AutoGPT, GPT-Engineer and BabyAGI, serve as inspiring examples. The potentiality of LLM extends beyond generating well-written copies, stories, essays and programs; it can be framed as a powerful general problem solver.\\nAgent System Overview In a LLM-powered autonomous agent system, LLM functions as the agents brain, complemented by several key components:', 'language': 'en', 'source': 'https://lilianweng.github.io/posts/2023-06-23-agent/', 'title': \"LLM Powered Autonomous Agents | Lil'Log\"})"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains import RetrievalQA\n",
"\n",
"qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),\n",
" return_source_documents=True)\n",
"result = qa_chain({\"query\": question})\n",
"print(len(result['source_documents']))\n",
"result['source_documents'][0]"
]
},
{
"cell_type": "markdown",
"id": "1b600236",
"metadata": {},
"source": [
"#### Return citations\n",
"\n",
"Answer citations can be returned using `RetrievalQAWithSourcesChain`."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "948f6d19",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'question': 'What are the approaches to Task Decomposition?',\n",
" 'answer': 'The approaches to Task Decomposition include:\\n1. Using LLM with simple prompting, such as providing steps or subgoals for achieving a task.\\n2. Using task-specific instructions, such as providing a specific instruction like \"Write a story outline\" for writing a novel.\\n3. Using human inputs to decompose the task.\\nAnother approach is the Tree of Thoughts, which extends the Chain of Thought (CoT) technique by exploring multiple reasoning possibilities at each step and generating multiple thoughts per step, creating a tree structure. The search process can be BFS or DFS, and each state can be evaluated by a classifier or majority vote.\\nSources: https://lilianweng.github.io/posts/2023-06-23-agent/',\n",
" 'sources': ''}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains import RetrievalQAWithSourcesChain\n",
"\n",
"qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm,retriever=vectorstore.as_retriever())\n",
"\n",
"result = qa_chain({\"question\": question})\n",
"result"
]
},
{
"cell_type": "markdown",
"id": "73d0b138",
"metadata": {},
"source": [
"#### Customizing retrieved document processing\n",
"\n",
"Retrieved documents can be fed to an LLM for answer distillation in a few different ways.\n",
"\n",
"`stuff`, `refine`, `map-reduce`, and `map-rerank` chains for passing documents to an LLM prompt are well summarized [here](/docs/modules/chains/document/).\n",
" \n",
"`stuff` is commonly used because it simply \"stuffs\" all retrieved documents into the prompt.\n",
"\n",
"The [load_qa_chain](/docs/use_cases/question_answering/how_to/question_answering.html) is an easy way to pass documents to an LLM using these various approaches (e.g., see `chain_type`)."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "29aa139f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'output_text': 'The approaches to task decomposition mentioned in the provided context are:\\n\\n1. Chain of thought (CoT): This approach involves instructing the language model to \"think step by step\" and decompose complex tasks into smaller and simpler steps. It enhances model performance on complex tasks by utilizing more test-time computation.\\n\\n2. Tree of Thoughts: This approach extends CoT by exploring multiple reasoning possibilities at each step. It decomposes the problem into multiple thought steps and generates multiple thoughts per step, creating a tree structure. The search process can be BFS or DFS, and each state is evaluated by a classifier or majority vote.\\n\\n3. LLM with simple prompting: This approach involves using a language model with simple prompts like \"Steps for XYZ\" or \"What are the subgoals for achieving XYZ?\" to perform task decomposition.\\n\\n4. Task-specific instructions: This approach involves providing task-specific instructions to guide the language model in decomposing the task. For example, providing the instruction \"Write a story outline\" for the task of writing a novel.\\n\\n5. Human inputs: Task decomposition can also be done with human inputs, where humans provide guidance and input to break down the task into smaller subtasks.'}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains.question_answering import load_qa_chain\n",
"\n",
"chain = load_qa_chain(llm, chain_type=\"stuff\")\n",
"chain({\"input_documents\": unique_docs, \"question\": question},return_only_outputs=True)"
]
},
{
"cell_type": "markdown",
"id": "a8cb8cd1",
"metadata": {},
"source": [
"We can also pass the `chain_type` to `RetrievalQA`."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f68574bd",
"metadata": {},
"outputs": [],
"source": [
"qa_chain = RetrievalQA.from_chain_type(llm,retriever=vectorstore.as_retriever(),\n",
" chain_type=\"stuff\")\n",
"result = qa_chain({\"query\": question})"
]
},
{
"cell_type": "markdown",
"id": "b33aeb5f",
"metadata": {},
"source": [
"In summary, the user can choose the desired level of abstraction for QA:\n",
"\n",
"![summary_chains.png](/img/summary_chains.png)\n",
"\n",
"## Step 6. Chat\n",
"\n",
"See our [use-case on chat](/docs/use_cases/chatbots) for detail on this!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,25 @@
from langchain_experimental.comprehend_moderation.amazon_comprehend_moderation import (
AmazonComprehendModerationChain,
)
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler,
)
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
BaseModerationFilters,
)
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
__all__ = [
"BaseModeration",
"BaseModerationActions",
"BaseModerationFilters",
"ComprehendPII",
"ComprehendIntent",
"ComprehendToxicity",
"BaseModerationCallbackHandler",
"AmazonComprehendModerationChain",
]

View File

@@ -0,0 +1,184 @@
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain_experimental.comprehend_moderation.base_moderation import (
BaseModeration,
)
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
BaseModerationCallbackHandler,
)
from langchain_experimental.pydantic_v1 import root_validator
class AmazonComprehendModerationChain(Chain):
"""A subclass of Chain, designed to apply moderation to LLMs."""
output_key: str = "output" #: :meta private:
"""Key used to fetch/store the output in data containers. Defaults to `output`"""
input_key: str = "input" #: :meta private:
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
moderation_config: Optional[Dict[str, Any]] = None
"""Configuration settings for moderation"""
client: Optional[Any]
"""boto3 client object for connection to Amazon Comprehend"""
region_name: Optional[str] = None
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config in case it is not provided here.
"""
credentials_profile_name: Optional[str] = None
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
has either access keys or role information specified.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
moderation_callback: Optional[BaseModerationCallbackHandler] = None
"""Callback handler for moderation, this is different
from regular callbacks which can be used in addition to this."""
unique_id: Optional[str] = None
"""A unique id that can be used to identify or group a user or session"""
@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Creates an Amazon Comprehend client
Args:
values (Dict[str, Any]): A dictionary containing configuration values.
Returns:
Dict[str, Any]: A dictionary with the updated configuration values,
including the Amazon Comprehend client.
Raises:
ModuleNotFoundError: If the 'boto3' package is not installed.
ValueError: If there is an issue importing 'boto3' or loading
AWS credentials.
Example:
.. code-block:: python
config = {
"credentials_profile_name": "my-profile",
"region_name": "us-west-2"
}
updated_config = create_client(config)
comprehend_client = updated_config["client"]
"""
if values.get("client") is not None:
return values
try:
import boto3
if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]
values["client"] = session.client("comprehend", **client_params)
return values
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
@property
def output_keys(self) -> List[str]:
"""
Returns a list of output keys.
This method defines the output keys that will be used to access the output
values produced by the chain or function. It ensures that the specified keys
are available to access the outputs.
Returns:
List[str]: A list of output keys.
Note:
This method is considered private and may not be intended for direct
external use.
"""
return [self.output_key]
@property
def input_keys(self) -> List[str]:
"""
Returns a list of input keys expected by the prompt.
This method defines the input keys that the prompt expects in order to perform
its processing. It ensures that the specified keys are available for providing
input to the prompt.
Returns:
List[str]: A list of input keys.
Note:
This method is considered private and may not be intended for direct
external use.
"""
return [self.input_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Executes the moderation process on the input text and returns the processed
output.
This internal method performs the moderation process on the input text. It
converts the input prompt value to plain text, applies the specified filters,
and then converts the filtered output back to a suitable prompt value object.
Additionally, it provides the option to log information about the run using
the provided `run_manager`.
Args:
inputs: A dictionary containing input values
run_manager: A run manager to handle run-related events. Default is None
Returns:
Dict[str, str]: A dictionary containing the processed output of the
moderation process.
Raises:
ValueError: If there is an error during the moderation process
"""
if run_manager:
run_manager.on_text("Running AmazonComprehendModerationChain...\n")
moderation = BaseModeration(
client=self.client,
config=self.moderation_config,
moderation_callback=self.moderation_callback,
unique_id=self.unique_id,
run_manager=run_manager,
)
response = moderation.moderate(prompt=inputs[self.input_keys[0]])
return {self.output_key: response}

View File

@@ -0,0 +1,176 @@
import uuid
from typing import Any, Callable, Dict, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import AIMessage, HumanMessage
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
class BaseModeration:
def __init__(
self,
client: Any,
config: Optional[Dict[str, Any]] = None,
moderation_callback: Optional[Any] = None,
unique_id: Optional[str] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
):
self.client = client
self.config = config
self.moderation_callback = moderation_callback
self.unique_id = unique_id
self.chat_message_index = 0
self.run_manager = run_manager
self.chain_id = str(uuid.uuid4())
def _convert_prompt_to_text(self, prompt: Any) -> str:
input_text = str()
if isinstance(prompt, StringPromptValue):
input_text = prompt.text
elif isinstance(prompt, str):
input_text = prompt
elif isinstance(prompt, ChatPromptValue):
"""
We will just check the last message in the message Chain of a
ChatPromptTemplate. The typical chronology is
SystemMessage > HumanMessage > AIMessage and so on. However assuming
that with every chat the chain is invoked we will only check the last
message. This is assuming that all previous messages have been checked
already. Only HumanMessage and AIMessage will be checked. We can perhaps
loop through and take advantage of the additional_kwargs property in the
HumanMessage and AIMessage schema to mark messages that have been moderated.
However that means that this class could generate multiple text chunks
and moderate() logics would need to be updated. This also means some
complexity in re-constructing the prompt while keeping the messages in
sequence.
"""
message = prompt.messages[-1]
self.chat_message_index = len(prompt.messages) - 1
if isinstance(message, HumanMessage):
input_text = message.content
if isinstance(message, AIMessage):
input_text = message.content
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
return input_text
def _convert_text_to_prompt(self, prompt: Any, text: str) -> Any:
if isinstance(prompt, StringPromptValue):
return StringPromptValue(text=text)
elif isinstance(prompt, str):
return text
elif isinstance(prompt, ChatPromptValue):
messages = prompt.messages
message = messages[self.chat_message_index]
if isinstance(message, HumanMessage):
messages[self.chat_message_index] = HumanMessage(
content=text,
example=message.example,
additional_kwargs=message.additional_kwargs,
)
if isinstance(message, AIMessage):
messages[self.chat_message_index] = AIMessage(
content=text,
example=message.example,
additional_kwargs=message.additional_kwargs,
)
return ChatPromptValue(messages=messages)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def _moderation_class(self, moderation_class: Any) -> Callable:
return moderation_class(
client=self.client,
callback=self.moderation_callback,
unique_id=self.unique_id,
chain_id=self.chain_id,
).validate
def _log_message_for_verbose(self, message: str) -> None:
if self.run_manager:
self.run_manager.on_text(message)
def moderate(self, prompt: Any) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
ModerationIntentionError,
ModerationPiiError,
ModerationToxicityError,
)
try:
# convert prompt to text
input_text = self._convert_prompt_to_text(prompt=prompt)
output_text = str()
# perform moderation
if self.config is None:
# In absence of config Action will default to STOP only
self._log_message_for_verbose("Running pii validation...\n")
pii_validate = self._moderation_class(moderation_class=ComprehendPII)
output_text = pii_validate(prompt_value=input_text)
self._log_message_for_verbose("Running toxicity validation...\n")
toxicity_validate = self._moderation_class(
moderation_class=ComprehendToxicity
)
output_text = toxicity_validate(prompt_value=output_text)
self._log_message_for_verbose("Running intent validation...\n")
intent_validate = self._moderation_class(
moderation_class=ComprehendIntent
)
output_text = intent_validate(prompt_value=output_text)
else:
filter_functions = {
"pii": ComprehendPII,
"toxicity": ComprehendToxicity,
"intent": ComprehendIntent,
}
filters = self.config["filters"]
for _filter in filters:
filter_name = f"{_filter}"
if filter_name in filter_functions:
self._log_message_for_verbose(
f"Running {filter_name} Validation...\n"
)
validation_fn = self._moderation_class(
moderation_class=filter_functions[filter_name]
)
input_text = input_text if not output_text else output_text
output_text = validation_fn(
prompt_value=input_text,
config=self.config[filter_name]
if filter_name in self.config
else None,
)
# convert text to prompt and return
return self._convert_text_to_prompt(prompt=prompt, text=output_text)
except ModerationPiiError as e:
self._log_message_for_verbose(f"Found PII content..stopping..\n{str(e)}\n")
raise e
except ModerationToxicityError as e:
self._log_message_for_verbose(
f"Found Toxic content..stopping..\n{str(e)}\n"
)
raise e
except ModerationIntentionError as e:
self._log_message_for_verbose(
f"Found Harmful intention..stopping..\n{str(e)}\n"
)
raise e
except Exception as e:
raise e

View File

@@ -0,0 +1,64 @@
from typing import Any, Callable, Dict
class BaseModerationCallbackHandler:
def __init__(self) -> None:
if (
self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_pii, self.on_after_pii
)
and self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
)
and self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_intent, self.on_after_intent
)
):
raise NotImplementedError(
"Subclasses must override at least one of on_after_pii(), "
"on_after_toxicity(), or on_after_intent() functions."
)
def _is_method_unchanged(
self, base_method: Callable, derived_method: Callable
) -> bool:
return base_method.__qualname__ == derived_method.__qualname__
async def on_after_pii(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after PII validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.")
async def on_after_toxicity(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Toxicity validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.")
async def on_after_intent(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Toxicity validation is complete."""
raise NotImplementedError("Subclasses should implement this async method.")
@property
def pii_callback(self) -> bool:
return (
self.on_after_pii.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_pii
)
@property
def toxicity_callback(self) -> bool:
return (
self.on_after_toxicity.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_toxicity
)
@property
def intent_callback(self) -> bool:
return (
self.on_after_intent.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_intent
)

View File

@@ -0,0 +1,12 @@
from enum import Enum
class BaseModerationActions(Enum):
STOP = 1
ALLOW = 2
class BaseModerationFilters(str, Enum):
PII = "pii"
TOXICITY = "toxicity"
INTENT = "intent"

View File

@@ -0,0 +1,43 @@
class ModerationPiiError(Exception):
"""Exception raised if PII entities are detected.
Attributes:
message -- explanation of the error
"""
def __init__(
self, message: str = "The prompt contains PII entities and cannot be processed"
):
self.message = message
super().__init__(self.message)
class ModerationToxicityError(Exception):
"""Exception raised if Toxic entities are detected.
Attributes:
message -- explanation of the error
"""
def __init__(
self, message: str = "The prompt contains toxic content and cannot be processed"
):
self.message = message
super().__init__(self.message)
class ModerationIntentionError(Exception):
"""Exception raised if Intention entities are detected.
Attributes:
message -- explanation of the error
"""
def __init__(
self,
message: str = (
"The prompt indicates an un-desired intent and " "cannot be processed"
),
):
self.message = message
super().__init__(self.message)

View File

@@ -0,0 +1,101 @@
import asyncio
import warnings
from typing import Any, Dict, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationIntentionError,
)
class ComprehendIntent:
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Intent",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
intent_endpoint = "document-classifier-endpoint/prompt-intent"
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
"""
Check and validate the intent of the given prompt text.
Args:
comprehend_client: Comprehend client for intent classification
prompt_value (str): The input text to be checked for unintended intent
config (Dict[str, Any]): Configuration settings for intent checks
Raises:
ValueError: If unintended intent is found in the prompt text based
on the specified threshold.
Returns:
str: The input prompt_value.
Note:
This function checks the intent of the provided prompt text using
Comprehend's classify_document API and raises an error if unintended
intent is detected with a score above the specified threshold.
"""
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
)
threshold = config.get("threshold", 0.5) if config else 0.5
action = (
config.get("action", BaseModerationActions.STOP)
if config
else BaseModerationActions.STOP
)
intent_found = False
if action == BaseModerationActions.ALLOW:
warnings.warn(
"You have allowed content with Harmful content."
"Defaulting to STOP action..."
)
action = BaseModerationActions.STOP
endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)
if self.callback and self.callback.intent_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response
for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNDESIRED_PROMPT"
):
intent_found = True
break
if self.callback and self.callback.intent_callback:
if intent_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if intent_found:
raise ModerationIntentionError
return prompt_value

View File

@@ -0,0 +1,173 @@
import asyncio
from typing import Any, Dict, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationPiiError,
)
class ComprehendPII:
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "PII",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
BaseModerationActions,
)
if config:
action = config.get("action", BaseModerationActions.STOP)
if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
raise ValueError("Action can either be stop or allow")
return (
self._contains_pii(prompt_value=prompt_value, config=config)
if action == BaseModerationActions.STOP
else self._detect_pii(prompt_value=prompt_value, config=config)
)
else:
return self._contains_pii(prompt_value=prompt_value)
def _contains_pii(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
"""
Checks for Personally Identifiable Information (PII) labels above a
specified threshold.
Args:
prompt_value (str): The input text to be checked for PII labels.
config (Dict[str, Any]): Configuration for PII check and actions.
Returns:
str: the original prompt
Note:
- The provided client should be initialized with valid AWS credentials.
"""
pii_identified = self.client.contains_pii_entities(
Text=prompt_value, LanguageCode="en"
)
if self.callback and self.callback.pii_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified
threshold = config.get("threshold", 0.5) if config else 0.5
pii_labels = config.get("labels", []) if config else []
pii_found = False
for entity in pii_identified["Labels"]:
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
entity["Score"] >= threshold and not pii_labels
):
pii_found = True
break
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
if pii_found:
raise ModerationPiiError
return prompt_value
def _detect_pii(self, prompt_value: str, config: Optional[Dict[str, Any]]) -> str:
"""
Detects and handles Personally Identifiable Information (PII) entities in the
given prompt text using Amazon Comprehend's detect_pii_entities API. The
function provides options to redact or stop processing based on the identified
PII entities and a provided configuration.
Args:
prompt_value (str): The input text to be checked for PII entities.
config (Dict[str, Any]): A configuration specifying how to handle
PII entities.
Returns:
str: The processed prompt text with redacted PII entities or raised
exceptions.
Raises:
ValueError: If the prompt contains configured PII entities for
stopping processing.
Note:
- If PII is not found in the prompt, the original prompt is returned.
- The client should be initialized with valid AWS credentials.
"""
pii_identified = self.client.detect_pii_entities(
Text=prompt_value, LanguageCode="en"
)
if self.callback and self.callback.pii_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified
if (pii_identified["Entities"]) == []:
if self.callback and self.callback.pii_callback:
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
return prompt_value
pii_found = False
if not config and pii_identified["Entities"]:
for entity in pii_identified["Entities"]:
if entity["Score"] >= 0.5:
pii_found = True
break
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
if pii_found:
raise ModerationPiiError
else:
threshold = config.get("threshold", 0.5) # type: ignore
pii_labels = config.get("labels", []) # type: ignore
mask_marker = config.get("mask_character", "*") # type: ignore
pii_found = False
for entity in pii_identified["Entities"]:
if (
pii_labels
and entity["Type"] in pii_labels
and entity["Score"] >= threshold
) or (not pii_labels and entity["Score"] >= threshold):
pii_found = True
char_offset_begin = entity["BeginOffset"]
char_offset_end = entity["EndOffset"]
prompt_value = (
prompt_value[:char_offset_begin]
+ mask_marker * (char_offset_end - char_offset_begin)
+ prompt_value[char_offset_end:]
)
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
return prompt_value

View File

@@ -0,0 +1,209 @@
import asyncio
import importlib
import warnings
from typing import Any, Dict, List, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationToxicityError,
)
class ComprehendToxicity:
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Toxicity",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _toxicity_init_validate(self, max_size: int) -> Any:
"""
Validate and initialize toxicity processing configuration.
Args:
max_size (int): Maximum sentence size defined in the configuration object.
Raises:
Exception: If the maximum sentence size exceeds the 5KB limit.
Note:
This function ensures that the NLTK punkt tokenizer is downloaded if not
already present.
Returns:
None
"""
if max_size > 1024 * 5:
raise Exception("The sentence length should not exceed 5KB.")
try:
nltk = importlib.import_module("nltk")
nltk.data.find("tokenizers/punkt")
return nltk
except ImportError:
raise ModuleNotFoundError(
"Could not import nltk python package. "
"Please install it with `pip install nltk`."
)
except LookupError:
nltk.download("punkt")
def _split_paragraph(
self, prompt_value: str, max_size: int = 1024 * 4
) -> List[List[str]]:
"""
Split a paragraph into chunks of sentences, respecting the maximum size limit.
Args:
paragraph (str): The input paragraph to be split into chunks
max_size (int, optional): The maximum size limit in bytes for each chunk
Defaults to 1024.
Returns:
List[List[str]]: A list of chunks, where each chunk is a list of sentences
Note:
This function validates the maximum sentence size based on service limits
using the 'toxicity_init_validate' function. It uses the NLTK sentence
tokenizer to split the paragraph into sentences.
"""
# validate max. sentence size based on Service limits
nltk = self._toxicity_init_validate(max_size)
sentences = nltk.sent_tokenize(prompt_value)
chunks = []
current_chunk = [] # type: ignore
current_size = 0
for sentence in sentences:
sentence_size = len(sentence.encode("utf-8"))
# If adding a new sentence exceeds max_size or
# current_chunk has 10 sentences, start a new chunk
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
if current_chunk: # Avoid appending empty chunks
chunks.append(current_chunk)
current_chunk = []
current_size = 0
current_chunk.append(sentence)
current_size += sentence_size
# Add any remaining sentences
if current_chunk:
chunks.append(current_chunk)
return chunks
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
"""
Check the toxicity of a given text prompt using AWS Comprehend service
and apply actions based on configuration.
Args:
prompt_value (str): The text content to be checked for toxicity.
config (Dict[str, Any]): Configuration for toxicity checks and actions.
Returns:
str: The original prompt_value if allowed or no toxicity found.
Raises:
ValueError: If the prompt contains toxic labels and cannot be
processed based on the configuration.
"""
chunks = self._split_paragraph(prompt_value=prompt_value)
for sentence_list in chunks:
segments = [{"Text": sentence} for sentence in sentence_list]
response = self.client.detect_toxic_content(
TextSegments=segments, LanguageCode="en"
)
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_input"] = segments # type: ignore
self.moderation_beacon["moderation_output"] = response
if config:
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
BaseModerationActions,
)
toxicity_found = False
action = config.get("action", BaseModerationActions.STOP)
if action not in [
BaseModerationActions.STOP,
BaseModerationActions.ALLOW,
]:
raise ValueError("Action can either be stop or allow")
threshold = config.get("threshold", 0.5) if config else 0.5
toxicity_labels = config.get("labels", []) if config else []
if action == BaseModerationActions.STOP:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label
and (
not toxicity_labels
or label["Name"] in toxicity_labels
)
and label["Score"] >= threshold
):
toxicity_found = True
break
if action == BaseModerationActions.ALLOW:
if not toxicity_labels:
warnings.warn(
"You have allowed toxic content without specifying "
"any toxicity labels."
)
else:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
else:
if response["ResultList"]:
detected_toxic_labels = list()
for item in response["ResultList"]:
detected_toxic_labels.extend(item["Labels"])
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
raise ModerationToxicityError
return prompt_value

File diff suppressed because it is too large Load Diff

View File

@@ -33,6 +33,7 @@ from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
@@ -302,6 +303,14 @@ class RedisSemanticCache(BaseCache):
# TODO - implement a TTL policy in Redis
DEFAULT_SCHEMA = {
"content_key": "prompt",
"text": [
{"name": "prompt"},
],
"extra": [{"name": "return_val"}, {"name": "llm_string"}],
}
def __init__(
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
):
@@ -349,12 +358,14 @@ class RedisSemanticCache(BaseCache):
embedding=self.embedding,
index_name=index_name,
redis_url=self.redis_url,
schema=cast(Dict, self.DEFAULT_SCHEMA),
)
except ValueError:
redis = RedisVectorstore(
embedding_function=self.embedding.embed_query,
embedding=self.embedding,
index_name=index_name,
redis_url=self.redis_url,
index_schema=cast(Dict, self.DEFAULT_SCHEMA),
)
_embedding = self.embedding.embed_query(text="test")
redis._create_index(dim=len(_embedding))
@@ -374,17 +385,18 @@ class RedisSemanticCache(BaseCache):
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
generations = []
generations: List = []
# Read from a Hash
results = llm_cache.similarity_search_limit_score(
results = llm_cache.similarity_search(
query=prompt,
k=1,
score_threshold=self.score_threshold,
distance_threshold=self.score_threshold,
)
if results:
for document in results:
for text in document.metadata["return_val"]:
generations.append(Generation(text=text))
generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
@@ -402,11 +414,11 @@ class RedisSemanticCache(BaseCache):
)
return
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
_dump_generations_to_json([g for g in return_val])
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": [generation.text for generation in return_val],
"return_val": _dump_generations_to_json([g for g in return_val]),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])

View File

@@ -10,7 +10,7 @@ if TYPE_CHECKING:
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
class RetrieverManagerMixin:
@@ -44,11 +44,18 @@ class LLMManagerMixin:
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""
def on_llm_end(
self,
@@ -316,6 +323,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,

View File

@@ -49,6 +49,7 @@ from langchain.schema import (
LLMResult,
)
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
@@ -592,6 +593,8 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@@ -607,6 +610,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
chunk=chunk,
**kwargs,
)
@@ -655,6 +659,8 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token.
@@ -667,6 +673,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"on_llm_new_token",
"ignore_llm",
token,
chunk=chunk,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,

View File

@@ -13,7 +13,12 @@ from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run
from langchain.load.dump import dumpd
from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult
from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
logger = logging.getLogger(__name__)
@@ -123,6 +128,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
@@ -135,11 +141,14 @@ class BaseTracer(BaseCallbackHandler, ABC):
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
llm_run.events.append(
{
"name": "new_token",
"time": datetime.utcnow(),
"kwargs": {"token": token},
"kwargs": event_kwargs,
},
)

View File

@@ -120,9 +120,11 @@ class BaseQAWithSourcesChain(Chain, ABC):
def _split_sources(self, answer: str) -> Tuple[str, str]:
"""Split sources from answer."""
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2]
sources = re.split(r"\n", sources)[0]
if re.search(r"SOURCES?[:\s]", answer, re.IGNORECASE):
answer, sources = re.split(
r"SOURCES?[:\s]|QUESTION:\s", answer, flags=re.IGNORECASE
)[:2]
sources = re.split(r"\n", sources)[0].strip()
else:
sources = ""
return answer, sources

View File

@@ -318,7 +318,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=chunk)
def _generate(
self,
@@ -398,7 +398,7 @@ class ChatOpenAI(BaseChatModel):
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
async def _agenerate(
self,

View File

@@ -289,9 +289,10 @@ class Anthropic(LLM, _AnthropicCommon):
for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
):
yield GenerationChunk(text=token.completion)
chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token.completion)
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream(
self,
@@ -324,9 +325,10 @@ class Anthropic(LLM, _AnthropicCommon):
stream=True,
**params,
):
yield GenerationChunk(text=token.completion)
chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token.completion)
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""

View File

@@ -297,6 +297,7 @@ class BaseOpenAI(BaseLLM):
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
@@ -320,6 +321,7 @@ class BaseOpenAI(BaseLLM):
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
@@ -825,9 +827,10 @@ class OpenAIChat(BaseLLM):
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
chunk = GenerationChunk(text=token)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token)
run_manager.on_llm_new_token(token, chunk=chunk)
async def _astream(
self,
@@ -842,9 +845,10 @@ class OpenAIChat(BaseLLM):
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
chunk = GenerationChunk(text=token)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token)
await run_manager.on_llm_new_token(token, chunk=chunk)
def _generate(
self,

View File

@@ -1,7 +1,5 @@
from typing import Any, Dict, List
from langchain.schema.messages import get_buffer_string # noqa: 401
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""

View File

@@ -1,16 +1,64 @@
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Any,
)
import re
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
from urllib.parse import urlparse
import numpy as np
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.client import Redis as RedisType
logger = logging.getLogger(__name__)
def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
return np.array(array).astype(dtype).tobytes()
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, value: str) -> str:
def escape_symbol(match: re.Match) -> str:
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, value)
def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in required_modules:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) >= int(module["ver"]):
return
# otherwise raise error
error_message = (
"Redis cannot be used as a vector database without RediSearch >=2.4"
"Please head to https://redis.io/docs/stack/search/quick_start/"
"to know more about installing the RediSearch module within Redis Stack."
)
logger.error(error_message)
raise ValueError(error_message)
def get_client(redis_url: str, **kwargs: Any) -> RedisType:

View File

@@ -1,664 +0,0 @@
from __future__ import annotations
import json
import logging
import uuid
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Type,
)
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import root_validator
from langchain.utilities.redis import get_client
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.client import Redis as RedisType
from redis.commands.search.query import Query
# required modules
REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20400},
{"name": "searchlight", "ver": 20400},
]
# distance mmetrics
REDIS_DISTANCE_METRICS = Literal["COSINE", "IP", "L2"]
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in required_modules:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) >= int(module["ver"]):
return
# otherwise raise error
error_message = (
"Redis cannot be used as a vector database without RediSearch >=2.4"
"Please head to https://redis.io/docs/stack/search/quick_start/"
"to know more about installing the RediSearch module within Redis Stack."
)
logger.error(error_message)
raise ValueError(error_message)
def _check_index_exists(client: RedisType, index_name: str) -> bool:
"""Check if Redis index exists."""
try:
client.ft(index_name).info()
except: # noqa: E722
logger.info("Index does not exist")
return False
logger.info("Index already exists")
return True
def _redis_key(prefix: str) -> str:
"""Redis key schema for a given prefix."""
return f"{prefix}:{uuid.uuid4().hex}"
def _redis_prefix(index_name: str) -> str:
"""Redis key prefix for a given index."""
return f"doc:{index_name}"
def _default_relevance_score(val: float) -> float:
return 1 - val
class Redis(VectorStore):
"""`Redis` vector store.
To use, you should have the ``redis`` python package installed.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = Redis(
redis_url="redis://username:password@localhost:6379"
index_name="my-index",
embedding_function=embeddings.embed_query,
)
To use a redis replication setup with multiple redis server and redis sentinels
set "redis_url" to "redis+sentinel://" scheme. With this url format a path is
needed holding the name of the redis service within the sentinels to get the
correct redis server connection. The default service name is "mymaster".
An optional username or password is used for booth connections to the rediserver
and the sentinel, different passwords for server and sentinel are not supported.
And as another constraint only one sentinel instance can be given:
Example:
.. code-block:: python
vectorstore = Redis(
redis_url="redis+sentinel://username:password@sentinelhost:26379/mymaster/0"
index_name="my-index",
embedding_function=embeddings.embed_query,
)
"""
def __init__(
self,
redis_url: str,
index_name: str,
embedding_function: Callable,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
relevance_score_fn: Optional[Callable[[float], float]] = None,
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
**kwargs: Any,
):
"""Initialize with necessary components."""
self.embedding_function = embedding_function
self.index_name = index_name
try:
redis_client = get_client(redis_url=redis_url, **kwargs)
# check if redis has redisearch module installed
_check_redis_module_exist(redis_client, REDIS_REQUIRED_MODULES)
except ValueError as e:
raise ValueError(f"Redis failed to connect: {e}")
self.client = redis_client
self.content_key = content_key
self.metadata_key = metadata_key
self.vector_key = vector_key
self.distance_metric = distance_metric
self.relevance_score_fn = relevance_score_fn
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept embedding object directly
return None
def _select_relevance_score_fn(self) -> Callable[[float], float]:
if self.relevance_score_fn:
return self.relevance_score_fn
if self.distance_metric == "COSINE":
return self._cosine_relevance_score_fn
elif self.distance_metric == "IP":
return self._max_inner_product_relevance_score_fn
elif self.distance_metric == "L2":
return self._euclidean_relevance_score_fn
else:
return _default_relevance_score
def _create_index(self, dim: int = 1536) -> None:
try:
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
except ImportError:
raise ImportError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Check if index exists
if not _check_index_exists(self.client, self.index_name):
# Define schema
schema = (
TextField(name=self.content_key),
TextField(name=self.metadata_key),
VectorField(
self.vector_key,
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": dim,
"DISTANCE_METRIC": self.distance_metric,
},
),
)
prefix = _redis_prefix(self.index_name)
# Create Redis Index
self.client.ft(self.index_name).create_index(
fields=schema,
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
embeddings: Optional[List[List[float]]] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Add more texts to the vectorstore.
Args:
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
Defaults to None.
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
embeddings. Defaults to None.
keys (List[str]) or ids (List[str]): Identifiers of entries.
Defaults to None.
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
Returns:
List[str]: List of ids added to the vectorstore
"""
ids = []
prefix = _redis_prefix(self.index_name)
# Get keys or ids from kwargs
# Other vectorstores use ids
keys_or_ids = kwargs.get("keys", kwargs.get("ids"))
# Write data to redis
pipeline = self.client.pipeline(transaction=False)
for i, text in enumerate(texts):
# Use provided values by default or fallback
key = keys_or_ids[i] if keys_or_ids else _redis_key(prefix)
metadata = metadatas[i] if metadatas else {}
embedding = embeddings[i] if embeddings else self.embedding_function(text)
pipeline.hset(
key,
mapping={
self.content_key: text,
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
self.metadata_key: json.dumps(metadata),
},
)
ids.append(key)
# Write batch
if i % batch_size == 0:
pipeline.execute()
# Cleanup final batch
pipeline.execute()
return ids
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""
Returns the most similar indexed documents to the query text.
Args:
query (str): The query text for which to find similar documents.
k (int): The number of documents to return. Default is 4.
Returns:
List[Document]: A list of documents that are most similar to the query text.
"""
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [doc for doc, _ in docs_and_scores]
def similarity_search_limit_score(
self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
) -> List[Document]:
"""
Returns the most similar indexed documents to the query text within the
score_threshold range.
Args:
query (str): The query text for which to find similar documents.
k (int): The number of documents to return. Default is 4.
score_threshold (float): The minimum matching score required for a document
to be considered a match. Defaults to 0.2.
Because the similarity calculation algorithm is based on cosine
similarity, the smaller the angle, the higher the similarity.
Returns:
List[Document]: A list of documents that are most similar to the query text,
including the match score for each document.
Note:
If there are no documents that satisfy the score_threshold value,
an empty list is returned.
"""
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [doc for doc, score in docs_and_scores if score < score_threshold]
def _prepare_query(self, k: int) -> Query:
try:
from redis.commands.search.query import Query
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Prepare the Query
hybrid_fields = "*"
base_query = (
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
)
return_fields = [self.metadata_key, self.content_key, "vector_score", "id"]
return (
Query(base_query)
.return_fields(*return_fields)
.sort_by("vector_score")
.paging(0, k)
.dialect(2)
)
def similarity_search_with_score(
self, query: str, k: int = 4
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query and score for each
"""
# Creates embedding vector from user query
embedding = self.embedding_function(query)
# Creates Redis query
redis_query = self._prepare_query(k)
params_dict: Mapping[str, str] = {
"vector": np.array(embedding) # type: ignore
.astype(dtype=np.float32)
.tobytes()
}
# Perform vector search
results = self.client.ft(self.index_name).search(redis_query, params_dict)
# Prepare document results
docs_and_scores: List[Tuple[Document, float]] = []
for result in results.docs:
metadata = {**json.loads(result.metadata), "id": result.id}
doc = Document(page_content=result.content, metadata=metadata)
docs_and_scores.append((doc, float(result.vector_score)))
return docs_and_scores
@classmethod
def from_texts_return_keys(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
index_name: Optional[str] = None,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
**kwargs: Any,
) -> Tuple[Redis, List[str]]:
"""Create a Redis vectorstore from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in Redis.
3. Adds the documents to the newly created Redis index.
4. Returns the keys of the newly created documents.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
redisearch, keys = RediSearch.from_texts_return_keys(
texts,
embeddings,
redis_url="redis://username:password@localhost:6379"
)
"""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
if "redis_url" in kwargs:
kwargs.pop("redis_url")
# Name of the search index if not given
if not index_name:
index_name = uuid.uuid4().hex
# Create instance
instance = cls(
redis_url,
index_name,
embedding.embed_query,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
distance_metric=distance_metric,
**kwargs,
)
# Create embeddings over documents
embeddings = embedding.embed_documents(texts)
# Create the search index
instance._create_index(dim=len(embeddings[0]))
# Add data to Redis
keys = instance.add_texts(texts, metadatas, embeddings)
return instance, keys
@classmethod
def from_texts(
cls: Type[Redis],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
index_name: Optional[str] = None,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any,
) -> Redis:
"""Create a Redis vectorstore from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in Redis.
3. Adds the documents to the newly created Redis index.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
redisearch = RediSearch.from_texts(
texts,
embeddings,
redis_url="redis://username:password@localhost:6379"
)
"""
instance, _ = cls.from_texts_return_keys(
texts,
embedding,
metadatas=metadatas,
index_name=index_name,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
**kwargs,
)
return instance
@staticmethod
def delete(
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> bool:
"""
Delete a Redis entry.
Args:
ids: List of ids (keys) to delete.
Returns:
bool: Whether or not the deletions were successful.
"""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
if ids is None:
raise ValueError("'ids' (keys)() were not provided.")
try:
import redis # noqa: F401
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs:
kwargs.pop("redis_url")
client = get_client(redis_url=redis_url, **kwargs)
except ValueError as e:
raise ValueError(f"Your redis connected error: {e}")
# Check if index exists
try:
client.delete(*ids)
logger.info("Entries deleted")
return True
except: # noqa: E722
# ids does not exist
return False
@staticmethod
def drop_index(
index_name: str,
delete_documents: bool,
**kwargs: Any,
) -> bool:
"""
Drop a Redis search index.
Args:
index_name (str): Name of the index to drop.
delete_documents (bool): Whether to drop the associated documents.
Returns:
bool: Whether or not the drop was successful.
"""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
try:
import redis # noqa: F401
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs:
kwargs.pop("redis_url")
client = get_client(redis_url=redis_url, **kwargs)
except ValueError as e:
raise ValueError(f"Your redis connected error: {e}")
# Check if index exists
try:
client.ft(index_name).dropindex(delete_documents)
logger.info("Drop index")
return True
except: # noqa: E722
# Index not exist
return False
@classmethod
def from_existing_index(
cls,
embedding: Embeddings,
index_name: str,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any,
) -> Redis:
"""Connect to an existing Redis index."""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
try:
import redis # noqa: F401
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs:
kwargs.pop("redis_url")
client = get_client(redis_url=redis_url, **kwargs)
# check if redis has redisearch module installed
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
# ensure that the index already exists
assert _check_index_exists(
client, index_name
), f"Index {index_name} does not exist"
except Exception as e:
raise ValueError(f"Redis failed to connect: {e}")
return cls(
redis_url,
index_name,
embedding.embed_query,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
**kwargs,
)
def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
tags = kwargs.pop("tags", None) or []
tags.extend(self._get_retriever_tags())
return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class RedisVectorStoreRetriever(VectorStoreRetriever):
"""Retriever for `Redis` vector store."""
vectorstore: Redis
"""Redis VectorStore."""
search_type: str = "similarity"
"""Type of search to perform. Can be either 'similarity' or 'similarity_limit'."""
k: int = 4
"""Number of documents to return."""
score_threshold: float = 0.4
"""Score threshold for similarity_limit search."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "similarity_limit"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
elif self.search_type == "similarity_limit":
docs = self.vectorstore.similarity_search_limit_score(
query, k=self.k, score_threshold=self.score_threshold
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
return self.vectorstore.add_documents(documents, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
return await self.vectorstore.aadd_documents(documents, **kwargs)

View File

@@ -0,0 +1,9 @@
from .base import Redis
from .filters import (
RedisFilter,
RedisNum,
RedisTag,
RedisText,
)
__all__ = ["Redis", "RedisFilter", "RedisTag", "RedisText", "RedisNum"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
from typing import Any, Dict, List
import numpy as np
# required modules
REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20600},
{"name": "searchlight", "ver": 20600},
]
# distance metrics
REDIS_DISTANCE_METRICS: List[str] = ["COSINE", "IP", "L2"]
# supported vector datatypes
REDIS_VECTOR_DTYPE_MAP: Dict[str, Any] = {
"FLOAT32": np.float32,
"FLOAT64": np.float64,
}
REDIS_TAG_SEPARATOR = ","

View File

@@ -0,0 +1,420 @@
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union
from langchain.utilities.redis import TokenEscaper
# disable mypy error for dunder method overrides
# mypy: disable-error-code="override"
class RedisFilterOperator(Enum):
EQ = 1
NE = 2
LT = 3
GT = 4
LE = 5
GE = 6
OR = 7
AND = 8
LIKE = 9
IN = 10
class RedisFilter:
@staticmethod
def text(field: str) -> "RedisText":
return RedisText(field)
@staticmethod
def num(field: str) -> "RedisNum":
return RedisNum(field)
@staticmethod
def tag(field: str) -> "RedisTag":
return RedisTag(field)
class RedisFilterField:
escaper: "TokenEscaper" = TokenEscaper()
OPERATORS: Dict[RedisFilterOperator, str] = {}
def __init__(self, field: str):
self._field = field
self._value: Any = None
self._operator: RedisFilterOperator = RedisFilterOperator.EQ
def equals(self, other: "RedisFilterField") -> bool:
if not isinstance(other, type(self)):
return False
return self._field == other._field and self._value == other._value
def _set_value(
self, val: Any, val_type: type, operator: RedisFilterOperator
) -> None:
# check that the operator is supported by this class
if operator not in self.OPERATORS:
raise ValueError(
f"Operator {operator} not supported by {self.__class__.__name__}. "
+ f"Supported operators are {self.OPERATORS.values()}"
)
if not isinstance(val, val_type):
raise TypeError(
f"Right side argument passed to operator {self.OPERATORS[operator]} "
f"with left side "
f"argument {self.__class__.__name__} must be of type {val_type}"
)
self._value = val
self._operator = operator
def check_operator_misuse(func: Callable) -> Callable:
@wraps(func)
def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
# Extracting 'other' from positional arguments or keyword arguments
other = kwargs.get("other") if "other" in kwargs else None
if not other:
for arg in args:
if isinstance(arg, type(instance)):
other = arg
break
if isinstance(other, type(instance)):
raise ValueError(
"Equality operators are overridden for FilterExpression creation. Use "
".equals() for equality checks"
)
return func(instance, *args, **kwargs)
return wrapper
class RedisTag(RedisFilterField):
"""A RedisTag is a RedisFilterField representing a tag in a Redis index."""
OPERATORS: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "==",
RedisFilterOperator.NE: "!=",
RedisFilterOperator.IN: "==",
}
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "@%s:{%s}",
RedisFilterOperator.NE: "(-@%s:{%s})",
RedisFilterOperator.IN: "@%s:{%s}",
}
def __init__(self, field: str):
"""Create a RedisTag FilterField
Args:
field (str): The name of the RedisTag field in the index to be queried
against.
"""
super().__init__(field)
def _set_tag_value(
self, other: Union[List[str], str], operator: RedisFilterOperator
) -> None:
if isinstance(other, list):
if not all(isinstance(tag, str) for tag in other):
raise ValueError("All tags must be strings")
else:
other = [other]
self._set_value(other, list, operator)
@check_operator_misuse
def __eq__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
"""Create a RedisTag equality filter expression
Args:
other (Union[List[str], str]): The tag(s) to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisTag
>>> filter = RedisTag("brand") == "nike"
"""
self._set_tag_value(other, RedisFilterOperator.EQ)
return RedisFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
"""Create a RedisTag inequality filter expression
Args:
other (Union[List[str], str]): The tag(s) to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisTag
>>> filter = RedisTag("brand") != "nike"
"""
self._set_tag_value(other, RedisFilterOperator.NE)
return RedisFilterExpression(str(self))
@property
def _formatted_tag_value(self) -> str:
return "|".join([self.escaper.escape(tag) for tag in self._value])
def __str__(self) -> str:
if not self._value:
raise ValueError(
f"Operator must be used before calling __str__. Operators are "
f"{self.OPERATORS.values()}"
)
"""Return the Redis Query syntax for a RedisTag filter expression"""
return self.OPERATOR_MAP[self._operator] % (
self._field,
self._formatted_tag_value,
)
class RedisNum(RedisFilterField):
"""A RedisFilterField representing a numeric field in a Redis index."""
OPERATORS: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "==",
RedisFilterOperator.NE: "!=",
RedisFilterOperator.LT: "<",
RedisFilterOperator.GT: ">",
RedisFilterOperator.LE: "<=",
RedisFilterOperator.GE: ">=",
}
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "@%s:[%i %i]",
RedisFilterOperator.NE: "(-@%s:[%i %i])",
RedisFilterOperator.GT: "@%s:[(%i +inf]",
RedisFilterOperator.LT: "@%s:[-inf (%i]",
RedisFilterOperator.GE: "@%s:[%i +inf]",
RedisFilterOperator.LE: "@%s:[-inf %i]",
}
def __str__(self) -> str:
"""Return the Redis Query syntax for a Numeric filter expression"""
if not self._value:
raise ValueError(
f"Operator must be used before calling __str__. Operators are "
f"{self.OPERATORS.values()}"
)
if (
self._operator == RedisFilterOperator.EQ
or self._operator == RedisFilterOperator.NE
):
return self.OPERATOR_MAP[self._operator] % (
self._field,
self._value,
self._value,
)
else:
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
@check_operator_misuse
def __eq__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric equality filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") == 90210
"""
self._set_value(other, int, RedisFilterOperator.EQ)
return RedisFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric inequality filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") != 90210
"""
self._set_value(other, int, RedisFilterOperator.NE)
return RedisFilterExpression(str(self))
def __gt__(self, other: int) -> "RedisFilterExpression":
"""Create a RedisNumeric greater than filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") > 18
"""
self._set_value(other, int, RedisFilterOperator.GT)
return RedisFilterExpression(str(self))
def __lt__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric less than filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") < 18
"""
self._set_value(other, int, RedisFilterOperator.LT)
return RedisFilterExpression(str(self))
def __ge__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric greater than or equal to filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") >= 18
"""
self._set_value(other, int, RedisFilterOperator.GE)
return RedisFilterExpression(str(self))
def __le__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric less than or equal to filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") <= 18
"""
self._set_value(other, int, RedisFilterOperator.LE)
return RedisFilterExpression(str(self))
class RedisText(RedisFilterField):
"""A RedisText is a RedisFilterField representing a text field in a Redis index."""
OPERATORS = {
RedisFilterOperator.EQ: "==",
RedisFilterOperator.NE: "!=",
RedisFilterOperator.LIKE: "%",
}
OPERATOR_MAP = {
RedisFilterOperator.EQ: '@%s:"%s"',
RedisFilterOperator.NE: '(-@%s:"%s")',
RedisFilterOperator.LIKE: "@%s:%s",
}
@check_operator_misuse
def __eq__(self, other: str) -> "RedisFilterExpression":
"""Create a RedisText equality filter expression
Args:
other (str): The text value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisText
>>> filter = RedisText("job") == "engineer"
"""
self._set_value(other, str, RedisFilterOperator.EQ)
return RedisFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: str) -> "RedisFilterExpression":
"""Create a RedisText inequality filter expression
Args:
other (str): The text value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisText
>>> filter = RedisText("job") != "engineer"
"""
self._set_value(other, str, RedisFilterOperator.NE)
return RedisFilterExpression(str(self))
def __mod__(self, other: str) -> "RedisFilterExpression":
"""Create a RedisText like filter expression
Args:
other (str): The text value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisText
>>> filter = RedisText("job") % "engineer"
"""
self._set_value(other, str, RedisFilterOperator.LIKE)
return RedisFilterExpression(str(self))
def __str__(self) -> str:
if not self._value:
raise ValueError(
f"Operator must be used before calling __str__. Operators are "
f"{self.OPERATORS.values()}"
)
try:
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
except KeyError:
raise Exception("Invalid operator")
class RedisFilterExpression:
"""A RedisFilterExpression is a logical expression of RedisFilterFields.
RedisFilterExpressions can be combined using the & and | operators to create
complex logical expressions that evaluate to the Redis Query language.
This presents an interface by which users can create complex queries
without having to know the Redis Query language.
Filter expressions are not initialized directly. Instead they are built
by combining RedisFilterFields using the & and | operators.
Examples:
>>> from langchain.vectorstores.redis import RedisTag, RedisNum
>>> brand_is_nike = RedisTag("brand") == "nike"
>>> price_is_under_100 = RedisNum("price") < 100
>>> filter = brand_is_nike & price_is_under_100
>>> print(str(filter))
(@brand:{nike} @price:[-inf (100)])
"""
def __init__(
self,
_filter: Optional[str] = None,
operator: Optional[RedisFilterOperator] = None,
left: Optional["RedisFilterExpression"] = None,
right: Optional["RedisFilterExpression"] = None,
):
self._filter = _filter
self._operator = operator
self._left = left
self._right = right
def __and__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
return RedisFilterExpression(
operator=RedisFilterOperator.AND, left=self, right=other
)
def __or__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
return RedisFilterExpression(
operator=RedisFilterOperator.OR, left=self, right=other
)
def __str__(self) -> str:
# top level check that allows recursive calls to __str__
if not self._filter and not self._operator:
raise ValueError("Improperly initialized RedisFilterExpression")
# allow for single filter expression without operators as last
# expression in the chain might not have an operator
if self._operator:
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
return f"({str(self._left)}{operator_str}{str(self._right)})"
# check that base case, the filter is set
if not self._filter:
raise ValueError("Improperly initialized RedisFilterExpression")
return self._filter

View File

@@ -0,0 +1,276 @@
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import yaml
# ignore type error here as it's a redis-py type problem
from redis.commands.search.field import ( # type: ignore
NumericField,
TagField,
TextField,
VectorField,
)
from typing_extensions import Literal
from langchain.pydantic_v1 import BaseModel, Field, validator
from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
class RedisDistanceMetric(str, Enum):
l2 = "L2"
cosine = "COSINE"
ip = "IP"
class RedisField(BaseModel):
name: str = Field(...)
class TextFieldSchema(RedisField):
weight: float = 1
no_stem: bool = False
phonetic_matcher: Optional[str] = None
withsuffixtrie: bool = False
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> TextField:
return TextField(
self.name,
weight=self.weight,
no_stem=self.no_stem,
phonetic_matcher=self.phonetic_matcher,
sortable=self.sortable,
no_index=self.no_index,
)
class TagFieldSchema(RedisField):
separator: str = ","
case_sensitive: bool = False
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> TagField:
return TagField(
self.name,
separator=self.separator,
case_sensitive=self.case_sensitive,
sortable=self.sortable,
no_index=self.no_index,
)
class NumericFieldSchema(RedisField):
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> NumericField:
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
class RedisVectorField(RedisField):
dims: int = Field(...)
algorithm: object = Field(...)
datatype: str = Field(default="FLOAT32")
distance_metric: RedisDistanceMetric = Field(default="COSINE")
initial_cap: int = Field(default=20000)
@validator("distance_metric", pre=True)
def uppercase_strings(cls, v: str) -> str:
return v.upper()
@validator("datatype", pre=True)
def uppercase_and_check_dtype(cls, v: str) -> str:
if v.upper() not in REDIS_VECTOR_DTYPE_MAP:
raise ValueError(
f"datatype must be one of {REDIS_VECTOR_DTYPE_MAP.keys()}. Got {v}"
)
return v.upper()
class FlatVectorField(RedisVectorField):
algorithm: Literal["FLAT"] = "FLAT"
block_size: int = Field(default=1000)
def as_field(self) -> VectorField:
return VectorField(
self.name,
self.algorithm,
{
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
"INITIAL_CAP": self.initial_cap,
"BLOCK_SIZE": self.block_size,
},
)
class HNSWVectorField(RedisVectorField):
algorithm: Literal["HNSW"] = "HNSW"
m: int = Field(default=16)
ef_construction: int = Field(default=200)
ef_runtime: int = Field(default=10)
epsilon: float = Field(default=0.8)
def as_field(self) -> VectorField:
return VectorField(
self.name,
self.algorithm,
{
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
"INITIAL_CAP": self.initial_cap,
"M": self.m,
"EF_CONSTRUCTION": self.ef_construction,
"EF_RUNTIME": self.ef_runtime,
"EPSILON": self.epsilon,
},
)
class RedisModel(BaseModel):
# always have a content field for text
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
tag: Optional[List[TagFieldSchema]] = None
numeric: Optional[List[NumericFieldSchema]] = None
extra: Optional[List[RedisField]] = None
# filled by default_vector_schema
vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
content_key: str = "content"
content_vector_key: str = "content_vector"
def add_content_field(self) -> None:
if self.text is None:
self.text = []
for field in self.text:
if field.name == self.content_key:
return
self.text.append(TextFieldSchema(name=self.content_key))
def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
# catch case where user inputted no vector field spec
# in the index schema
if self.vector is None:
self.vector = []
# ignore types as pydantic is handling type validation and conversion
if vector_field["algorithm"] == "FLAT":
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
elif vector_field["algorithm"] == "HNSW":
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
else:
raise ValueError(
f"algorithm must be either FLAT or HNSW. Got "
f"{vector_field['algorithm']}"
)
def as_dict(self) -> Dict[str, List[Any]]:
schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
# iter over all class attributes
for attr, attr_value in self.__dict__.items():
# only non-empty lists
if isinstance(attr_value, list) and len(attr_value) > 0:
field_values: List[Dict[str, Any]] = []
# iterate over all fields in each category (tag, text, etc)
for val in attr_value:
value: Dict[str, Any] = {}
# iterate over values within each field to extract
# settings for that field (i.e. name, weight, etc)
for field, field_value in val.__dict__.items():
# make enums into strings
if isinstance(field_value, Enum):
value[field] = field_value.value
# don't write null values
elif field_value is not None:
value[field] = field_value
field_values.append(value)
schemas[attr] = field_values
schema: Dict[str, List[Any]] = {}
# only write non-empty lists from defaults
for k, v in schemas.items():
if len(v) > 0:
schema[k] = v
return schema
@property
def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
if not self.vector:
raise ValueError("No vector fields found")
for field in self.vector:
if field.name == self.content_vector_key:
return field
raise ValueError("No content_vector field found")
@property
def vector_dtype(self) -> np.dtype:
# should only ever be called after pydantic has validated the schema
return REDIS_VECTOR_DTYPE_MAP[self.content_vector.datatype]
@property
def is_empty(self) -> bool:
return all(
field is None for field in [self.tag, self.text, self.numeric, self.vector]
)
def get_fields(self) -> List["RedisField"]:
redis_fields: List["RedisField"] = []
if self.is_empty:
return redis_fields
for field_name in self.__fields__.keys():
if field_name not in ["content_key", "content_vector_key", "extra"]:
field_group = getattr(self, field_name)
if field_group is not None:
for field in field_group:
redis_fields.append(field.as_field())
return redis_fields
@property
def metadata_keys(self) -> List[str]:
keys: List[str] = []
if self.is_empty:
return keys
for field_name in self.__fields__.keys():
field_group = getattr(self, field_name)
if field_group is not None:
for field in field_group:
# check if it's a metadata field. exclude vector and content key
if not isinstance(field, str) and field.name not in [
self.content_key,
self.content_vector_key,
]:
keys.append(field.name)
return keys
def read_schema(
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
) -> Dict[str, Any]:
# check if its a dict and return RedisModel otherwise, check if it's a path and
# read in the file assuming it's a yaml file and return a RedisModel
if isinstance(index_schema, dict):
return index_schema
elif isinstance(index_schema, Path):
with open(index_schema, "rb") as f:
return yaml.safe_load(f)
elif isinstance(index_schema, str):
if Path(index_schema).resolve().is_file():
with open(index_schema, "rb") as f:
return yaml.safe_load(f)
else:
raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
else:
raise TypeError(
f"index_schema must be a dict, or path to a yaml file "
f"Got {type(index_schema)}"
)

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.0.273"
version = "0.0.274"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

View File

@@ -1,16 +1,27 @@
"""Test Redis cache functionality."""
import uuid
from typing import List
import pytest
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.embeddings.base import Embeddings
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379"
def random_string() -> str:
return str(uuid.uuid4())
def test_redis_cache_ttl() -> None:
import redis
@@ -30,12 +41,10 @@ def test_redis_cache() -> None:
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output)
assert output == expected_output
langchain.llm_cache.redis.flushall()
@@ -80,14 +89,90 @@ def test_redis_semantic_cache() -> None:
langchain.llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
import redis
def test_redis_semantic_cache_multi() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear(llm_string=llm_string)
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
def test_redis_semantic_cache_chat() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()
langchain.llm_cache.clear(llm_string=llm_string)
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
ids=[
"single_prompt_single_generation",
"single_prompt_multiple_generations",
"single_prompt_multiple_generations",
"multiple_prompts_multiple_generations",
],
)
def test_redis_semantic_cache_hit(
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
) -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=embedding, redis_url=REDIS_TEST_URL
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
print(prompt_i)
print(llm_generations_i)
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
llm.generate(prompts)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)

View File

@@ -1,18 +1,22 @@
"""Test ChatOpenAI wrapper."""
from typing import Any
from typing import Any, List, Optional, Union
import pytest
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.manager import CallbackManager
from langchain.chains.openai_functions import (
create_openai_fn_chain,
)
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@@ -191,6 +195,108 @@ async def test_async_chat_openai_streaming() -> None:
assert generation.text == generation.message.content
@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_async_chat_openai_streaming_with_function() -> None:
"""Test ChatOpenAI wrapper with multiple completions."""
class MyCustomAsyncHandler(AsyncCallbackHandler):
def __init__(self) -> None:
super().__init__()
self._captured_tokens: List[str] = []
self._captured_chunks: List[
Optional[Union[ChatGenerationChunk, GenerationChunk]]
] = []
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[ChatGenerationChunk, GenerationChunk]] = None,
**kwargs: Any,
) -> Any:
self._captured_tokens.append(token)
self._captured_chunks.append(chunk)
json_schema = {
"title": "Person",
"description": "Identifying information about a person.",
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "The person's name",
"type": "string",
},
"age": {
"title": "Age",
"description": "The person's age",
"type": "integer",
},
"fav_food": {
"title": "Fav Food",
"description": "The person's favorite food",
"type": "string",
},
},
"required": ["name", "age"],
}
callback_handler = MyCustomAsyncHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
n=1,
callback_manager=callback_manager,
streaming=True,
)
prompt_msgs = [
SystemMessage(
content="You are a world class algorithm for "
"extracting information in structured formats."
),
HumanMessage(
content="Use the given format to extract "
"information from the following input:"
),
HumanMessagePromptTemplate.from_template("{input}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
function: Any = {
"name": "output_formatter",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": json_schema,
}
chain = create_openai_fn_chain(
[function],
chat,
prompt,
output_parser=None,
)
message = HumanMessage(content="Sally is 13 years old")
response = await chain.agenerate([{"input": message}])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
assert len(callback_handler._captured_tokens) > 0
assert len(callback_handler._captured_chunks) > 0
assert all([chunk is not None for chunk in callback_handler._captured_chunks])
def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.

View File

@@ -52,6 +52,7 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
if text not in self.known_texts:
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
return [float(1.0)] * (self.dimensionality - 1) + [

View File

@@ -1,17 +1,28 @@
"""Test Redis functionality."""
from typing import List
import os
from typing import Any, Dict, List, Optional
import pytest
from langchain.docstore.document import Document
from langchain.vectorstores.redis import Redis
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from langchain.vectorstores.redis import (
Redis,
RedisFilter,
RedisNum,
RedisText,
)
from langchain.vectorstores.redis.filters import RedisFilterExpression
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
TEST_INDEX_NAME = "test"
TEST_REDIS_URL = "redis://localhost:6379"
TEST_SINGLE_RESULT = [Document(page_content="foo")]
TEST_SINGLE_WITH_METADATA_RESULT = [Document(page_content="foo", metadata={"a": "b"})]
TEST_SINGLE_WITH_METADATA = {"a": "b"}
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
RANGE_SCORE = pytest.approx(0.0513, abs=0.002)
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
IP_SCORE = -8.0
EUCLIDEAN_SCORE = 1.0
@@ -23,6 +34,27 @@ def drop(index_name: str) -> bool:
)
def convert_bytes(data: Any) -> Any:
if isinstance(data, bytes):
return data.decode("ascii")
if isinstance(data, dict):
return dict(map(convert_bytes, data.items()))
if isinstance(data, list):
return list(map(convert_bytes, data))
if isinstance(data, tuple):
return map(convert_bytes, data)
return data
def make_dict(values: List[Any]) -> dict:
i = 0
di = {}
while i < len(values) - 1:
di[values[i]] = values[i + 1]
i += 2
return di
@pytest.fixture
def texts() -> List[str]:
return ["foo", "bar", "baz"]
@@ -31,7 +63,7 @@ def texts() -> List[str]:
def test_redis(texts: List[str]) -> None:
"""Test end to end construction and search."""
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
output = docsearch.similarity_search("foo", k=1)
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
assert output == TEST_SINGLE_RESULT
assert drop(docsearch.index_name)
@@ -40,30 +72,55 @@ def test_redis_new_vector(texts: List[str]) -> None:
"""Test adding a new document"""
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2)
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
assert output == TEST_RESULT
assert drop(docsearch.index_name)
def test_redis_from_existing(texts: List[str]) -> None:
"""Test adding a new document"""
Redis.from_texts(
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
)
schema: Dict = docsearch.schema
# write schema for the next test
docsearch.write_schema("test_schema.yml")
# Test creating from an existing
docsearch2 = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
FakeEmbeddings(),
index_name=TEST_INDEX_NAME,
redis_url=TEST_REDIS_URL,
schema=schema,
)
output = docsearch2.similarity_search("foo", k=1)
output = docsearch2.similarity_search("foo", k=1, return_metadata=False)
assert output == TEST_SINGLE_RESULT
def test_redis_add_texts_to_existing() -> None:
"""Test adding a new document"""
# Test creating from an existing with yaml from file
docsearch = Redis.from_existing_index(
FakeEmbeddings(),
index_name=TEST_INDEX_NAME,
redis_url=TEST_REDIS_URL,
schema="test_schema.yml",
)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
assert output == TEST_RESULT
assert drop(TEST_INDEX_NAME)
# remove the test_schema.yml file
os.remove("test_schema.yml")
def test_redis_from_texts_return_keys(texts: List[str]) -> None:
"""Test from_texts_return_keys constructor."""
docsearch, keys = Redis.from_texts_return_keys(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL
)
output = docsearch.similarity_search("foo", k=1)
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
assert output == TEST_SINGLE_RESULT
assert len(keys) == len(texts)
assert drop(docsearch.index_name)
@@ -73,21 +130,124 @@ def test_redis_from_documents(texts: List[str]) -> None:
"""Test from_documents constructor."""
docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts]
docsearch = Redis.from_documents(docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
output = docsearch.similarity_search("foo", k=1)
assert output == TEST_SINGLE_WITH_METADATA_RESULT
output = docsearch.similarity_search("foo", k=1, return_metadata=True)
assert "a" in output[0].metadata.keys()
assert "b" in output[0].metadata.values()
assert drop(docsearch.index_name)
def test_redis_add_texts_to_existing() -> None:
"""Test adding a new document"""
# Test creating from an existing
docsearch = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
# -- test filters -- #
@pytest.mark.parametrize(
"filter_expr, expected_length, expected_nums",
[
(RedisText("text") == "foo", 1, None),
(RedisFilter.text("text") == "foo", 1, None),
(RedisText("text") % "ba*", 2, ["bar", "baz"]),
(RedisNum("num") > 2, 1, [3]),
(RedisNum("num") < 2, 1, [1]),
(RedisNum("num") >= 2, 2, [2, 3]),
(RedisNum("num") <= 2, 2, [1, 2]),
(RedisNum("num") != 2, 2, [1, 3]),
(RedisFilter.num("num") != 2, 2, [1, 3]),
(RedisFilter.tag("category") == "a", 3, None),
(RedisFilter.tag("category") == "b", 2, None),
(RedisFilter.tag("category") == "c", 2, None),
(RedisFilter.tag("category") == ["b", "c"], 3, None),
],
ids=[
"text-filter-equals-foo",
"alternative-text-equals-foo",
"text-filter-fuzzy-match-ba",
"number-filter-greater-than-2",
"number-filter-less-than-2",
"number-filter-greater-equals-2",
"number-filter-less-equals-2",
"number-filter-not-equals-2",
"alternative-number-not-equals-2",
"tag-filter-equals-a",
"tag-filter-equals-b",
"tag-filter-equals-c",
"tag-filter-equals-b-or-c",
],
)
def test_redis_filters_1(
filter_expr: RedisFilterExpression,
expected_length: int,
expected_nums: Optional[list],
) -> None:
metadata = [
{"name": "joe", "num": 1, "text": "foo", "category": ["a", "b"]},
{"name": "john", "num": 2, "text": "bar", "category": ["a", "c"]},
{"name": "jane", "num": 3, "text": "baz", "category": ["b", "c", "a"]},
]
documents = [Document(page_content="foo", metadata=m) for m in metadata]
docsearch = Redis.from_documents(
documents, FakeEmbeddings(), redis_url=TEST_REDIS_URL
)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2)
assert output == TEST_RESULT
assert drop(TEST_INDEX_NAME)
output = docsearch.similarity_search("foo", k=3, filter=filter_expr)
assert len(output) == expected_length
if expected_nums is not None:
for out in output:
assert (
out.metadata["text"] in expected_nums
or int(out.metadata["num"]) in expected_nums
)
assert drop(docsearch.index_name)
# -- test index specification -- #
def test_index_specification_generation() -> None:
index_schema = {
"text": [{"name": "job"}, {"name": "title"}],
"numeric": [{"name": "salary"}],
}
text = ["foo"]
meta = {"job": "engineer", "title": "principal engineer", "salary": 100000}
docs = [Document(page_content=t, metadata=meta) for t in text]
r = Redis.from_documents(
docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL, index_schema=index_schema
)
output = r.similarity_search("foo", k=1, return_metadata=True)
assert output[0].metadata["job"] == "engineer"
assert output[0].metadata["title"] == "principal engineer"
assert int(output[0].metadata["salary"]) == 100000
info = convert_bytes(r.client.ft(r.index_name).info())
attributes = info["attributes"]
assert len(attributes) == 5
for attr in attributes:
d = make_dict(attr)
if d["identifier"] == "job":
assert d["type"] == "TEXT"
elif d["identifier"] == "title":
assert d["type"] == "TEXT"
elif d["identifier"] == "salary":
assert d["type"] == "NUMERIC"
elif d["identifier"] == "content":
assert d["type"] == "TEXT"
elif d["identifier"] == "content_vector":
assert d["type"] == "VECTOR"
else:
raise ValueError("Unexpected attribute in index schema")
assert drop(r.index_name)
# -- test distance metrics -- #
cosine_schema: Dict = {"distance_metric": "cosine"}
ip_schema: Dict = {"distance_metric": "IP"}
l2_schema: Dict = {"distance_metric": "L2"}
def test_cosine(texts: List[str]) -> None:
@@ -96,7 +256,7 @@ def test_cosine(texts: List[str]) -> None:
texts,
FakeEmbeddings(),
redis_url=TEST_REDIS_URL,
distance_metric="COSINE",
vector_schema=cosine_schema,
)
output = docsearch.similarity_search_with_score("far", k=2)
_, score = output[1]
@@ -107,7 +267,7 @@ def test_cosine(texts: List[str]) -> None:
def test_l2(texts: List[str]) -> None:
"""Test Flat L2 distance."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2"
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=l2_schema
)
output = docsearch.similarity_search_with_score("far", k=2)
_, score = output[1]
@@ -118,7 +278,7 @@ def test_l2(texts: List[str]) -> None:
def test_ip(texts: List[str]) -> None:
"""Test inner product distance."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP"
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=ip_schema
)
output = docsearch.similarity_search_with_score("far", k=2)
_, score = output[1]
@@ -126,29 +286,34 @@ def test_ip(texts: List[str]) -> None:
assert drop(docsearch.index_name)
def test_similarity_search_limit_score(texts: List[str]) -> None:
def test_similarity_search_limit_distance(texts: List[str]) -> None:
"""Test similarity search limit score."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
texts,
FakeEmbeddings(),
redis_url=TEST_REDIS_URL,
)
output = docsearch.similarity_search_limit_score("far", k=2, score_threshold=0.1)
assert len(output) == 1
_, score = output[0]
assert score == COSINE_SCORE
output = docsearch.similarity_search(texts[0], k=3, distance_threshold=0.1)
# can't check score but length of output should be 2
assert len(output) == 2
assert drop(docsearch.index_name)
def test_similarity_search_with_score_with_limit_score(texts: List[str]) -> None:
def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> None:
"""Test similarity search with score with limit score."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
)
output = docsearch.similarity_search_with_relevance_scores(
"far", k=2, score_threshold=0.1
output = docsearch.similarity_search_with_score(
texts[0], k=3, distance_threshold=0.1, return_metadata=True
)
assert len(output) == 1
_, score = output[0]
assert score == COSINE_SCORE
assert len(output) == 2
for out, score in output:
if out.page_content == texts[1]:
score == COSINE_SCORE
assert drop(docsearch.index_name)
@@ -156,6 +321,48 @@ def test_delete(texts: List[str]) -> None:
"""Test deleting a new document"""
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
ids = docsearch.add_texts(["foo"])
got = docsearch.delete(ids=ids)
got = docsearch.delete(ids=ids, redis_url=TEST_REDIS_URL)
assert got
assert drop(docsearch.index_name)
def test_redis_as_retriever() -> None:
texts = ["foo", "foo", "foo", "foo", "bar"]
docsearch = Redis.from_texts(
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
)
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
results = retriever.get_relevant_documents("foo")
assert len(results) == 3
assert all([d.page_content == "foo" for d in results])
assert drop(docsearch.index_name)
def test_redis_retriever_distance_threshold() -> None:
texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
retriever = docsearch.as_retriever(
search_type="similarity_distance_threshold",
search_kwargs={"k": 3, "distance_threshold": 0.1},
)
results = retriever.get_relevant_documents("foo")
assert len(results) == 2
assert drop(docsearch.index_name)
def test_redis_retriever_score_threshold() -> None:
texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
retriever = docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.91},
)
results = retriever.get_relevant_documents("foo")
assert len(results) == 2
assert drop(docsearch.index_name)

View File

@@ -12,6 +12,21 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"This Agreement is governed by English law.\nSources: 28-pl",
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"This Agreement is governed by English law.\nsource: 28-pl",
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"This Agreement is governed by English law.\nSource: 28-pl",
"This Agreement is governed by English law.\n",
"28-pl",
),
(
"This Agreement is governed by English law.\n"
"SOURCES: 28-pl\n\n"