Skip to content

graph_retriever.strategies

Strategies determine which nodes are selected during traversal.

Eager dataclass

Eager(
    *,
    select_k: int = DEFAULT_SELECT_K,
    start_k: int = 4,
    adjacent_k: int = 10,
    max_traverse: int | None = None,
    max_depth: int | None = None,
    k: int = DEFAULT_SELECT_K,
    _query_embedding: list[float] = list(),
)

Bases: Strategy

Eager traversal strategy (breadth-first).

This strategy selects all discovered nodes at each traversal step. It ensures breadth-first traversal by processing nodes layer by layer, which is useful for scenarios where all nodes at the current depth should be explored before proceeding to the next depth.

PARAMETER DESCRIPTION
select_k

Maximum number of nodes to retrieve during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

start_k

Number of documents to fetch via similarity for starting the traversal. Added to any initial roots provided to the traversal.

TYPE: int DEFAULT: 4

adjacent_k

Number of documents to fetch for each outgoing edge.

TYPE: int DEFAULT: 10

max_depth

Maximum traversal depth. If None, there is no limit.

TYPE: int | None DEFAULT: None

k

Deprecated: Use select_k instead. Maximum number of nodes to select and return during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

__post_init__

__post_init__()

Allow passing the deprecated 'k' value instead of 'select_k'.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def __post_init__(self):
    """Allow passing the deprecated 'k' value instead of 'select_k'."""
    if self.select_k == DEFAULT_SELECT_K and self.k != DEFAULT_SELECT_K:
        self.select_k = self.k
    else:
        self.k = self.select_k

build staticmethod

build(base_strategy: Strategy, **kwargs: Any) -> Strategy

Build a strategy for a retrieval operation.

Combines a base strategy with any provided keyword arguments to create a customized traversal strategy.

PARAMETER DESCRIPTION
base_strategy

The base strategy to start with.

TYPE: Strategy

kwargs

Additional configuration options for the strategy.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Strategy

A configured strategy instance.

RAISES DESCRIPTION
ValueError

If 'strategy' is set incorrectly or extra arguments are invalid.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
@staticmethod
def build(
    base_strategy: Strategy,
    **kwargs: Any,
) -> Strategy:
    """
    Build a strategy for a retrieval operation.

    Combines a base strategy with any provided keyword arguments to
    create a customized traversal strategy.

    Parameters
    ----------
    base_strategy :
        The base strategy to start with.
    kwargs :
        Additional configuration options for the strategy.

    Returns
    -------
    :
        A configured strategy instance.

    Raises
    ------
    ValueError
        If 'strategy' is set incorrectly or extra arguments are invalid.
    """
    # Check if there is a new strategy to use. Otherwise, use the base.
    strategy: Strategy
    if "strategy" in kwargs:
        if next(iter(kwargs.keys())) != "strategy":
            raise ValueError("Error: 'strategy' must be set before other args.")
        strategy = kwargs.pop("strategy")
        if not isinstance(strategy, Strategy):
            raise ValueError(
                f"Unsupported 'strategy' type {type(strategy).__name__}."
                " Must be a sub-class of Strategy"
            )
    elif base_strategy is not None:
        strategy = base_strategy
    else:
        raise ValueError("'strategy' must be set in `__init__` or invocation")

    # Apply the kwargs to update the strategy.
    assert strategy is not None
    if "k" in kwargs:
        kwargs["select_k"] = kwargs.pop("k")
    strategy = dataclasses.replace(strategy, **kwargs)

    return strategy

finalize_nodes

finalize_nodes(selected: Iterable[Node]) -> Iterable[Node]

Finalize the selected nodes.

This method is called before returning the final set of nodes. It allows the strategy to perform any final processing or re-ranking of the selected nodes.

PARAMETER DESCRIPTION
selected

The selected nodes to be finalized

TYPE: Iterable[Node]

RETURNS DESCRIPTION
Iterable[Node]

Finalized nodes.

Notes
  • The default implementation returns the first self.select_k selected nodes without any additional processing.
Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
    """
    Finalize the selected nodes.

    This method is called before returning the final set of nodes. It allows
    the strategy to perform any final processing or re-ranking of the selected
    nodes.

    Parameters
    ----------
    selected :
        The selected nodes to be finalized

    Returns
    -------
    :
        Finalized nodes.

    Notes
    -----
    - The default implementation returns the first `self.select_k` selected nodes
    without any additional processing.
    """
    return list(selected)[: self.select_k]

iteration

iteration(
    nodes: Iterable[Node], tracker: NodeTracker
) -> None

Process the newly discovered nodes on each iteration.

This method should call tracker.traverse() and/or tracker.select() as appropriate to update the nodes that need to be traversed in this iteration or selected at the end of the retrieval, respectively.

PARAMETER DESCRIPTION
nodes

The newly discovered nodes found from either: - the initial vector store retrieval - incoming edges from nodes chosen for traversal in the previous iteration

TYPE: Iterable[Node]

tracker

The tracker object to manage the traversal and selection of nodes.

TYPE: NodeTracker

Notes
  • This method is called once for each iteration of the traversal.
  • In order to stop iterating either choose to not traverse any additional nodes or don't select any additional nodes for output.
Source code in packages/graph-retriever/src/graph_retriever/strategies/eager.py
@override
def iteration(self, nodes: Iterable[Node], tracker: NodeTracker) -> None:
    tracker.select_and_traverse(nodes)

Mmr dataclass

Mmr(
    lambda_mult: float = 0.5,
    min_mmr_score: float = NEG_INF,
    _selected_ids: list[str] = list(),
    _candidate_id_to_index: dict[str, int] = dict(),
    _candidates: list[_MmrCandidate] = list(),
    _best_score: float = NEG_INF,
    _best_id: str | None = None,
    *,
    select_k: int = DEFAULT_SELECT_K,
    start_k: int = 4,
    adjacent_k: int = 10,
    max_traverse: int | None = None,
    max_depth: int | None = None,
    k: int = DEFAULT_SELECT_K,
    _query_embedding: list[float] = list(),
)

Bases: Strategy

Maximal Marginal Relevance (MMR) traversal strategy.

This strategy selects nodes by balancing relevance to the query and diversity among the results. It uses a lambda_mult parameter to control the trade-off between relevance and redundancy. Nodes are scored based on their similarity to the query and their distance from already selected nodes.

PARAMETER DESCRIPTION
select_k

Maximum number of nodes to retrieve during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

start_k

Number of documents to fetch via similarity for starting the traversal. Added to any initial roots provided to the traversal.

TYPE: int DEFAULT: 4

adjacent_k

Number of documents to fetch for each outgoing edge.

TYPE: int DEFAULT: 10

max_depth

Maximum traversal depth. If None, there is no limit.

TYPE: int | None DEFAULT: None

lambda_mult

Controls the trade-off between relevance and diversity. A value closer to 1 prioritizes relevance, while a value closer to 0 prioritizes diversity. Must be between 0 and 1 (inclusive).

TYPE: float DEFAULT: 0.5

min_mmr_score

Only nodes with a score greater than or equal to this value will be selected.

TYPE: float DEFAULT: NEG_INF

k

Deprecated: Use select_k instead. Maximum number of nodes to select and return during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

__post_init__

__post_init__()

Allow passing the deprecated 'k' value instead of 'select_k'.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def __post_init__(self):
    """Allow passing the deprecated 'k' value instead of 'select_k'."""
    if self.select_k == DEFAULT_SELECT_K and self.k != DEFAULT_SELECT_K:
        self.select_k = self.k
    else:
        self.k = self.select_k

build staticmethod

build(base_strategy: Strategy, **kwargs: Any) -> Strategy

Build a strategy for a retrieval operation.

Combines a base strategy with any provided keyword arguments to create a customized traversal strategy.

PARAMETER DESCRIPTION
base_strategy

The base strategy to start with.

TYPE: Strategy

kwargs

Additional configuration options for the strategy.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Strategy

A configured strategy instance.

RAISES DESCRIPTION
ValueError

If 'strategy' is set incorrectly or extra arguments are invalid.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
@staticmethod
def build(
    base_strategy: Strategy,
    **kwargs: Any,
) -> Strategy:
    """
    Build a strategy for a retrieval operation.

    Combines a base strategy with any provided keyword arguments to
    create a customized traversal strategy.

    Parameters
    ----------
    base_strategy :
        The base strategy to start with.
    kwargs :
        Additional configuration options for the strategy.

    Returns
    -------
    :
        A configured strategy instance.

    Raises
    ------
    ValueError
        If 'strategy' is set incorrectly or extra arguments are invalid.
    """
    # Check if there is a new strategy to use. Otherwise, use the base.
    strategy: Strategy
    if "strategy" in kwargs:
        if next(iter(kwargs.keys())) != "strategy":
            raise ValueError("Error: 'strategy' must be set before other args.")
        strategy = kwargs.pop("strategy")
        if not isinstance(strategy, Strategy):
            raise ValueError(
                f"Unsupported 'strategy' type {type(strategy).__name__}."
                " Must be a sub-class of Strategy"
            )
    elif base_strategy is not None:
        strategy = base_strategy
    else:
        raise ValueError("'strategy' must be set in `__init__` or invocation")

    # Apply the kwargs to update the strategy.
    assert strategy is not None
    if "k" in kwargs:
        kwargs["select_k"] = kwargs.pop("k")
    strategy = dataclasses.replace(strategy, **kwargs)

    return strategy

candidate_ids

candidate_ids() -> Iterable[str]

Return the IDs of the candidates.

RETURNS DESCRIPTION
Iterable[str]

The IDs of the candidates.

Source code in packages/graph-retriever/src/graph_retriever/strategies/mmr.py
def candidate_ids(self) -> Iterable[str]:
    """
    Return the IDs of the candidates.

    Returns
    -------
    Iterable[str]
        The IDs of the candidates.
    """
    return self._candidate_id_to_index.keys()

finalize_nodes

finalize_nodes(selected: Iterable[Node]) -> Iterable[Node]

Finalize the selected nodes.

This method is called before returning the final set of nodes. It allows the strategy to perform any final processing or re-ranking of the selected nodes.

PARAMETER DESCRIPTION
selected

The selected nodes to be finalized

TYPE: Iterable[Node]

RETURNS DESCRIPTION
Iterable[Node]

Finalized nodes.

Notes
  • The default implementation returns the first self.select_k selected nodes without any additional processing.
Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
    """
    Finalize the selected nodes.

    This method is called before returning the final set of nodes. It allows
    the strategy to perform any final processing or re-ranking of the selected
    nodes.

    Parameters
    ----------
    selected :
        The selected nodes to be finalized

    Returns
    -------
    :
        Finalized nodes.

    Notes
    -----
    - The default implementation returns the first `self.select_k` selected nodes
    without any additional processing.
    """
    return list(selected)[: self.select_k]

iteration

iteration(
    nodes: Iterable[Node], tracker: NodeTracker
) -> None

Add candidates to the consideration set.

Source code in packages/graph-retriever/src/graph_retriever/strategies/mmr.py
@override
def iteration(self, nodes: Iterable[Node], tracker: NodeTracker) -> None:
    """Add candidates to the consideration set."""
    nodes = list(nodes)
    node_count = len(nodes)
    if node_count > 0:
        # Build up a matrix of the remaining candidate embeddings.
        # And add them to the candidate set
        new_embeddings: NDArray[np.float32] = np.ndarray(
            (
                node_count,
                self._dimensions,
            )
        )
        offset = self._candidate_embeddings.shape[0]
        for index, candidate_node in enumerate(nodes):
            self._candidate_id_to_index[candidate_node.id] = offset + index
            new_embeddings[index] = candidate_node.embedding

        # Compute the similarity to the query.
        similarity = cosine_similarity(new_embeddings, self._nd_query_embedding)

        # Compute the distance metrics of all of pairs in the selected set with
        # the new candidates.
        redundancy = cosine_similarity(
            new_embeddings, self._already_selected_embeddings()
        )
        for index, candidate_node in enumerate(nodes):
            max_redundancy = 0.0
            if redundancy.shape[0] > 0:
                max_redundancy = redundancy[index].max()
            candidate = _MmrCandidate(
                node=candidate_node,
                similarity=similarity[index][0],
                weighted_similarity=self.lambda_mult * similarity[index][0],
                weighted_redundancy=self._lambda_mult_complement * max_redundancy,
            )
            self._candidates.append(candidate)

            if candidate.score >= self._best_score:
                self._best_score = candidate.score
                self._best_id = candidate.node.id

        # Add the new embeddings to the candidate set.
        self._candidate_embeddings = np.vstack(
            (
                self._candidate_embeddings,
                new_embeddings,
            )
        )

    while tracker.num_remaining > 0:
        next = self._next()

        if next is None:
            break

        num_traversing = tracker.select_and_traverse([next])
        if num_traversing == 1:
            break

NodeTracker

NodeTracker(select_k: int, max_depth: int | None)

Helper class initiating node selection and traversal.

Call .select(nodes) to add nodes to the result set. Call .traverse(nodes) to add nodes to the next traversal. Call .select_and_traverse(nodes) to add nodes to the result set and the next traversal.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def __init__(self, select_k: int, max_depth: int | None) -> None:
    self._select_k: int = select_k
    self._max_depth: int | None = max_depth
    self._visited_node_ids: set[str] = set()
    # use a dict to preserve order
    self.to_traverse: dict[str, Node] = dict()
    self.selected: list[Node] = []

num_remaining property

num_remaining

The remaining number of nodes to be selected.

select

select(nodes: Iterable[Node]) -> None

Select nodes to be included in the result set.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def select(self, nodes: Iterable[Node]) -> None:
    """Select nodes to be included in the result set."""
    for node in nodes:
        node.extra_metadata["_depth"] = node.depth
        node.extra_metadata["_similarity_score"] = node.similarity_score
    self.selected.extend(nodes)

select_and_traverse

select_and_traverse(nodes: Iterable[Node]) -> int

Select nodes to be included in the result set and the next traversal.

RETURNS DESCRIPTION
Number of nodes added for traversal.
Notes
  • Nodes are only added for traversal if they have not been visited before.
  • Nodes are only added for traversal if they do not exceed the maximum depth.
  • If no new nodes are chosen for traversal, or selected for output, then the traversal will stop.
  • Traversal will also stop if the number of selected nodes reaches the select_k limit.
Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def select_and_traverse(self, nodes: Iterable[Node]) -> int:
    """
    Select nodes to be included in the result set and the next traversal.

    Returns
    -------
    Number of nodes added for traversal.

    Notes
    -----
    - Nodes are only added for traversal if they have not been visited before.
    - Nodes are only added for traversal if they do not exceed the maximum depth.
    - If no new nodes are chosen for traversal, or selected for output, then
        the traversal will stop.
    - Traversal will also stop if the number of selected nodes reaches the select_k
        limit.
    """
    self.select(nodes)
    return self.traverse(nodes)

traverse

traverse(nodes: Iterable[Node]) -> int

Select nodes to be included in the next traversal.

RETURNS DESCRIPTION
Number of nodes added for traversal.
Notes
  • Nodes are only added if they have not been visited before.
  • Nodes are only added if they do not exceed the maximum depth.
  • If no new nodes are chosen for traversal, or selected for output, then the traversal will stop.
  • Traversal will also stop if the number of selected nodes reaches the select_k limit.
Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def traverse(self, nodes: Iterable[Node]) -> int:
    """
    Select nodes to be included in the next traversal.

    Returns
    -------
    Number of nodes added for traversal.

    Notes
    -----
    - Nodes are only added if they have not been visited before.
    - Nodes are only added if they do not exceed the maximum depth.
    - If no new nodes are chosen for traversal, or selected for output, then
        the traversal will stop.
    - Traversal will also stop if the number of selected nodes reaches the select_k
        limit.
    """
    new_nodes = {
        n.id: n
        for n in nodes
        if self._not_visited(n.id)
        if self._max_depth is None or n.depth < self._max_depth
    }
    self.to_traverse.update(new_nodes)
    self._visited_node_ids.update(new_nodes.keys())
    return len(new_nodes)

Scored dataclass

Scored(
    scorer: Callable[[Node], float],
    _nodes: list[_ScoredNode] = list(),
    per_iteration_limit: int | None = None,
    *,
    select_k: int = DEFAULT_SELECT_K,
    start_k: int = 4,
    adjacent_k: int = 10,
    max_traverse: int | None = None,
    max_depth: int | None = None,
    k: int = DEFAULT_SELECT_K,
    _query_embedding: list[float] = list(),
)

Bases: Strategy

Scored traversal strategy.

This strategy uses a scoring function to select nodes using a local maximum approach. In each iteration, it chooses the top scoring nodes available and then traverses the connected nodes.

PARAMETER DESCRIPTION
scorer

A callable function that returns the score of a node.

TYPE: Callable[[Node], float]

select_k

Maximum number of nodes to retrieve during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

start_k

Number of documents to fetch via similarity for starting the traversal. Added to any initial roots provided to the traversal.

TYPE: int DEFAULT: 4

adjacent_k

Number of documents to fetch for each outgoing edge.

TYPE: int DEFAULT: 10

max_depth

Maximum traversal depth. If None, there is no limit.

TYPE: int | None DEFAULT: None

per_iteration_limit

Maximum number of nodes to select and traverse during a single iteration.

TYPE: int | None DEFAULT: None

k

Deprecated: Use select_k instead. Maximum number of nodes to select and return during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

__post_init__

__post_init__()

Allow passing the deprecated 'k' value instead of 'select_k'.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def __post_init__(self):
    """Allow passing the deprecated 'k' value instead of 'select_k'."""
    if self.select_k == DEFAULT_SELECT_K and self.k != DEFAULT_SELECT_K:
        self.select_k = self.k
    else:
        self.k = self.select_k

build staticmethod

build(base_strategy: Strategy, **kwargs: Any) -> Strategy

Build a strategy for a retrieval operation.

Combines a base strategy with any provided keyword arguments to create a customized traversal strategy.

PARAMETER DESCRIPTION
base_strategy

The base strategy to start with.

TYPE: Strategy

kwargs

Additional configuration options for the strategy.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Strategy

A configured strategy instance.

RAISES DESCRIPTION
ValueError

If 'strategy' is set incorrectly or extra arguments are invalid.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
@staticmethod
def build(
    base_strategy: Strategy,
    **kwargs: Any,
) -> Strategy:
    """
    Build a strategy for a retrieval operation.

    Combines a base strategy with any provided keyword arguments to
    create a customized traversal strategy.

    Parameters
    ----------
    base_strategy :
        The base strategy to start with.
    kwargs :
        Additional configuration options for the strategy.

    Returns
    -------
    :
        A configured strategy instance.

    Raises
    ------
    ValueError
        If 'strategy' is set incorrectly or extra arguments are invalid.
    """
    # Check if there is a new strategy to use. Otherwise, use the base.
    strategy: Strategy
    if "strategy" in kwargs:
        if next(iter(kwargs.keys())) != "strategy":
            raise ValueError("Error: 'strategy' must be set before other args.")
        strategy = kwargs.pop("strategy")
        if not isinstance(strategy, Strategy):
            raise ValueError(
                f"Unsupported 'strategy' type {type(strategy).__name__}."
                " Must be a sub-class of Strategy"
            )
    elif base_strategy is not None:
        strategy = base_strategy
    else:
        raise ValueError("'strategy' must be set in `__init__` or invocation")

    # Apply the kwargs to update the strategy.
    assert strategy is not None
    if "k" in kwargs:
        kwargs["select_k"] = kwargs.pop("k")
    strategy = dataclasses.replace(strategy, **kwargs)

    return strategy

finalize_nodes

finalize_nodes(selected: Iterable[Node]) -> Iterable[Node]

Finalize the selected nodes.

This method is called before returning the final set of nodes. It allows the strategy to perform any final processing or re-ranking of the selected nodes.

PARAMETER DESCRIPTION
selected

The selected nodes to be finalized

TYPE: Iterable[Node]

RETURNS DESCRIPTION
Iterable[Node]

Finalized nodes.

Notes
  • The default implementation returns the first self.select_k selected nodes without any additional processing.
Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
    """
    Finalize the selected nodes.

    This method is called before returning the final set of nodes. It allows
    the strategy to perform any final processing or re-ranking of the selected
    nodes.

    Parameters
    ----------
    selected :
        The selected nodes to be finalized

    Returns
    -------
    :
        Finalized nodes.

    Notes
    -----
    - The default implementation returns the first `self.select_k` selected nodes
    without any additional processing.
    """
    return list(selected)[: self.select_k]

iteration

iteration(
    nodes: Iterable[Node], tracker: NodeTracker
) -> None

Process the newly discovered nodes on each iteration.

This method should call tracker.traverse() and/or tracker.select() as appropriate to update the nodes that need to be traversed in this iteration or selected at the end of the retrieval, respectively.

PARAMETER DESCRIPTION
nodes

The newly discovered nodes found from either: - the initial vector store retrieval - incoming edges from nodes chosen for traversal in the previous iteration

TYPE: Iterable[Node]

tracker

The tracker object to manage the traversal and selection of nodes.

TYPE: NodeTracker

Notes
  • This method is called once for each iteration of the traversal.
  • In order to stop iterating either choose to not traverse any additional nodes or don't select any additional nodes for output.
Source code in packages/graph-retriever/src/graph_retriever/strategies/scored.py
@override
def iteration(self, nodes: Iterable[Node], tracker: NodeTracker) -> None:
    for node in nodes:
        heapq.heappush(self._nodes, _ScoredNode(self.scorer(node), node))

    limit = tracker.num_remaining
    if self.per_iteration_limit:
        limit = min(limit, self.per_iteration_limit)

    while limit > 0 and self._nodes:
        highest = heapq.heappop(self._nodes)
        node = highest.node
        node.extra_metadata["_score"] = highest.score
        limit -= tracker.select_and_traverse([node])

Strategy dataclass

Strategy(
    *,
    select_k: int = DEFAULT_SELECT_K,
    start_k: int = 4,
    adjacent_k: int = 10,
    max_traverse: int | None = None,
    max_depth: int | None = None,
    k: int = DEFAULT_SELECT_K,
    _query_embedding: list[float] = list(),
)

Bases: ABC

Interface for configuring node selection and traversal strategies.

This base class defines how nodes are selected, traversed, and finalized during a graph traversal. Implementations can customize behaviors like limiting the depth of traversal, scoring nodes, or selecting the next set of nodes for exploration.

PARAMETER DESCRIPTION
select_k

Maximum number of nodes to select and return during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

start_k

Number of nodes to fetch via similarity for starting the traversal. Added to any initial roots provided to the traversal.

TYPE: int DEFAULT: 4

adjacent_k

Number of nodes to fetch for each outgoing edge.

TYPE: int DEFAULT: 10

max_traverse

Maximum number of nodes to traverse outgoing edges from before returning. If None, there is no limit.

TYPE: int | None DEFAULT: None

max_depth

Maximum traversal depth. If None, there is no limit.

TYPE: int | None DEFAULT: None

k

Deprecated: Use select_k instead. Maximum number of nodes to select and return during traversal.

TYPE: int DEFAULT: DEFAULT_SELECT_K

__post_init__

__post_init__()

Allow passing the deprecated 'k' value instead of 'select_k'.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def __post_init__(self):
    """Allow passing the deprecated 'k' value instead of 'select_k'."""
    if self.select_k == DEFAULT_SELECT_K and self.k != DEFAULT_SELECT_K:
        self.select_k = self.k
    else:
        self.k = self.select_k

build staticmethod

build(base_strategy: Strategy, **kwargs: Any) -> Strategy

Build a strategy for a retrieval operation.

Combines a base strategy with any provided keyword arguments to create a customized traversal strategy.

PARAMETER DESCRIPTION
base_strategy

The base strategy to start with.

TYPE: Strategy

kwargs

Additional configuration options for the strategy.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
Strategy

A configured strategy instance.

RAISES DESCRIPTION
ValueError

If 'strategy' is set incorrectly or extra arguments are invalid.

Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
@staticmethod
def build(
    base_strategy: Strategy,
    **kwargs: Any,
) -> Strategy:
    """
    Build a strategy for a retrieval operation.

    Combines a base strategy with any provided keyword arguments to
    create a customized traversal strategy.

    Parameters
    ----------
    base_strategy :
        The base strategy to start with.
    kwargs :
        Additional configuration options for the strategy.

    Returns
    -------
    :
        A configured strategy instance.

    Raises
    ------
    ValueError
        If 'strategy' is set incorrectly or extra arguments are invalid.
    """
    # Check if there is a new strategy to use. Otherwise, use the base.
    strategy: Strategy
    if "strategy" in kwargs:
        if next(iter(kwargs.keys())) != "strategy":
            raise ValueError("Error: 'strategy' must be set before other args.")
        strategy = kwargs.pop("strategy")
        if not isinstance(strategy, Strategy):
            raise ValueError(
                f"Unsupported 'strategy' type {type(strategy).__name__}."
                " Must be a sub-class of Strategy"
            )
    elif base_strategy is not None:
        strategy = base_strategy
    else:
        raise ValueError("'strategy' must be set in `__init__` or invocation")

    # Apply the kwargs to update the strategy.
    assert strategy is not None
    if "k" in kwargs:
        kwargs["select_k"] = kwargs.pop("k")
    strategy = dataclasses.replace(strategy, **kwargs)

    return strategy

finalize_nodes

finalize_nodes(selected: Iterable[Node]) -> Iterable[Node]

Finalize the selected nodes.

This method is called before returning the final set of nodes. It allows the strategy to perform any final processing or re-ranking of the selected nodes.

PARAMETER DESCRIPTION
selected

The selected nodes to be finalized

TYPE: Iterable[Node]

RETURNS DESCRIPTION
Iterable[Node]

Finalized nodes.

Notes
  • The default implementation returns the first self.select_k selected nodes without any additional processing.
Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
def finalize_nodes(self, selected: Iterable[Node]) -> Iterable[Node]:
    """
    Finalize the selected nodes.

    This method is called before returning the final set of nodes. It allows
    the strategy to perform any final processing or re-ranking of the selected
    nodes.

    Parameters
    ----------
    selected :
        The selected nodes to be finalized

    Returns
    -------
    :
        Finalized nodes.

    Notes
    -----
    - The default implementation returns the first `self.select_k` selected nodes
    without any additional processing.
    """
    return list(selected)[: self.select_k]

iteration abstractmethod

iteration(
    *, nodes: Iterable[Node], tracker: NodeTracker
) -> None

Process the newly discovered nodes on each iteration.

This method should call tracker.traverse() and/or tracker.select() as appropriate to update the nodes that need to be traversed in this iteration or selected at the end of the retrieval, respectively.

PARAMETER DESCRIPTION
nodes

The newly discovered nodes found from either: - the initial vector store retrieval - incoming edges from nodes chosen for traversal in the previous iteration

TYPE: Iterable[Node]

tracker

The tracker object to manage the traversal and selection of nodes.

TYPE: NodeTracker

Notes
  • This method is called once for each iteration of the traversal.
  • In order to stop iterating either choose to not traverse any additional nodes or don't select any additional nodes for output.
Source code in packages/graph-retriever/src/graph_retriever/strategies/base.py
@abc.abstractmethod
def iteration(self, *, nodes: Iterable[Node], tracker: NodeTracker) -> None:
    """
    Process the newly discovered nodes on each iteration.

    This method should call `tracker.traverse()` and/or `tracker.select()`
    as appropriate to update the nodes that need to be traversed in this iteration
    or selected at the end of the retrieval, respectively.

    Parameters
    ----------
    nodes :
        The newly discovered nodes found from either:
        - the initial vector store retrieval
        - incoming edges from nodes chosen for traversal in the previous iteration
    tracker :
        The tracker object to manage the traversal and selection of nodes.

    Notes
    -----
    - This method is called once for each iteration of the traversal.
    - In order to stop iterating either choose to not traverse any additional nodes
    or don't select any additional nodes for output.
    """
    ...