feat: add GraphRAG framework and integrate TuGraph (#1506)

Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Florian
2024-05-16 15:39:50 +08:00
committed by GitHub
parent 593e974405
commit a9087c3853
133 changed files with 10139 additions and 6631 deletions

View File

@@ -0,0 +1 @@
"""Module for transformer."""

View File

@@ -0,0 +1,26 @@
"""Transformer base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
logger = logging.getLogger(__name__)
class TransformerBase:
"""Transformer base class."""
class EmbedderBase(TransformerBase, ABC):
"""Embedder base class."""
class ExtractorBase(TransformerBase, ABC):
"""Extractor base class."""
@abstractmethod
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""
class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""

View File

@@ -0,0 +1,50 @@
"""KeywordExtractor class."""
import logging
from typing import List, Optional
from dbgpt.core import LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
KEYWORD_EXTRACT_PT = (
"A question is provided below. Given the question, extract up to "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question.\n"
"Generate as more as possible synonyms or alias of the keywords "
"considering possible cases of capitalization, pluralization, "
"common expressions, etc.\n"
"Avoid stopwords.\n"
"Provide the keywords and synonyms in comma-separated format."
"Formatted keywords and synonyms text should be separated by a semicolon.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
"Keywords:\nAlice,mother,Bob;mummy\n"
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
"Keywords:\nPhilz,coffee shop,Berkeley,1982;coffee bar,coffee house\n"
"---------------------\n"
"Text: {text}\n"
"Keywords:\n"
)
logger = logging.getLogger(__name__)
class KeywordExtractor(LLMExtractor):
"""KeywordExtractor class."""
def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the KeywordExtractor."""
super().__init__(llm_client, model_name, KEYWORD_EXTRACT_PT)
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]:
keywords = set()
for part in text.split(";"):
for s in part.strip().split(","):
keyword = s.strip()
if keyword:
keywords.add(keyword)
if limit and len(keywords) >= limit:
return list(keywords)
return list(keywords)

View File

@@ -0,0 +1,50 @@
"""TripletExtractor class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.base import ExtractorBase
logger = logging.getLogger(__name__)
class LLMExtractor(ExtractorBase, ABC):
"""LLMExtractor class."""
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
"""Initialize the LLMExtractor."""
self._llm_client = llm_client
self._model_name = model_name
self._prompt_template = prompt_template
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLm."""
template = HumanPromptTemplate.from_template(self._prompt_template)
messages = template.format_messages(text=text)
# use default model if needed
if not self._model_name:
models = await self._llm_client.models()
if not models:
raise Exception("No models available")
self._model_name = models[0].model
logger.info(f"Using model {self._model_name} to extract")
model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=self._model_name, messages=model_messages)
response = await self._llm_client.generate(request=request)
if not response.success:
code = str(response.error_code)
reason = response.text
logger.error(f"request llm failed ({code}) {reason}")
return []
if limit and limit < 1:
ValueError("optional argument limit >= 1")
return self._parse_response(response.text, limit)
@abstractmethod
def _parse_response(self, text: str, limit: Optional[int] = None) -> List:
"""Parse llm response."""

View File

@@ -0,0 +1,10 @@
"""Text2Cypher class."""
import logging
from dbgpt.rag.transformer.base import TranslatorBase
logger = logging.getLogger(__name__)
class Text2Cypher(TranslatorBase):
"""Text2Cypher class."""

View File

@@ -0,0 +1,10 @@
"""Text2GQL class."""
import logging
from dbgpt.rag.transformer.base import TranslatorBase
logger = logging.getLogger(__name__)
class Text2GQL(TranslatorBase):
"""Text2GQL class."""

View File

@@ -0,0 +1,10 @@
"""Text2Vector class."""
import logging
from dbgpt.rag.transformer.base import EmbedderBase
logger = logging.getLogger(__name__)
class Text2Vector(EmbedderBase):
"""Text2Vector class."""

View File

@@ -0,0 +1,71 @@
"""TripletExtractor class."""
import logging
import re
from typing import Any, List, Optional, Tuple
from dbgpt.core import LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
logger = logging.getLogger(__name__)
TRIPLET_EXTRACT_PT = (
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
"Triplets:\n(Alice, is mother of, Bob)\n"
"Text: Alice has 2 apples.\n"
"Triplets:\n(Alice, has 2, apple)\n"
"Text: Alice was given 1 apple by Bob.\n"
"Triplets:(Bob, gives 1 apple, Bob)\n"
"Text: Alice was pushed by Bob.\n"
"Triplets:(Bob, pushes, Alice)\n"
"Text: Bob's mother Alice has 2 apples.\n"
"Triplets:\n(Alice, is mother of, Bob)\n(Alice, has 2, apple)\n"
"Text: A Big monkey climbed up the tall fruit tree and picked 3 peaches.\n"
"Triplets:\n(monkey, climbed up, fruit tree)\n(monkey, picked 3, peach)\n"
"Text: Alice has 2 apples, she gives 1 to Bob.\n"
"Triplets:\n"
"(Alice, has 2, apple)\n(Alice, gives 1 apple, Bob)\n"
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
"Triplets:\n"
"(Philz, is, coffee shop)\n(Philz, founded in, Berkeley)\n"
"(Philz, founded in, 1982)\n"
"---------------------\n"
"Text: {text}\n"
"Triplets:\n"
)
class TripletExtractor(LLMExtractor):
"""TripletExtractor class."""
def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the TripletExtractor."""
super().__init__(llm_client, model_name, TRIPLET_EXTRACT_PT)
def _parse_response(
self, text: str, limit: Optional[int] = None
) -> List[Tuple[Any, ...]]:
triplets = []
for line in text.split("\n"):
for match in re.findall(r"\((.*?)\)", line):
splits = match.split(",")
parts = [split.strip() for split in splits if split.strip()]
if len(parts) == 3:
parts = [
p.strip(
"`~!@#$%^&*()-=+[]\\{}|;':\",./<>?"
"·!¥&*()—【】、「」;‘’:“”,。、《》?"
)
for p in parts
]
triplets.append(tuple(parts))
if limit and len(triplets) >= limit:
return triplets
return triplets