pymc.testing.mock_sample#
- pymc.testing.mock_sample(draws=10, sample_stats=None, **kwargs)[source]#
Mock
pymc.sample()withpymc.sample_prior_predictive().Useful for testing models that use pm.sample without running MCMC sampling.
Examples
Using mock_sample with pytest
Note
Use
pymc.testing.mock_sample_setup_and_teardown()directly for pytest fixtures.import pytest import pymc as pm from pymc.testing import mock_sample @pytest.fixture(scope="module") def mock_pymc_sample(): original_sample = pm.sample pm.sample = mock_sample yield pm.sample = original_sample
By default, the sample_stats group is not created. Pass a dictionary of functions that create sample statistics, where the keys are the names of the statistics and the values are functions that take a size tuple and return an array of that size.
from functools import partial import numpy as np from numpy.typing import NDArray from pymc.testing import mock_sample def mock_diverging(size: tuple[int, int]) -> NDArray: return np.zeros(size) def mock_tree_depth(size: tuple[int, int]) -> NDArray: return np.random.choice(range(2, 10), size=size) mock_sample_with_stats = partial( mock_sample, sample_stats={ "diverging": mock_diverging, "tree_depth": mock_tree_depth, }, )