LazyGraphRAG in LangChain¶
Introduction¶
In LazyGraphRAG, Microsoft demonstrates significant cost and performance benefits to delaying the construction of a knowledge graph. This is largely because not all documents need to be analyzed. However, it is also benefical that documents by the time documents are analyzed the question is already known, allowing irrelevant information to be ignored.
We've noticed similar cost benefits to building a document graph linking content based on simple properties such as extracted keywords compared to building a complete knowledge graph. For the Wikipedia dataset used in this notebook, we estimated it would have taken $70k to build a knowledege graph using the example from LangChain, while the document graph was basically free.
In this notebook we demonstrate how to populate a document graph with Wikipedia articles linked based on mentions in the articles and extracted keywords. Keyword extraction uses a local KeyBERT model, making it fast and cost-effective to construct these graphs. We'll then show how to build out a chain which does the steps of Lazy GraphRAG -- retrieving articles, extracting claims from each community, ranking and selecting the top claims, and generating an answer based on those claims.
Environment Setup¶
The following block will configure the environment from the Colab Secrets. To run it, you should have the following Colab Secrets defined and accessible to this notebook:
OPENAI_API_KEY
: The OpenAI key.ASTRA_DB_API_ENDPOINT
: The Astra DB API endpoint.ASTRA_DB_APPLICATION_TOKEN
: The Astra DB Application token.LANGCHAIN_API_KEY
: Optional. If defined, will enable LangSmith tracing.ASTRA_DB_KEYSPACE
: Optional. If defined, will specify the Astra DB keyspace. If not defined, will use the default.
# Install modules.
#
# On Apple hardware, "spacy[apple]" will improve performance.
%pip install \
langchain-core \
langchain-astradb \
langchain-openai \
langchain-graph-retriever \
spacy \
graph-rag-example-helpers
The last package -- graph-rag-example-helpers
-- includes some helpers for setting up environment helpers and allowing the loading of wikipedia data to be restarted if it fails.
# Downloads the model used by Spacy for extracting entities.
!python -m spacy download en_core_web_sm
# Configure import paths.
import os
import sys
sys.path.append("../../")
# Initialize environment variables.
from graph_rag_example_helpers.env import Environment, initialize_environment
initialize_environment(Environment.ASTRAPY)
os.environ["LANGCHAIN_PROJECT"] = "lazy-graph-rag"
# The full dataset is ~6m documents, and takes hours to load.
# The short dataset is 1000 documents and loads quickly.
# Change this to `True` to use the larger dataset.
USE_SHORT_DATASET = True
Part 1: Loading Data¶
First, we'll demonstrate how to load Wikipedia data into an AstraDBVectorStore
, using the mentioned articles and keywords as metadata fields.
In this section, we're not actually doing anything special for the graph -- we're just populating the metadata with fields that useful describe our content.
Create Documents from Wikipedia Articles¶
The first thing we need to do is create the LangChain
Document
s we'll import.
To do this, we write some code to convert lines from a JSON file downloaded from 2wikimultihop and create a Document
.
We populate the id
and metadata["mentions"]
from information in this file.
Then, we run those documents through the SpacyNERTransformer
to populate metadata["entities"]
with entities named in the article.
import json
from collections.abc import Iterator
from langchain_core.documents import Document
from langchain_graph_retriever.transformers.spacy import (
SpacyNERTransformer,
)
def parse_document(line: bytes) -> Document:
"""Reads one JSON line from the wikimultihop dump."""
para = json.loads(line)
id = para["id"]
title = para["title"]
# Use structured information (mentioned Wikipedia IDs) as metadata.
mentioned_ids = [id for m in para["mentions"] for m in m["ref_ids"] or []]
return Document(
id=id,
page_content=" ".join(para["sentences"]),
metadata={
"mentions": mentioned_ids,
"title": title,
},
)
NER_TRANSFORMER = SpacyNERTransformer(
limit=1000,
exclude_labels={"CARDINAL", "MONEY", "QUANTITY", "TIME", "PERCENT", "ORDINAL"},
)
# Load data in batches, using GLiNER to extract entities.
def prepare_batch(lines: Iterator[str]) -> Iterator[Document]:
# Parse documents from the batch of lines.
docs = [parse_document(line) for line in lines]
docs = NER_TRANSFORMER.transform_documents(docs)
return docs
Create the AstraDBVectorStore¶
Next, we create the Vector Store we're going to load these documents into. In our case, we use DataStax Astra DB with Open AI embeddings.
from langchain_astradb import AstraDBVectorStore
from langchain_openai import OpenAIEmbeddings
COLLECTION = "lazy_graph_rag_short" if USE_SHORT_DATASET else "lazy_graph_rag"
store = AstraDBVectorStore(
embedding=OpenAIEmbeddings(),
collection_name=COLLECTION,
pre_delete_collection=USE_SHORT_DATASET,
)
Loading Data into the Store¶
Next, we perform the actual loading. This takes a while, so we use a helper utility to persist which batches have been written so we can resume if there are any failures.
On OS X, it is useful to run caffeinate -dis
in a shell to prevent the machine from going to sleep and seems to reduce errors.
import os
import os.path
from graph_rag_example_helpers.datasets.wikimultihop import aload_2wikimultihop
# Path to the file `para_with_hyperlink.zip`.
# See instructions here to download from
# [2wikimultihop](https://github.com/Alab-NII/2wikimultihop?tab=readme-ov-file#new-update-april-7-2021).
PARA_WITH_HYPERLINK_ZIP = os.path.join(os.getcwd(), "para_with_hyperlink.zip")
await aload_2wikimultihop(
limit=100 if USE_SHORT_DATASET else None,
full_para_with_hyperlink_zip_path=PARA_WITH_HYPERLINK_ZIP,
store=store,
batch_prepare=prepare_batch,
)
At this point, we've created a VectorStore
with the Wikipedia articles.
Each article is associated with metadata identifying other articles it mentions and entities from the article.
As is, this is useful for performing a vector search filtered to articles mentioning a specific term or performing an entity seach on the documents.
The library langchain-graph-retriever
makes this even more useful by allowing articles to be traversed based on relationships such as articles mentioned in the current article (or mentioning the current article) or articles providing more information on the entities mentioned in the current article.
In the next section we'll see not just how we can use the relationships in the metadata to retrieve more articles, but we'll go a step further and perform Lazy GraphRAG to extract relevant claims from both the similar and related articles and use the most relevant claims to answer the question.
Part 2: Lazy Graph RAG via Hierarchical Summarization¶
As we've noted before, eagerly building a knowledge graph is prohibitively expensive. Microsoft seems to agree, and recently introduced LazyGraphRAG, which enables GraphRAG to be performed late -- after a query is retrieved.
We implement the LazyGraphRAG technique using the traversing retrievers as follows:
- Retrieve a good number of nodes using a traversing retrieval.
- Identify communities in the retrieved sub-graph.
- Extract claims from each community relevant to the query using an LLM.
- Rank each of the claims based on the relevance to the question and select the top claims.
- Generate an answer to the question based on the extracted claims.
LangChain for Extracting Claims¶
The first thing we do is create a chain that produces the claims. Given an input containing the question and the retrieved communities, it applies an LLM in parallel extracting claims from each community.
A claim is just a string representing the statement and the source_id
of the document. We request structured output so we get a list of claims.
from collections.abc import Iterable
from operator import itemgetter
from typing import TypedDict
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableParallel, chain
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
class Claim(BaseModel):
"""Representation of an individual claim from a source document(s)."""
claim: str = Field(description="The claim from the original document(s).")
source_id: str = Field(description="Document ID containing the claim.")
class Claims(BaseModel):
"""Claims extracted from a set of source document(s)."""
claims: list[Claim] = Field(description="The extracted claims.")
MODEL = ChatOpenAI(model="gpt-4o", temperature=0)
CLAIMS_MODEL = MODEL.with_structured_output(Claims)
CLAIMS_PROMPT = ChatPromptTemplate.from_template("""
Extract claims from the following related documents.
Only return claims appearing within the specified documents.
If no documents are provided, do not make up claims or documents.
Claims (and scores) should be relevant to the question.
Don't include claims from the documents if they are not directly or indirectly
relevant to the question.
If none of the documents make any claims relevant to the question, return an
empty list of claims.
If multiple documents make similar claims, include the original text of each as
separate claims. Score the most useful and authoritative claim higher than
similar, lower-quality claims.
Question: {question}
{formatted_documents}
""")
# TODO: Few-shot examples? Possibly with a selector?
def format_documents_with_ids(documents: Iterable[Document]) -> str:
formatted_docs = "\n\n".join(
f"Document ID: {doc.id}\nContent: {doc.page_content}" for doc in documents
)
return formatted_docs
CLAIM_CHAIN = (
RunnableParallel(
{
"question": itemgetter("question"),
"formatted_documents": itemgetter("documents")
| RunnableLambda(format_documents_with_ids),
}
)
| CLAIMS_PROMPT
| CLAIMS_MODEL
)
class ClaimsChainInput(TypedDict):
question: str
communities: Iterable[Iterable[Document]]
@chain
async def claims_chain(input: ClaimsChainInput) -> Iterable[Claim]:
question = input["question"]
communities = input["communities"]
# TODO: Use openai directly so this can use the batch API for performance/cost?
community_claims = await CLAIM_CHAIN.abatch(
[{"question": question, "documents": community} for community in communities]
)
return [claim for community in community_claims for claim in community.claims]
LangChain for Ranking Claims¶
The next chain is used for ranking the claims so we can select the most relevant to the question.
This is based on ideas from RankRAG.
Specifically, the prompt is constructed so that the next token should be True
if the content is relevant and False
if not.
The probability of the token is used to determine the relevance -- True
with a higher probability is more relevant than True
with a lesser probability.
import math
from langchain_core.runnables import chain
RANK_PROMPT = ChatPromptTemplate.from_template("""
Rank the relevance of the following claim to the question.
Output "True" if the claim is relevant and "False" if it is not.
Only output True or False.
Question: Where is Seattle?
Claim: Seattle is in Washington State.
Relevant: True
Question: Where is LA?
Claim: New York City is in New York State.
Relevant: False
Question: {question}
Claim: {claim}
Relevant:
""")
def compute_rank(msg):
logprob = msg.response_metadata["logprobs"]["content"][0]
prob = math.exp(logprob["logprob"])
token = logprob["token"]
if token == "True":
return prob
elif token == "False":
return 1.0 - prob
else:
raise ValueError(f"Unexpected logprob: {logprob}")
RANK_CHAIN = RANK_PROMPT | MODEL.bind(logprobs=True) | RunnableLambda(compute_rank)
class RankChainInput(TypedDict):
question: str
claims: Iterable[Claim]
@chain
async def rank_chain(input: RankChainInput) -> Iterable[Claim]:
# TODO: Use openai directly so this can use the batch API for performance/cost?
claims = input["claims"]
ranks = await RANK_CHAIN.abatch(
[{"question": input["question"], "claim": claim} for claim in claims]
)
rank_claims = sorted(
zip(ranks, claims, strict=True), key=lambda rank_claim: rank_claim[0]
)
return [claim for _, claim in rank_claims]
We could extend this by using an MMR-like strategy for selecting claims. Specifically, we could combine the relevance of the claim to the question and the diversity compared to already selected claims to select the best variety of claims.
LazyGraphRAG in LangChain¶
Finally, we produce a chain that puts everything together.
Given a GraphRetriever
it retrieves documents, creates communities using edges amongst the retrieved documents, extracts claims from those communities, ranks and selects the best claims, and then answers the question using those claims.
from typing import Any
from graph_retriever.edges import EdgeSpec, MetadataEdgeFunction
from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables import chain
from langchain_graph_retriever import GraphRetriever
from langchain_graph_retriever.document_graph import create_graph, group_by_community
@chain
async def lazy_graph_rag(
question: str,
*,
retriever: GraphRetriever,
model: BaseLanguageModel,
edges: Iterable[EdgeSpec] | MetadataEdgeFunction | None = None,
max_tokens: int = 1000,
**kwargs: Any,
) -> str:
"""Retrieve claims relating to the question using LazyGraphRAG.
Returns the top claims up to the given `max_tokens` as a markdown list.
"""
edges = edges or retriever.edges
if edges is None:
raise ValueError("Must specify 'edges' in invocation or retriever")
# 1. Retrieve documents using the (traversing) retriever.
documents = await retriever.ainvoke(question, edges=edges, **kwargs)
# 2. Create a graph and extract communities.
document_graph = create_graph(documents, edges=edges)
communities = group_by_community(document_graph)
# 3. Extract claims from the communities.
claims = await claims_chain.ainvoke(
{"question": question, "communities": communities}
)
# 4. Rank the claims and select claims up to the given token limit.
result_claims = []
tokens = 0
for claim in await rank_chain.ainvoke({"question": question, "claims": claims}):
claim_str = f"- {claim.claim} (Source: {claim.source_id})"
tokens += model.get_num_tokens(claim_str)
if tokens > max_tokens:
break
result_claims.append(claim_str)
return "\n".join(result_claims)
Using Lazy GraphRAG in LangChain¶
Finally, we sue the Lazy GraphRAG chain we created on the store we populated earlier.
from graph_retriever.edges import Id
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_graph_retriever import GraphRetriever
RETRIEVER = GraphRetriever(
store=store,
edges=[("mentions", Id()), ("entities", "entities")],
k=100,
start_k=30,
adjacent_k=20,
max_depth=3,
)
ANSWER_PROMPT = PromptTemplate.from_template("""
Answer the question based on the supporting claims.
Only use information from the claims. Do not guess or make up any information.
Where possible, reference and quote the supporting claims.
Question: {question}
Claims:
{claims}
""")
LAZY_GRAPH_RAG_CHAIN = (
{
"question": RunnablePassthrough(),
"claims": RunnablePassthrough()
| lazy_graph_rag.bind(
retriever=RETRIEVER,
model=MODEL,
max_tokens=1000,
),
}
| ANSWER_PROMPT
| MODEL
)
QUESTION = "Why are Bermudan sloop ships widely prized compared to other ships?"
result = await LAZY_GRAPH_RAG_CHAIN.ainvoke(QUESTION)
result.content
'Bermudan sloop ships are widely prized for several reasons. Firstly, they feature the Bermuda rig, which is popular because it is easier to sail with a smaller crew or even single-handed, is cheaper due to having less hardware, and performs well when sailing into the wind (Source: 48520). Additionally, Bermuda sloops were constructed using Bermuda cedar, a material valued for its durability and resistance to rot, contributing to the ships' longevity and performance (Source: 17186373). These factors combined make Bermudan sloops highly valued compared to other ships.'
For comparison, below are the results to the same question using a basic RAG pattern with just vector similarity.
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
VECTOR_ANSWER_PROMPT = PromptTemplate.from_template("""
Answer the question based on the provided documents.
Only use information from the documents. Do not guess or make up any information.
Question: {question}
Documents:
{documents}
""")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
VECTOR_CHAIN = (
{
"question": RunnablePassthrough(),
"documents": (store.as_retriever() | format_docs),
}
| VECTOR_ANSWER_PROMPT
| MODEL
)
result = VECTOR_CHAIN.invoke(QUESTION)
result.content
'The documents do not provide specific reasons why Bermudan sloop ships are widely prized compared to other ships. They describe the development and characteristics of the Bermuda sloop, such as its fore-and-aft rigged single-masted design and the use of the Bermuda rig with triangular sails, but do not explicitly state why these ships are particularly valued over others.'
The LazyGraphRAG chain is great when a question needs to consider a large amount of relevant information in order to produce a thorough answer.
Conclusion¶
This post demonstrated how easy it is to implement Lazy GraphRAG on top of a document graph.
It used langchain-graph-retriever
from the graph-rag project to implement the document graph and graph-based retrieval on top of an existing LangChain VectorStore
.
This means you can focus on populating and using your VectorStore
with useful metadata and add graph-based retrieval and even Lazy GraphRAG when you need it.
Any LangChain VectorStore
can be used with Lazy GraphRAG without needing to change or re-ingest the stored documents.
Knowledge Graphs and GraphRAG shouldn't be hard or scary.
Start simple and easily overlay edges when you need them.
Graph retrievers and LazyGraph RAG work well with agents.
You can allow the agent to retrieve differently depending on the question -- doing a vector only search for simple questions, traversing to mentioned articles for a deeper question or traversing to articles that cite this to see if there is newer information available.
We'll show how to combine these techniques with agents in a future post.
Until then, give langchain-graph-retriever
a try and let us know how it goes!