Skip to content

graph_rag_example_helpers

datasets

animals

fetch_documents

fetch_documents() -> list[Document]

Download and parse a list of Documents for use with Graph Retriever.

This is a small example dataset with useful links.

This method downloads the dataset each time -- generally it is preferable to invoke this only once and store the documents in memory or a vector store.

RETURNS DESCRIPTION
list[Document]

The fetched animal documents.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/datasets/animals/fetch.py
def fetch_documents() -> list[Document]:
    """
    Download and parse a list of Documents for use with Graph Retriever.

    This is a small example dataset with useful links.

    This method downloads the dataset each time -- generally it is preferable
    to invoke this only once and store the documents in memory or a vector
    store.

    Returns
    -------
    :
        The fetched animal documents.
    """
    response = requests.get(ANIMALS_JSONL_URL)
    response.raise_for_status()  # Ensure we got a valid response

    return [
        Document(id=data["id"], page_content=data["text"], metadata=data["metadata"])
        for line in response.text.splitlines()
        if (data := json.loads(line))
    ]

fetch

fetch_documents
fetch_documents() -> list[Document]

Download and parse a list of Documents for use with Graph Retriever.

This is a small example dataset with useful links.

This method downloads the dataset each time -- generally it is preferable to invoke this only once and store the documents in memory or a vector store.

RETURNS DESCRIPTION
list[Document]

The fetched animal documents.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/datasets/animals/fetch.py
def fetch_documents() -> list[Document]:
    """
    Download and parse a list of Documents for use with Graph Retriever.

    This is a small example dataset with useful links.

    This method downloads the dataset each time -- generally it is preferable
    to invoke this only once and store the documents in memory or a vector
    store.

    Returns
    -------
    :
        The fetched animal documents.
    """
    response = requests.get(ANIMALS_JSONL_URL)
    response.raise_for_status()  # Ensure we got a valid response

    return [
        Document(id=data["id"], page_content=data["text"], metadata=data["metadata"])
        for line in response.text.splitlines()
        if (data := json.loads(line))
    ]

wikimultihop

BatchPreparer module-attribute

BatchPreparer = Callable[
    [Iterator[bytes]], Iterator[Document]
]

Function to apply to batches of lines to produce the document.

aload_2wikimultihop async

aload_2wikimultihop(
    limit: int | None,
    *,
    full_para_with_hyperlink_zip_path: str,
    store: VectorStore,
    batch_prepare: BatchPreparer,
) -> None

Load 2wikimultihop data into the given VectorStore.

PARAMETER DESCRIPTION
limit

Maximum number of lines to load. If a number less than one thousand, limits loading to the given number of lines. If None, loads all content.

TYPE: int | None

full_para_with_hyperlink_zip_path

Path to para_with_hyperlink.zip downloaded following the instructions in 2wikimultihop.

TYPE: str

store

The VectorStore to populate.

TYPE: VectorStore

batch_prepare

Function to apply to batches of lines to produce the document.

TYPE: BatchPreparer

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/datasets/wikimultihop/load.py
async def aload_2wikimultihop(
    limit: int | None,
    *,
    full_para_with_hyperlink_zip_path: str,
    store: VectorStore,
    batch_prepare: BatchPreparer,
) -> None:
    """
    Load 2wikimultihop data into the given `VectorStore`.

    Parameters
    ----------
    limit :
        Maximum number of lines to load.
        If a number less than one thousand, limits loading to the given number of lines.
        If `None`, loads all content.
    full_para_with_hyperlink_zip_path :
        Path to `para_with_hyperlink.zip` downloaded following the instructions
        in
        [2wikimultihop](https://github.com/Alab-NII/2wikimultihop?tab=readme-ov-file#new-update-april-7-2021).
    store :
        The VectorStore to populate.
    batch_prepare :
        Function to apply to batches of lines to produce the document.
    """
    if limit is None or limit > LINES_IN_FILE:
        limit = LINES_IN_FILE

    if limit <= 1000:
        local_path = "../../data/para_with_hyperlink_short.jsonl"
        if os.path.isfile(local_path):
            for batch in batched(
                itertools.islice(open(local_path, "rb").readlines(), limit), BATCH_SIZE
            ):
                docs = batch_prepare(iter(batch))
                store.add_documents(list(docs))
            print(f"Loaded from {local_path}")  # noqa: T201
        else:
            print(f"{local_path} not found, fetching short dataset")  # noqa: T201
            response = requests.get(SHORT_URL)
            response.raise_for_status()  # Ensure we get a valid response

            for batch in batched(
                itertools.islice(response.content.splitlines(), limit), BATCH_SIZE
            ):
                docs = batch_prepare(iter(batch))
                store.add_documents(list(docs))
            print(f"Loaded from {SHORT_URL}")  # noqa: T201
        return

    assert os.path.isfile(full_para_with_hyperlink_zip_path)
    persistence = PersistentIteration(
        journal_name="load_2wikimultihop.jrnl",
        iterator=batched(
            itertools.islice(wikipedia_lines(full_para_with_hyperlink_zip_path), limit),
            BATCH_SIZE,
        ),
    )
    total_batches = ceil(limit / BATCH_SIZE) - persistence.completed_count()
    if persistence.completed_count() > 0:
        print(  # noqa: T201
            f"Resuming loading with {persistence.completed_count()}"
            f" completed, {total_batches} remaining"
        )

    @backoff.on_exception(
        backoff.expo,
        EXCEPTIONS_TO_RETRY,
        max_tries=MAX_RETRIES,
    )
    async def add_docs(batch_docs, offset) -> None:
        from astrapy.exceptions import InsertManyException

        try:
            await store.aadd_documents(batch_docs)
            persistence.ack(offset)
        except InsertManyException as err:
            for err_desc in err.error_descriptors:
                if err_desc.error_code != "DOCUMENT_ALREADY_EXISTS":
                    print(err_desc)  # noqa: T201
            raise

    # We can't use asyncio.TaskGroup in 3.10. This would be simpler with that.
    tasks: list[asyncio.Task] = []

    for offset, batch_lines in tqdm(persistence, total=total_batches):
        batch_docs = batch_prepare(batch_lines)
        if batch_docs:
            task = asyncio.create_task(add_docs(batch_docs, offset))

            # It is OK if tasks are lost upon failure since that means we're
            # aborting the loading.
            tasks.append(task)

            while len(tasks) >= MAX_IN_FLIGHT:
                completed, pending = await asyncio.wait(
                    tasks, return_when=asyncio.FIRST_COMPLETED
                )
                for complete in completed:
                    if (e := complete.exception()) is not None:
                        print(f"Exception in task: {e}")  # noqa: T201
                tasks = list(pending)
        else:
            persistence.ack(offset)

    # Make sure all the tasks are done.
    # This wouldn't be necessary if we used a taskgroup, but that is Python 3.11+.
    while len(tasks) > 0:
        completed, pending = await asyncio.wait(
            tasks, return_when=asyncio.FIRST_COMPLETED
        )
        for complete in completed:
            if (e := complete.exception()) is not None:
                print(f"Exception in task: {e}")  # noqa: T201
        tasks = list(pending)

    assert len(tasks) == 0
    assert persistence.pending_count() == 0

load

BatchPreparer module-attribute
BatchPreparer = Callable[
    [Iterator[bytes]], Iterator[Document]
]

Function to apply to batches of lines to produce the document.

aload_2wikimultihop async
aload_2wikimultihop(
    limit: int | None,
    *,
    full_para_with_hyperlink_zip_path: str,
    store: VectorStore,
    batch_prepare: BatchPreparer,
) -> None

Load 2wikimultihop data into the given VectorStore.

PARAMETER DESCRIPTION
limit

Maximum number of lines to load. If a number less than one thousand, limits loading to the given number of lines. If None, loads all content.

TYPE: int | None

full_para_with_hyperlink_zip_path

Path to para_with_hyperlink.zip downloaded following the instructions in 2wikimultihop.

TYPE: str

store

The VectorStore to populate.

TYPE: VectorStore

batch_prepare

Function to apply to batches of lines to produce the document.

TYPE: BatchPreparer

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/datasets/wikimultihop/load.py
async def aload_2wikimultihop(
    limit: int | None,
    *,
    full_para_with_hyperlink_zip_path: str,
    store: VectorStore,
    batch_prepare: BatchPreparer,
) -> None:
    """
    Load 2wikimultihop data into the given `VectorStore`.

    Parameters
    ----------
    limit :
        Maximum number of lines to load.
        If a number less than one thousand, limits loading to the given number of lines.
        If `None`, loads all content.
    full_para_with_hyperlink_zip_path :
        Path to `para_with_hyperlink.zip` downloaded following the instructions
        in
        [2wikimultihop](https://github.com/Alab-NII/2wikimultihop?tab=readme-ov-file#new-update-april-7-2021).
    store :
        The VectorStore to populate.
    batch_prepare :
        Function to apply to batches of lines to produce the document.
    """
    if limit is None or limit > LINES_IN_FILE:
        limit = LINES_IN_FILE

    if limit <= 1000:
        local_path = "../../data/para_with_hyperlink_short.jsonl"
        if os.path.isfile(local_path):
            for batch in batched(
                itertools.islice(open(local_path, "rb").readlines(), limit), BATCH_SIZE
            ):
                docs = batch_prepare(iter(batch))
                store.add_documents(list(docs))
            print(f"Loaded from {local_path}")  # noqa: T201
        else:
            print(f"{local_path} not found, fetching short dataset")  # noqa: T201
            response = requests.get(SHORT_URL)
            response.raise_for_status()  # Ensure we get a valid response

            for batch in batched(
                itertools.islice(response.content.splitlines(), limit), BATCH_SIZE
            ):
                docs = batch_prepare(iter(batch))
                store.add_documents(list(docs))
            print(f"Loaded from {SHORT_URL}")  # noqa: T201
        return

    assert os.path.isfile(full_para_with_hyperlink_zip_path)
    persistence = PersistentIteration(
        journal_name="load_2wikimultihop.jrnl",
        iterator=batched(
            itertools.islice(wikipedia_lines(full_para_with_hyperlink_zip_path), limit),
            BATCH_SIZE,
        ),
    )
    total_batches = ceil(limit / BATCH_SIZE) - persistence.completed_count()
    if persistence.completed_count() > 0:
        print(  # noqa: T201
            f"Resuming loading with {persistence.completed_count()}"
            f" completed, {total_batches} remaining"
        )

    @backoff.on_exception(
        backoff.expo,
        EXCEPTIONS_TO_RETRY,
        max_tries=MAX_RETRIES,
    )
    async def add_docs(batch_docs, offset) -> None:
        from astrapy.exceptions import InsertManyException

        try:
            await store.aadd_documents(batch_docs)
            persistence.ack(offset)
        except InsertManyException as err:
            for err_desc in err.error_descriptors:
                if err_desc.error_code != "DOCUMENT_ALREADY_EXISTS":
                    print(err_desc)  # noqa: T201
            raise

    # We can't use asyncio.TaskGroup in 3.10. This would be simpler with that.
    tasks: list[asyncio.Task] = []

    for offset, batch_lines in tqdm(persistence, total=total_batches):
        batch_docs = batch_prepare(batch_lines)
        if batch_docs:
            task = asyncio.create_task(add_docs(batch_docs, offset))

            # It is OK if tasks are lost upon failure since that means we're
            # aborting the loading.
            tasks.append(task)

            while len(tasks) >= MAX_IN_FLIGHT:
                completed, pending = await asyncio.wait(
                    tasks, return_when=asyncio.FIRST_COMPLETED
                )
                for complete in completed:
                    if (e := complete.exception()) is not None:
                        print(f"Exception in task: {e}")  # noqa: T201
                tasks = list(pending)
        else:
            persistence.ack(offset)

    # Make sure all the tasks are done.
    # This wouldn't be necessary if we used a taskgroup, but that is Python 3.11+.
    while len(tasks) > 0:
        completed, pending = await asyncio.wait(
            tasks, return_when=asyncio.FIRST_COMPLETED
        )
        for complete in completed:
            if (e := complete.exception()) is not None:
                print(f"Exception in task: {e}")  # noqa: T201
        tasks = list(pending)

    assert len(tasks) == 0
    assert persistence.pending_count() == 0
wikipedia_lines
wikipedia_lines(
    para_with_hyperlink_zip_path: str,
) -> Iterable[bytes]

Return iterable of lines from the wikipedia file.

PARAMETER DESCRIPTION
para_with_hyperlink_zip_path

Path to para_with_hyperlink.zip downloaded following the instructions in 2wikimultihop.

TYPE: str

YIELDS DESCRIPTION
str

Lines from the Wikipedia file.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/datasets/wikimultihop/load.py
def wikipedia_lines(para_with_hyperlink_zip_path: str) -> Iterable[bytes]:
    """
    Return iterable of lines from the wikipedia file.

    Parameters
    ----------
    para_with_hyperlink_zip_path :
        Path to `para_with_hyperlink.zip` downloaded following the instructions
        in
        [2wikimultihop](https://github.com/Alab-NII/2wikimultihop?tab=readme-ov-file#new-update-april-7-2021).

    Yields
    ------
    str
        Lines from the Wikipedia file.
    """
    with zipfile.ZipFile(para_with_hyperlink_zip_path, "r") as archive:
        with archive.open("para_with_hyperlink.jsonl", "r") as para_with_hyperlink:
            yield from para_with_hyperlink

env

NON_SECRETS module-attribute

NON_SECRETS = {
    "ASTRA_DB_API_ENDPOINT",
    "ASTRA_DB_DATABASE_ID",
}

Environment variables that can use input instead of getpass.

Environment

Bases: Enum

Enumeration of supported environments for examples.

ASTRAPY class-attribute instance-attribute

ASTRAPY = auto()

Environment variables for connecting to AstraDB via AstraPy

CASSIO class-attribute instance-attribute

CASSIO = auto()

Environment variables for connecting to AstraDB via CassIO

required_envvars

required_envvars() -> list[str]

Return the required environment variables for this environment.

RETURNS DESCRIPTION
list[str]

The environment variables required in this environment.

RAISES DESCRIPTION
ValueError

If the environment isn't recognized.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/env.py
def required_envvars(self) -> list[str]:
    """
    Return the required environment variables for this environment.

    Returns
    -------
    :
        The environment variables required in this environment.

    Raises
    ------
    ValueError
        If the environment isn't recognized.
    """
    required = ["OPENAI_API_KEY", "ASTRA_DB_APPLICATION_TOKEN"]
    if self == Environment.CASSIO:
        required.append("ASTRA_DB_DATABASE_ID")
    elif self == Environment.ASTRAPY:
        required.append("ASTRA_DB_API_ENDPOINT")
    else:
        raise ValueError(f"Unrecognized environment '{self}")
    return required

initialize_environment

initialize_environment(env: Environment = CASSIO)

Initialize the environment variables.

PARAMETER DESCRIPTION
env

The environment to initialize

TYPE: Environment DEFAULT: CASSIO

Notes
This uses the following:

1. If a `.env` file is found, load environment variables from that.
2. If not, and running in colab, set necessary environment variables from
    secrets.
3. If necessary variables aren't set by the above, then prompts the user.
Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/env.py
def initialize_environment(env: Environment = Environment.CASSIO):
    """
    Initialize the environment variables.

    Parameters
    ----------
    env :
        The environment to initialize

    Notes
    -----
        This uses the following:

        1. If a `.env` file is found, load environment variables from that.
        2. If not, and running in colab, set necessary environment variables from
            secrets.
        3. If necessary variables aren't set by the above, then prompts the user.
    """
    # 1. If a `.env` file is found, load environment variables from that.
    if dotenv_path := find_dotenv():
        load_dotenv(dotenv_path)
        verify_environment(env)
        return

    # 2. If not, and running in colab, set necesary environment variables from secrets.
    try:
        initialize_from_colab_userdata(env)
        verify_environment(env)
        return
    except (ImportError, ModuleNotFoundError):
        pass

    # 3. Initialize from prompts.
    initialize_from_prompts(env)
    verify_environment(env)

initialize_from_colab_userdata

initialize_from_colab_userdata(env: Environment = CASSIO)

Try to initialize environment from colab userdata.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/env.py
def initialize_from_colab_userdata(env: Environment = Environment.CASSIO):
    """Try to initialize environment from colab `userdata`."""
    from google.colab import userdata  # type: ignore[import-untyped]

    for required in env.required_envvars():
        os.environ[required] = userdata.get(required)

    try:
        os.environ["ASTRA_DB_KEYSPACE"] = userdata.get("ASTRA_DB_KEYSPACE")
    except userdata.SecretNotFoundError as _:
        # User doesn't have a keyspace set, so use the default.
        os.environ.pop("ASTRA_DB_KEYSPACE", None)

    try:
        os.environ["LANGCHAIN_API_KEY"] = userdata.get("LANGCHAIN_API_KEY")
        os.environ["LANGCHAIN_TRACING_V2"] = "True"
    except (userdata.SecretNotFoundError, userdata.NotebookAccessError):
        print("Colab Secret not set / accessible. Not configuring tracing")  # noqa: T201
        os.environ.pop("LANGCHAIN_API_KEY")
        os.environ.pop("LANGCHAIN_TRACING_V2")

initialize_from_prompts

initialize_from_prompts(env: Environment = CASSIO)

Initialize the environment by prompting the user.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/env.py
def initialize_from_prompts(env: Environment = Environment.CASSIO):
    """Initialize the environment by prompting the user."""
    import getpass

    for required in env.required_envvars():
        if required in os.environ:
            continue
        elif required in NON_SECRETS:
            os.environ[required] = input(required)
        else:
            os.environ[required] = getpass.getpass(required)

verify_environment

verify_environment(env: Environment = CASSIO)

Verify the necessary environment variables are set.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/env.py
def verify_environment(env: Environment = Environment.CASSIO):
    """Verify the necessary environment variables are set."""
    for required in env.required_envvars():
        assert required in os.environ, f'"{required}" not defined in environment'

persistent_iteration

Offset dataclass

Offset(index: int)

Class for tracking a position in the iteraiton.

PersistentIteration

PersistentIteration(
    journal_name: str, iterator: Iterator[T]
)

Bases: Generic[T]

Create a persistent iteration.

This creates a journal file with the name journal_name containing the indices of completed items. When resuming iteration, the already processed indices will be skipped.

PARAMETER DESCRIPTION
journal_name

Name of the journal file to use. If it doesn't exist it will be created. The indices of completed items will be written to the journal.

TYPE: str

iterator

The iterator to process persistently. It must be deterministic -- elements should always be returned in the same order on restarts.

TYPE: Iterator[T]

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/persistent_iteration.py
def __init__(self, journal_name: str, iterator: Iterator[T]) -> None:
    self.iterator = enumerate(iterator)
    self.pending: dict[Offset, T] = {}

    self._completed = set()
    try:
        read_journal = open(journal_name)
        for line in read_journal:
            self._completed.add(Offset(index=int(line)))
    except FileNotFoundError:
        pass

    self._write_journal = open(journal_name, "a")

__iter__

__iter__() -> Iterator[tuple[Offset, T]]

Iterate over pairs of offsets and elements.

RETURNS DESCRIPTION
Iterator[tuple[Offset, T]]
Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/persistent_iteration.py
def __iter__(self) -> Iterator[tuple[Offset, T]]:
    """
    Iterate over pairs of offsets and elements.

    Returns
    -------
    :
    """
    return self

__next__

__next__() -> tuple[Offset, T]

Return the next offset and item.

RETURNS DESCRIPTION
offset

The offset of the next item. Should be acknowledge after the item is finished processing.

TYPE: Offset

item

The next item.

TYPE: T

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/persistent_iteration.py
def __next__(self) -> tuple[Offset, T]:
    """
    Return the next offset and item.

    Returns
    -------
    offset :
        The offset of the next item. Should be acknowledge after the item
        is finished processing.
    item :
        The next item.
    """
    index, item = next(self.iterator)
    offset = Offset(index)

    while offset in self._completed:
        index, item = next(self.iterator)
        offset = Offset(index)

    self.pending[offset] = item
    return (offset, item)

ack

ack(offset: Offset) -> int

Acknowledge the given offset.

This should only be called after the elements in that offset have been persisted.

PARAMETER DESCRIPTION
offset

The offset to acknowledge.

TYPE: Offset

RETURNS DESCRIPTION
int

The numebr of pending elements.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/persistent_iteration.py
def ack(self, offset: Offset) -> int:
    """
    Acknowledge the given offset.

    This should only be called after the elements in that offset have been
    persisted.

    Parameters
    ----------
    offset :
        The offset to acknowledge.

    Returns
    -------
    :
        The numebr of pending elements.
    """
    self._write_journal.write(f"{offset.index}\n")
    self._write_journal.flush()
    self._completed.add(offset)

    self.pending.pop(offset)
    return len(self.pending)

completed_count

completed_count() -> int

Return the numebr of completed elements.

RETURNS DESCRIPTION
int

The number of completed elements.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/persistent_iteration.py
def completed_count(self) -> int:
    """
    Return the numebr of completed elements.

    Returns
    -------
    :
        The number of completed elements.
    """
    return len(self._completed)

pending_count

pending_count() -> int

Return the number of pending (not processed) elements.

RETURNS DESCRIPTION
int

The number of pending elements.

Source code in packages/graph-rag-example-helpers/src/graph_rag_example_helpers/persistent_iteration.py
def pending_count(self) -> int:
    """
    Return the number of pending (not processed) elements.

    Returns
    -------
    :
        The number of pending elements.
    """
    return len(self.pending)