Skip to content

graph_retriever.testing

Helpers for testing Graph Retriever implementations.

adapter_tests

AdapterComplianceCase dataclass

AdapterComplianceCase(
    *,
    id: str,
    expected: list[str],
    requires_nested: bool = False,
    requires_dict_in_list: bool = False,
)

Bases: ABC

Base dataclass for test cases.

ATTRIBUTE DESCRIPTION
id

The ID of the test case.

TYPE: str

expected

The expected results of the case.

TYPE: list[str]

AdapterComplianceSuite

Bases: ABC

Test suite for adapter compliance.

To use this, create a sub-class containing a @pytest.fixture named adapter which returns an Adapter with the documents from animals.jsonl loaded.

adjacent_case

adjacent_case(request) -> AdjacentCase

Fixture providing the get_adjacent and aget_adjacent test cases.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
@pytest.fixture(params=ADJACENT_CASES, ids=lambda c: c.id)
def adjacent_case(self, request) -> AdjacentCase:
    """Fixture providing the `get_adjacent` and `aget_adjacent` test cases."""
    return request.param

expected

expected(
    method: str, case: AdapterComplianceCase
) -> list[str]

Override to change the expected behavior of a case.

If the test is expected to fail, call pytest.xfail(reason), or pytest.skip(reason) if it can't be executed.

Generally, this should not change the expected results, unless the the adapter being tested uses wildly different distance metrics or a different embedding. The AnimalsEmbedding is deterimistic and the results across vector stores should generally be deterministic and consistent.

PARAMETER DESCRIPTION
method

The method being tested. For instance, get, aget, or similarity_search_with_embedding, etc.

TYPE: str

case

The case being tested.

TYPE: AdapterComplianceCase

RETURNS DESCRIPTION
list[str]

The expected animals.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def expected(self, method: str, case: AdapterComplianceCase) -> list[str]:
    """
    Override to change the expected behavior of a case.

    If the test is expected to fail, call `pytest.xfail(reason)`, or
    `pytest.skip(reason)` if it can't be executed.

    Generally, this should *not* change the expected results, unless the the
    adapter being tested uses wildly different distance metrics or a
    different embedding. The `AnimalsEmbedding` is deterimistic and the
    results across vector stores should generally be deterministic and
    consistent.

    Parameters
    ----------
    method :
        The method being tested. For instance, `get`, `aget`, or
        `similarity_search_with_embedding`, etc.
    case :
        The case being tested.

    Returns
    -------
    :
        The expected animals.
    """
    if not self.supports_nested_metadata() and case.requires_nested:
        pytest.xfail("nested metadata not supported")
    if not self.supports_dict_in_list() and case.requires_dict_in_list:
        pytest.xfail("dict-in-list fields is not supported")
    return case.expected

get_case

get_case(request) -> GetCase

Fixture providing the get and aget test cases.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
@pytest.fixture(params=GET_CASES, ids=lambda c: c.id)
def get_case(self, request) -> GetCase:
    """Fixture providing the `get` and `aget` test cases."""
    return request.param

search_case

search_case(request) -> SearchCase

Fixture providing the (a)?similarity_search_* test cases.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
@pytest.fixture(params=SEARCH_CASES, ids=lambda c: c.id)
def search_case(self, request) -> SearchCase:
    """Fixture providing the `(a)?similarity_search_*` test cases."""
    return request.param

supports_dict_in_list

supports_dict_in_list() -> bool

Return whether dicts can appear in list fields in metadata.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def supports_dict_in_list(self) -> bool:
    """Return whether dicts can appear in list fields in metadata."""
    return True

supports_nested_metadata

supports_nested_metadata() -> bool

Return whether nested metadata is expected to work.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def supports_nested_metadata(self) -> bool:
    """Return whether nested metadata is expected to work."""
    return True

test_aadjacent async

test_aadjacent(
    adapter: Adapter, adjacent_case: AdjacentCase
) -> None

Run tests for `aadjacent.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
async def test_aadjacent(
    self, adapter: Adapter, adjacent_case: AdjacentCase
) -> None:
    """Run tests for `aadjacent."""
    expected = self.expected("aadjacent", adjacent_case)
    embedding, _ = await adapter.asearch_with_embedding(adjacent_case.query, k=0)
    results = await adapter.aadjacent(
        edges=adjacent_case.edges,
        query_embedding=embedding,
        k=adjacent_case.k,
        filter=adjacent_case.filter,
    )
    assert_ids_any_order(results, expected)

test_adjacent

test_adjacent(
    adapter: Adapter, adjacent_case: AdjacentCase
) -> None

Run tests for `adjacent.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def test_adjacent(self, adapter: Adapter, adjacent_case: AdjacentCase) -> None:
    """Run tests for `adjacent."""
    expected = self.expected("adjacent", adjacent_case)
    embedding, _ = adapter.search_with_embedding(adjacent_case.query, k=0)
    results = adapter.adjacent(
        edges=adjacent_case.edges,
        query_embedding=embedding,
        k=adjacent_case.k,
        filter=adjacent_case.filter,
    )
    assert_ids_any_order(results, expected)

test_aget async

test_aget(adapter: Adapter, get_case: GetCase) -> None

Run tests for aget.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
async def test_aget(self, adapter: Adapter, get_case: GetCase) -> None:
    """Run tests for `aget`."""
    expected = self.expected("aget", get_case)
    results = await adapter.aget(get_case.request, filter=get_case.filter)
    assert_ids_any_order(results, expected)

test_asearch async

test_asearch(
    adapter: Adapter, search_case: SearchCase
) -> None

Run tests for asearch.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
async def test_asearch(self, adapter: Adapter, search_case: SearchCase) -> None:
    """Run tests for `asearch`."""
    expected = self.expected("asearch", search_case)
    embedding, _ = await adapter.asearch_with_embedding(search_case.query, k=0)
    results = await adapter.asearch(embedding, **search_case.kwargs)
    assert_ids_any_order(results, expected)

test_asearch_with_embedding async

test_asearch_with_embedding(
    adapter: Adapter, search_case: SearchCase
) -> None

Run tests for asearch_with_embedding.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
async def test_asearch_with_embedding(
    self, adapter: Adapter, search_case: SearchCase
) -> None:
    """Run tests for `asearch_with_embedding`."""
    expected = self.expected("asearch_with_embedding", search_case)
    embedding, results = await adapter.asearch_with_embedding(
        search_case.query, **search_case.kwargs
    )
    assert_is_embedding(embedding)
    assert_ids_any_order(results, expected)

test_get

test_get(adapter: Adapter, get_case: GetCase) -> None

Run tests for get.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def test_get(self, adapter: Adapter, get_case: GetCase) -> None:
    """Run tests for `get`."""
    expected = self.expected("get", get_case)
    results = adapter.get(get_case.request, filter=get_case.filter)
    assert_ids_any_order(results, expected)
test_search(
    adapter: Adapter, search_case: SearchCase
) -> None

Run tests for search.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def test_search(self, adapter: Adapter, search_case: SearchCase) -> None:
    """Run tests for `search`."""
    expected = self.expected("search", search_case)
    embedding, _ = adapter.search_with_embedding(search_case.query, k=0)
    results = adapter.search(embedding, **search_case.kwargs)
    assert_ids_any_order(results, expected)

test_search_with_embedding

test_search_with_embedding(
    adapter: Adapter, search_case: SearchCase
) -> None

Run tests for search_with_embedding.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def test_search_with_embedding(
    self, adapter: Adapter, search_case: SearchCase
) -> None:
    """Run tests for `search_with_embedding`."""
    expected = self.expected("search_with_embedding", search_case)
    embedding, results = adapter.search_with_embedding(
        search_case.query, **search_case.kwargs
    )
    assert_is_embedding(embedding)
    assert_ids_any_order(results, expected)

AdjacentCase dataclass

AdjacentCase(
    query: str,
    edges: set[Edge],
    k: int = 4,
    filter: dict[str, Any] | None = None,
    *,
    id: str,
    expected: list[str],
    requires_nested: bool = False,
    requires_dict_in_list: bool = False,
)

Bases: AdapterComplianceCase

A test case for get_adjacent and aget_adjacent.

GetCase dataclass

GetCase(
    request: list[str],
    filter: dict[str, Any] | None = None,
    *,
    id: str,
    expected: list[str],
    requires_nested: bool = False,
    requires_dict_in_list: bool = False,
)

Bases: AdapterComplianceCase

A test case for get and aget.

SearchCase dataclass

SearchCase(
    query: str,
    k: int | None = None,
    filter: dict[str, str] | None = None,
    *,
    id: str,
    expected: list[str],
    requires_nested: bool = False,
    requires_dict_in_list: bool = False,
)

Bases: AdapterComplianceCase

A test case for similarity_search_* and asimilarity_search_* methods.

kwargs property

kwargs

Return keyword arguments for the test invocation.

assert_ids_any_order

assert_ids_any_order(
    results: Iterable[Content], expected: list[str]
) -> None

Assert the results are valid and match the IDs.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def assert_ids_any_order(
    results: Iterable[Content],
    expected: list[str],
) -> None:
    """Assert the results are valid and match the IDs."""
    assert_valid_results(results)

    result_ids = [r.id for r in results]
    assert set(result_ids) == set(expected), "should contain exactly expected IDs"

assert_is_embedding

assert_is_embedding(value: Any)

Assert the value is an embedding.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def assert_is_embedding(value: Any):
    """Assert the value is an embedding."""
    assert isinstance(value, list)
    for item in value:
        assert isinstance(item, float)

assert_valid_result

assert_valid_result(content: Content)

Assert the content is valid.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def assert_valid_result(content: Content):
    """Assert the content is valid."""
    assert isinstance(content.id, str)
    assert_is_embedding(content.embedding)

assert_valid_results

assert_valid_results(docs: Iterable[Content])

Assert all of the contents are valid results.

Source code in packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py
def assert_valid_results(docs: Iterable[Content]):
    """Assert all of the contents are valid results."""
    for doc in docs:
        assert_valid_result(doc)

embeddings

AnimalEmbeddings

AnimalEmbeddings()

Bases: WordEmbeddings

Embeddings for animal test-case.

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def __init__(self):
    super().__init__(
        words="""
        alli alpa amer amph ante ante antl appe aqua arct arma aust babo
        badg barr bask bear beav beet beha bird biso bite blac blue boar
        bobc brig buff bugl burr bush butt came cani capy cari carn cass
        cate cham chee chic chim chin chir clim coas coat cobr cock colo
        colo comm comp cour coyo crab cran croa croc crow crus wing wool
        cult cunn curi curl damb danc deer defe defe deme dese digg ding
        dise dist dive dolp dome dome donk dove drag drag duck ecos effo
        eigh elab eleg elev elon euca extr eyes falc famo famo fast fast
        feat feet ferr fier figh finc fish flam flig flig food fore foun
        fres frie frog gaze geck gees gent gill gira goat gori grac gras
        gras graz grou grou grou guin hams hard hawk hedg herb herd hero
        high hipp honk horn hors hove howl huma humm hump hunt hyen iden
        igua inde inse inte jack jagu jell jump jung kang koal komo lark
        larv lemu leop life lion liza lobs long loud loya mada magp mamm
        mana mari mars mass mati meat medi melo meta migr milk mimi moos
        mosq moth narw nati natu neck newt noct nort ocea octo ostr pack
        pain patt peac pest pinc pink play plum poll post pouc powe prec
        pred prey prid prim prob prot prow quac quil rais rapi reac rega
        rege regi rego rein rept resi rive roam rode sava scav seab seaf
        seas semi shar shed shel skil smal snak soci soft soli song song
        soun sout spec spee spik spor spot stag stic stin stin stor stre
        stre stre stro surv surv sust symb tail tall talo team teet tent
        term terr thou tiny tong toug tree agil tuft tund tusk umbr unic
        uniq vast vege veno vibr vita vora wadi wasp wate webb wetl wild
        ant bat bee cat cow dog eel elk emu fox pet pig""".split()
    )

__call__

__call__(text: str) -> list[float]

Return the embedding.

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def __call__(self, text: str) -> list[float]:
    """Return the embedding."""
    return [
        1.0 + (100 / self._offsets[i]) if word in text else 0.2 / (i + 1)
        for i, word in enumerate(self._words)
    ]

ParserEmbeddings

ParserEmbeddings(dimension: int = 10)

Parse the tuext as a list of floats, otherwise return zeros.

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def __init__(self, dimension: int = 10) -> None:
    self.dimension = dimension

__call__

__call__(text: str) -> list[float]

Return the embedding.

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def __call__(self, text: str) -> list[float]:
    """Return the embedding."""
    try:
        vals = json.loads(text)
        assert len(vals) == self.dimension
        return vals
    except json.JSONDecodeError:
        return [0.0] * self.dimension

WordEmbeddings

WordEmbeddings(words: list[str])

Embeddings based on a word list.

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def __init__(self, words: list[str]):
    self._words = words
    self._offsets = [
        _string_to_number(w) * ((-1) ** i) for i, w in enumerate(words)
    ]

__call__

__call__(text: str) -> list[float]

Return the embedding.

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def __call__(self, text: str) -> list[float]:
    """Return the embedding."""
    return [
        1.0 + (100 / self._offsets[i]) if word in text else 0.2 / (i + 1)
        for i, word in enumerate(self._words)
    ]

angular_2d_embedding

angular_2d_embedding(text: str) -> list[float]

Convert input text to a 'vector' (list of floats).

PARAMETER DESCRIPTION
text

The text to embed.

TYPE: str

RETURNS DESCRIPTION
list[float]

If the text is a number, use it as the angle for the unit vector in units of pi.

Any other input text becomes the singular result [0, 0].

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def angular_2d_embedding(text: str) -> list[float]:
    """
    Convert input text to a 'vector' (list of floats).

    Parameters
    ----------
    text: str
        The text to embed.

    Returns
    -------
    :
        If the text is a number, use it as the angle for the unit vector in
        units of pi.

        Any other input text becomes the singular result `[0, 0]`.
    """
    try:
        angle = float(text)
        return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
    except ValueError:
        # Assume: just test string, no attention is paid to values.
        return [0.0, 0.0]

earth_embeddings

earth_embeddings(text: str) -> list[float]

Split words and return a vector based on that.

Source code in packages/graph-retriever/src/graph_retriever/testing/embeddings.py
def earth_embeddings(text: str) -> list[float]:
    """Split words and return a vector based on that."""

    def vector_near(value: float) -> list[float]:
        base_point = [value, (1 - value**2) ** 0.5]
        fluctuation = random.random() / 100.0
        return [base_point[0] + fluctuation, base_point[1] - fluctuation]

    words = set(text.lower().split())
    if "earth" in words:
        return vector_near(0.9)
    elif {"planet", "world", "globe", "sphere"}.intersection(words):
        return vector_near(0.8)
    else:
        return vector_near(0.1)