Building a rebust GraphRAG system for specific use case

Part One: Preparing a Custom Dataset for Your Graph Database

Introduction to Retrieval Augmented Generation (RAG)

img

Retrieval Augmented Generation (RAG) has emerged as a powerful technique for enhancing the capabilities of Large Language Models (LLMs). By integrating external knowledge sources into the LLM's reasoning process, RAG enables LLMs to generate more accurate, comprehensive, and contextually relevant responses. This blog post series explores a specific application of RAG : Building a rebust GraphRAG system for specific use case.

We'll break down this process into three parts:

  1. Preparing a Custom Dataset for Your Database: This part focuses on creating the training data needed to fine-tune an LLM for translating natural language questions into Cypher queries.
  2. The Fine-Tuning Process: Here, we'll delve into the technical details of fine-tuning an LLM specifically for the text-to-Cypher task.
  3. Building a Q&A System: In the final part, we'll demonstrate how to integrate the fine-tuned LLM with a knowledge extraction component to build a complete question-answering system that can effectively extract insights from your graph database.

Why Graph RAG?

img

While traditional RAG often relies on simple vector similarity search over unstructured text, this approach falls short when dealing with complex, multi-hop questions that require connecting information from multiple sources. This is where Graph RAG shines. By leveraging the power of graph databases, Graph RAG can efficiently traverse relationships between entities and retrieve the interconnected information needed to answer complex queries.

The problem Graph RAG solves: Traditional RAG struggles with questions that require understanding the relationships between different pieces of information. For example, a question like "Who are the founders of companies that were acquired by Google?" requires identifying companies acquired by Google and then finding their founders. This type of question involves multiple steps (or hops) through the data. Graph databases, with their explicit representation of relationships, are ideally suited for navigating these multi-hop queries.

What is Cypher and Why Use LLMs to Generate It?

Cypher is a declarative graph query language specifically designed for querying graph databases like Neo4j. It allows you to express complex patterns and relationships in a concise and intuitive way.

Why LLMs for Cypher generation? Manually writing Cypher queries can be challenging, especially for non-technical users. By fine-tuning an LLM to translate natural language questions into Cypher, we can empower users to interact with graph databases using their natural language, making the data more accessible and the insights easier to uncover.

Generating a Dataset for Text-to-Cypher Fine-Tuning

In this first part of the series, we will focus on generating a high-quality dataset that will be used to fine-tune our LLM for the text-to-Cypher task. This dataset will consist of pairs of natural language questions and their corresponding Cypher queries.

Why Fine-Tune?

While LLMs have shown impressive abilities in various tasks, they are not inherently proficient in generating Cypher queries, especially for complex graph structures and domain-specific datasets. Fine-tuning allows us to adapt a pre-trained LLM to our specific task and dataset, resulting in a model that is significantly better at translating natural language questions into accurate and efficient Cypher queries.

By fine-tuning on a custom dataset that reflects the structure and content of your graph database, you can achieve higher accuracy, better performance, and more relevant results in your Graph RAG application.


Code Example

Installing Dependencies :

pip install  pandas langchain langchain_core neo4j langchain-openai 

Generate questions based on the Custom Neo4j DB schema

Here we import essential Python modules :

import time
import random
import os
import json
import pandas as pd 
from typing import List
from langchain_openai import ChatOpenAI
from langchain_core.prompts import (
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
    ChatPromptTemplate
)
from langchain_core.pydantic_v1 import BaseModel, Field

Defining Query Types :

Here we define a dictionary named query_types that categorizes the different types of Cypher queries we aim to generate. Each key in the dictionary represents a query type (e.g., "Simple Retrieval Queries", "Complex Aggregation Queries"), and the corresponding value provides a detailed description of that query type and specific instructions for generating relevant questions.

query_types = {
    "Simple Retrieval Queries": "These queries focus on basic data extraction, retrieving nodes or relationships based on straightforward criteria such as labels, properties, or direct relationships. Examples include fetching all nodes labeled as 'Person' or retrieving relationships of a specific type like 'EMPLOYED_BY'. Simple retrieval is essential for initial data inspections and basic reporting tasks. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements",
    "Complex Retrieval Queries": "These advanced queries use the rich pattern-matching capabilities of Cypher to handle multiple node types and relationship patterns. They involve sophisticated filtering conditions and logical operations to extract nuanced insights from interconnected data points. An example could be finding all 'Person' nodes who work in a 'Department' with over 50 employees and have at least one 'REPORTS_TO' relationship. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements",
    "Simple Aggregation Queries": "Simple aggregation involves calculating basic statistical metrics over properties of nodes or relationships, such as counting the number of nodes, averaging property values, or determining maximum and minimum values. These queries summarize data characteristics and support quick analytical conclusions. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements",
    "Pathfinding Queries": "Specialized in exploring connections between nodes, these queries are used to find the shortest path, identify all paths up to a certain length, or explore possible routes within a network. They are essential for applications in network analysis, routing, logistics, and social network exploration. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements",
    "Complex Aggregation Queries": "The most sophisticated category, these queries involve multiple aggregation functions and often group results over complex subgraphs. They calculate metrics like average number of reports per manager or total sales volume through a network, supporting strategic decision making and advanced reporting. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements",
    "Verbose query": "These queries are characterized by their explicit and detailed specifications about the data retrieval process and the exact information needed. They involve elaborate instructions for navigating through complex data structures, specifying precise criteria for inclusion, exclusion, and sorting of data points. Verbose queries typically require the breakdown of each step in the querying process, from the initial identification of relevant data nodes and relationships to the intricate filtering and sorting mechanisms that must be applied. Always limit the number of results if more than one row is expected from the questions by saying 'first 3' or 'top 5' elements",
}

This categorization helps guide the LLM in generating diverse and targeted questions that cover various aspects of the graph database schema.

Defining the LLM

Here we initialize the ChatOpenAI class from the langchain_openai library, configuring it with the necessary credentials and model selection. We provide the OpenAI API endpoint, API key, and the specific model name we want to use.

OPENAI_ENDPOINT = "xxxxxxxxxxx"
OPENAI_API_KEY="xxxxxxxxxxxxx"
MODEL_NAME = "xxxxxxxxxxxx"

llm = ChatOpenAI(base_url=OPENAI_ENDPOINT,
             api_key=OPENAI_API_KEY,
             model=MODEL_NAME,)

This sets up the language model that will be used for various tasks, including generating questions and translating natural language to Cypher queries.

LLM with Structured Output

Here we define a Pydantic model named Question that specifies the structure of the output we expect from the LLM when generating questions. It has a single field, questions, which is a list of strings representing the generated questions.

class Question(BaseModel):
    questions: List[str] = Field(
        description="List of relevant questions for the particular graph schema. Make sure that questions can be answered with information from the schema and that the questions are diverse as possible. Make sure that the schema and the example values contains the information that can answer the questions! Do not ask questions that cannot be answered based on the provided schema. For example, if no information about subtitles can be found in the graph, don't ask any information about subtitles. Make sure to always limit the results to less than 10 results by saying 3 users, or top 5 movies, or similar."
    )
structured_llm = llm.with_structured_output(Question)

We then use the with_structured_output method of the llm object to create a new LLM object (structured_llm) that is specifically configured to return output adhering to the Question model's structure. This ensures that the LLM's output is parsed and returned in a structured format that is easy to work with.

Defining the Prompts and Creating the Chain

Here we define the prompts and create a LangChain chain for generating questions.

We start by defining a system_prompt that provides high-level instructions to the LLM about its task: generating 100 diverse and relevant questions related to a Neo4j graph schema.


system_prompt = """Your task is to generate 100 questions that are directly related to a specific graph schema in Neo4j. Each question should target distinct aspects of the schema, such as relationships between nodes, properties of nodes, or characteristics of node types. Ensure that the questions vary in complexity, covering basic, intermediate, and advanced queries.

Avoid ambiguous questions. For clarity, an ambiguous question is one that can be interpreted in multiple ways or does not have a straightforward answer based on the schema. For example, avoid asking, "What is related to this?" without specifying the node type or relationship.
Please design each question to yield a limited number of results, specifically between 10 to 50 results. This will ensure that the queries are precise and suitable for detailed analysis and training.
The goal of these questions is to create a dataset for training AI models to convert natural language queries into Cypher queries effectively.
It is vital that the database contains information that can answer the question!
Make sure to generate 100 questions!"""

After that, we create a default_prompt using ChatPromptTemplate, which combines a system message and a human message. The system message incorporates the system_prompt and adds specific instructions for creating the questions.

The human message provides the graph schema and example values as input to the LLM.

default_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(
            f"{system_prompt} Follow these instructions create minimum 100 possible questions: {{instructions}}"
        ),
        HumanMessagePromptTemplate.from_template(
            "Make sure to create questions for the following graph schema:{input}\n Here are some example nodes and relationship values: {values}. Don't use any values that aren't found in the schema or in provided values. Also, do not ask questions that there is no way to answer based on the schema or provided example values. Don't include question index or the sequence of the question in the list. Make sure your question is complete and clear"
        ),
    ]
)

chain = default_prompt | structured_llm

Finally, we create a LangChain chain by piping the default_prompt to the structured_llm. This chain will be used to invoke the LLM with the prompt and retrieve the generated questions in a structured format.

Testing the Chain

The chain.invoke method then executes the chain, passing the schema, instructions, and example values as input to the LLM.

instructions = ".."
values = graph.query(
            """
            MATCH (n) WHERE rand() > 0.6 WITH n LIMIT 2
            CALL { WITH n MATCH p=(n)-[*3..3]-() RETURN p LIMIT 1}
            RETURN p
            """
            )
chain.invoke( {"input": graph.schema, "instructions": instructions, "values": values})

The output will be a Question object containing the generated questions.

Starting the Question Generation Process

We iterate through the query_types dictionary, constructing specific instructions for each query type based on its description. For each query type, we retrieve example values from the graph and invoke the LangChain chain with the schema, instructions, and values.

schema = graph.schema
all_questions = []
for query_type in query_types:
  instructions = f"{query_type}: {query_types[query_type]}"
  values = graph.query(
            """
            MATCH (n) WHERE rand() > 0.6 WITH n LIMIT 2
            CALL { WITH n MATCH p=(n)-[*3..3]-() RETURN p LIMIT 1}
            RETURN p
            """
            )
  questions = chain.invoke( {"input": schema, "instructions": instructions, "values": values})
  all_questions.extend([{"question": el, "type": query_type} for el in questions.questions])

The generated questions are then appended to the all_questions list, along with their corresponding query type.

Saving the Generated Questions

Here we convert the all_questions list into a Pandas DataFrame and save it as a CSV file named 'text2cypher_questions.csv'.

all_questions_df = pd.DataFrame.from_records(all_questions)
all_questions_df.to_csv('text2cypher_questions.csv', index = False)

Having successfully generated a diverse set of questions tailored to our Neo4j database schema, the next critical step is to translate these natural language questions into their corresponding Cypher queries. This translation process will also be powered by a large language model, leveraging its ability to understand both natural language and the intricacies of the Cypher query language. By automating this translation, we can efficiently create a comprehensive dataset of question-Cypher pairs, which will serve as the foundation for fine-tuning a specialized LLM in the subsequent stages.


Generate Answers For the generated questions

Defining the Cypher Query Model

Here we define a Pydantic model named CypherQuery that specifies the structure of the output we expect from the LLM when generating Cypher queries. It has a single field, cypherquery, which represents the generated Cypher query as a string.

We then use the with_structured_output method of the llm object to create a new LLM object (structured_llm_cypher_query) that is specifically configured to return output adhering to the CypherQuery model's structure.

class CypherQuery(BaseModel):
    cypherquery: str = Field(
        description="A correct Neo4J Cypher Query Language without any preambles. Make sure that it follows the schema"
    )


structured_llm_cypher_query = llm.with_structured_output(CypherQuery)

Defining the Prompts and Creating the Chain for Cypher Query Generation

We start by defining a system_prompt_cypher_query that provides specific instructions to the LLM on how to convert natural language questions into Cypher queries. This includes instructions on using specific Cypher syntax and avoiding potential pitfalls.

system_prompt_cypher_query = """Given an input question, convert it to a Cypher query. No pre-amble.
Additional instructions:
- Ensure that queries checking for non-null properties use `IS NOT NULL` in a straightforward manner.
- Don't use `size((n)--(m))` for counting relationships. Instead use the new `count{{(n)--(m))}}` syntax.
- Incorporate the new existential subqueries in examples where the query needs to check for the existence of a pattern.
  Example: MATCH (p:Person)-[r:IS_FRIENDS_WITH]->(friend:Person)
            WHERE exists{{ (p)-[:WORKS_FOR]->(:Company {{name: 'Neo4j'}})}}
            RETURN p, r, friend"""

After that, we create a default_prompt_cypher_query using ChatPromptTemplate, which combines a system message and a human message. The system message uses the system_prompt_cypher_query. The human message also like above provides the graph schema and the question as input to the LLM, prompting it to generate the corresponding Cypher query.

default_prompt_cypher_query = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(
            f"{system_prompt_cypher_query}"
        ),
        HumanMessagePromptTemplate.from_template(
            """Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question: {schema}"
              Question: {question}
              Cypher query:"""),
    ]
)

chain_cypher_query = default_prompt_cypher_query | structured_llm_cypher_query

Finally, we create a LangChain chain by piping the default_prompt_cypher_query to the structured_llm_cypher_query. This chain will be used to invoke the LLM with the prompt and retrieve the generated Cypher query in a structured format.

Testing the Chain for Cypher Query Generation

We provide the graph schema and a placeholder question as input to the invoke method.

chain_cypher_query.invoke({"schema": graph.schema, "question":"Write your question here .. "}).cypherquery

The output will be a CypherQuery object containing the generated Cypher query.

Defining Functions for Cypher Query Retrieval with Retry Logic

Here we define two functions, get_cypher_query and safe_get_cypher_query, to retrieve Cypher queries for given questions while incorporating retry logic to handle potential errors during the LLM interaction.

get_cypher_query simply invokes the chain_cypher_query with the schema and question, introducing a small delay to avoid rate limiting issues.

def get_cypher_query(question):
    cypher_query = chain_cypher_query.invoke({"schema": schema, "question": question}).cypherquery
    time.sleep(1)
    return cypher_query

safe_get_cypher_query wraps get_cypher_query with a retry mechanism. It attempts to retrieve the Cypher query multiple times, handling exceptions that might occur during the process.

def safe_get_cypher_query(question):
    retry_count = 3
    for attempt in range(retry_count):
        try:
            result = get_cypher_query(question)
            print(question, ':', result)
            return result
        except Exception as e:
            print(f"Error: {e}. Retrying ({attempt + 1}/{retry_count})...")
            time.sleep(random.uniform(1, 3))
    return None

Starting the Cypher Query Generation Process

df['cypher'] = df['question'].apply(lambda q: safe_get_cypher_query(q))

Here we apply the safe_get_cypher_query function to the 'question' column of a DataFrame (presumably loaded from the 'text2cypher_questions.csv' file). This generates Cypher queries for each question in the DataFrame and stores them in a new column named 'cypher'.

Saving the Generated Cypher Queries

Here we save the DataFrame containing the questions and their corresponding Cypher queries to a CSV file named 'raw_text2cypher.csv'.

df.to_csv('raw_text2cypher.csv', index = False)

Evaluating the Generated Cypher Query by Running it Against the Database

Testing the Generated Queries

Here we evaluate the generated Cypher queries by executing them against the graph database. We initialize three lists to store information about syntax errors, whether the query returned any results, and whether the query timed out.

syntax_error = []
returns_results = []
timeouts = []

We iterate through the DataFrame (presumably loaded from 'raw_text2cypher.csv') and execute each Cypher query using graph.query. If the query executes successfully and returns data, we mark it as returning results. If a ValueError occurs indicating invalid Cypher syntax, we mark it as a syntax error. If an exception related to timeout occurs, we mark it as a timeout.

for i, row in df.reset_index().iterrows():
    try:
        data = graph.query(row["cypher"])
        if data:
            returns_results.append(True)
        else:
            returns_results.append(False)
        syntax_error.append(False)
        timeouts.append(False)
    except ValueError as e:
        if "Generated Cypher Statement is not valid" in str(e):
            syntax_error.append(True)
            print(f"Syntax error in Cypher query: {e}")
        else:
            syntax_error.append(False)
            print(f"Other ValueError: {e}")
        returns_results.append(False)
        timeouts.append(False)
    except Exception as e:
        if (
            e.code
            == "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration"
        ):
            returns_results.append(False)
            syntax_error.append(False)
            timeouts.append(True)

Here we add the evaluation results (syntax error, timeout, returns results) as new columns to the DataFrame. We then create a new DataFrame final_df containing the question, Cypher query, type, and evaluation results. This DataFrame is saved to a CSV file named 'detailed_text2cypher.csv'.

df["syntax_error"] = syntax_error
df["timeout"] = timeouts
df["returns_results"] = returns_results

final_df = df[["question","cypher","type","syntax_error","timeout","returns_results",]]
final_df.to_csv('detailed_text2cypher.csv', index = False)

Additional Data Cleaning

Here we acknowledge that additional data cleaning might be necessary based on the specific characteristics of the generated queries and the evaluation results. This could involve filtering out queries with syntax errors or timeouts, or further refining the generated Cypher queries to improve their accuracy or efficiency. The quality of the base language model used for generation also significantly impacts the quality of the resulting dataset. In my experience, utilizing a powerful model like Llama 3.1 405B yielded excellent results, minimizing the need for extensive data cleaning.


With a robust dataset of question-Cypher pairs in hand, we're now equipped to embark on the next phase: fine-tuning a large language model specifically for the task of translating natural language into Cypher queries. The quality of the base language model used for fine-tuning plays a crucial role in the final performance. In the second part of this series, we'll dive into the technical intricacies of this fine-tuning process, exploring techniques and strategies for optimizing a model's performance, particularly when using a powerful model like Llama 3.1 405B, on our curated dataset. Join us as we unlock the potential of LLMs to seamlessly bridge the gap between human language and graph database interaction.