pymc.testing.mock_sample#

pymc.testing.mock_sample(draws=10, sample_stats=None, **kwargs)[source]#

Mock pymc.sample() with pymc.sample_prior_predictive().

Useful for testing models that use pm.sample without running MCMC sampling.

Examples

Using mock_sample with pytest

Note

Use pymc.testing.mock_sample_setup_and_teardown() directly for pytest fixtures.

import pytest

import pymc as pm
from pymc.testing import mock_sample

@pytest.fixture(scope="module")
def mock_pymc_sample():
    original_sample = pm.sample
    pm.sample = mock_sample

    yield

    pm.sample = original_sample

By default, the sample_stats group is not created. Pass a dictionary of functions that create sample statistics, where the keys are the names of the statistics and the values are functions that take a size tuple and return an array of that size.

from functools import partial

import numpy as np
from numpy.typing import NDArray

from pymc.testing import mock_sample

def mock_diverging(size: tuple[int, int]) -> NDArray:
    return np.zeros(size)

def mock_tree_depth(size: tuple[int, int]) -> NDArray:
    return np.random.choice(range(2, 10), size=size)

mock_sample_with_stats = partial(
    mock_sample,
    sample_stats={
        "diverging": mock_diverging,
        "tree_depth": mock_tree_depth,
    },
)