Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/docs/tutorials/advancedfitting/bayesian_bumps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"- Building a reflectometry model with realistic bounds (critical for MCMC).\n",
"- Classical optimisation first (good starting point for Bayesian sampling).\n",
"- High-level DREAM MCMC sampling via\n",
" ``MultiFitter.sample()`` and ``PosteriorResults``.\n",
" ``MultiFitter.mcmc_sample()`` and ``PosteriorResults``.\n",
"- Posterior inspection: summary table, corner plot, trace plot, credible\n",
" intervals, Gelman-Rubin R-hat.\n",
"- Posterior-predictive checks: reflectivity and SLD profile with 95 %\n",
Expand Down Expand Up @@ -190,7 +190,7 @@
"outputs": [],
"source": [
"# ---- Bayesian MCMC sampling -------------------------------------------------\n",
"# ``MultiFitter.sample()`` delegates to the BUMPS DREAM sampler.\n",
"# ``MultiFitter.mcmc_sample()`` delegates to the BUMPS DREAM sampler.\n",
"# All keyword arguments are forwarded with user-friendly names:\n",
"# ``samples`` ← total retained samples\n",
"# ``burn`` ← burn‑in steps\n",
Expand All @@ -199,7 +199,7 @@
"# ``population``← BUMPS‑native ``pop`` for advanced users\n",
"# ``seed`` ← random seed for reproducibility\n",
"\n",
"posterior_dict = fitter.sample(\n",
"posterior_dict = fitter.mcmc_sample(\n",
" data,\n",
" samples=2000, # Short for demo; use 20 k+ in production\n",
" burn=500,\n",
Expand Down Expand Up @@ -450,7 +450,7 @@
"print('―' * 60)\n",
"print()\n",
"print('API surface demonstrated:')\n",
"print(' MultiFitter.sample(data, samples=, burn=, thin=, seed=)')\n",
"print(' MultiFitter.mcmc_sample(data, samples=, burn=, thin=, seed=)')\n",
"print(' PosteriorResults(draws, param_names, logp=, sampler_state=)')\n",
"print(' .summary() — formatted parameter table')\n",
"print(' .corner() — pairwise correlation plot')\n",
Expand Down
20,602 changes: 10,473 additions & 10,129 deletions pixi.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
]
requires-python = '>=3.11'
dependencies = [
'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian',
'easyscience @ git+https://github.com/easyscience/corelib.git@bayesian_mp',
# 'easyscience',
'scipp',
'refnx',
Expand Down Expand Up @@ -69,10 +69,10 @@ dev = [
'mkdocstrings-python', # MkDocs: Python docstring support
'pyyaml', # YAML parser
'spdx-headers', # SPDX license header validation
'corner', # Bayesian analysis and plotting
'arviz', # Bayesian analysis and plotting
]

bayesian = ["corner>=2.2", "arviz>=0.18"]

[project.urls]
Documentation = 'https://easyscience.github.io/reflectometry-lib'
'Release Notes' = 'https://github.com/easyscience/reflectometry-lib/releases'
Expand Down
29 changes: 29 additions & 0 deletions src/easyreflectometry/calculators/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,35 @@ def __init__(self):
"""Init function."""
super().__init__(interface_list=CalculatorBase._calculators)

def __reduce__(self):
"""Serialize the active calculator state for worker processes."""
wrapper = getattr(self(), '_wrapper', None)
if wrapper is None and self.current_interface_name is not None:
raise RuntimeError(
f'Cannot pickle CalculatorFactory: active interface '
f"{self.current_interface_name!r} exposes no '_wrapper' attribute. "
'The InterfaceFactoryTemplate API may have changed.'
)
return (
self._state_restore,
(
self.__class__,
self.current_interface_name,
wrapper.__getstate__() if wrapper is not None else None,
),
)

@staticmethod
def _state_restore(cls, interface_str, wrapper_state):
"""Restore a calculator factory with its active wrapper state."""
obj = cls()
if interface_str is not None and interface_str in obj.available_interfaces:
obj.switch(interface_str)
wrapper = getattr(obj(), '_wrapper', None)
if wrapper is not None and wrapper_state is not None:
wrapper.__setstate__(wrapper_state)
return obj

def reset_storage(self) -> None:
"""Reset storage."""
return self().reset_storage()
Expand Down
12 changes: 12 additions & 0 deletions src/easyreflectometry/calculators/wrapper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,18 @@ def get_item_value(self, name: str, key: str) -> float:
item = getattr(item, key)
return getattr(item, 'value')

def __getstate__(self) -> dict:
return {
'storage': self.storage,
'resolution_function': self._resolution_function,
'magnetism': self._magnetism,
}

def __setstate__(self, state: dict) -> None:
self.storage = state['storage']
self._resolution_function = state['resolution_function']
self._magnetism = state['magnetism']

def set_resolution_function(self, resolution_function: ResolutionFunction) -> None:
"""Set the resolution function for the calculator.

Expand Down
48 changes: 31 additions & 17 deletions src/easyreflectometry/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def fit_single_data_set_1d(self, data: DataSet1D, objective: str | None = None)
]
return result

def sample(
def mcmc_sample(
self,
data: sc.DataGroup,
samples: int = 10000,
Expand All @@ -364,6 +364,7 @@ def sample(
seed: int | None = None,
objective: str | None = None,
initializer: str | None = None,
n_workers: int | None = None,
progress_callback=None,
abort_test=None,
) -> dict:
Expand All @@ -383,12 +384,22 @@ def sample(
:param initializer: DREAM population initializer. One of ``'eps'``,
``'cov'``, ``'lhs'``, or ``'random'``. By default, None (BUMPS
uses ``'eps'``).
:param n_workers: Number of worker processes for parallel DREAM
population evaluation. ``None`` (default) and ``1`` use
sequential evaluation. Values greater than ``1`` enable
multiprocessing; the effective pool size is capped at
``min(n_workers, population)``.
:param progress_callback: Optional callback for progress updates during
sampling. Forwarded to the core MultiFitter.
:return: Dictionary with keys ``'draws'``, ``'param_names'``, ``'state'``,
and ``'logp'``.
:param abort_test: Optional callback that returns ``True`` to signal
that sampling should be aborted.
:return: Dictionary with keys ``'draws'``, ``'param_names'``,
``'internal_bumps_object'``, and ``'logp'``.
:raises RuntimeError: If the current minimizer is not a BUMPS instance.
:raises ValueError: If ``n_workers`` is not None and less than 1.
"""
if n_workers is not None and n_workers < 1:
raise ValueError(f'n_workers must be a positive integer or None, got {n_workers}')
obj = _validate_objective(objective) if objective is not None else self._objective

refl_nums = [k[3:] for k in data['coords'].keys() if 'Qz' == k[:2]]
Expand Down Expand Up @@ -417,20 +428,23 @@ def sample(
sampler_kwargs = {}
if initializer is not None:
sampler_kwargs['init'] = initializer
return self.easy_science_multi_fitter.sample(
x=x,
y=y,
weights=dy,
samples=samples,
burn=burn,
thin=thin,
chains=chains,
population=population,
seed=seed,
sampler_kwargs=sampler_kwargs or None,
progress_callback=progress_callback,
abort_test=abort_test,
)
core_sample_kwargs = {
'x': x,
'y': y,
'weights': dy,
'samples': samples,
'burn': burn,
'thin': thin,
'chains': chains,
'population': population,
'seed': seed,
'sampler_kwargs': sampler_kwargs or None,
'progress_callback': progress_callback,
'abort_test': abort_test,
}
if n_workers is not None:
core_sample_kwargs['n_workers'] = n_workers
return self.easy_science_multi_fitter.mcmc_sample(**core_sample_kwargs)

@property
def chi2(self) -> float | None:
Expand Down
48 changes: 48 additions & 0 deletions tests/calculators/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors <https://github.com/easyscience>
# SPDX-License-Identifier: BSD-3-Clause

"""Tests for CalculatorFactory serialization."""

import pickle # noqa: S403

import numpy as np
from numpy.testing import assert_allclose

from easyreflectometry.calculators import CalculatorFactory
from easyreflectometry.model import Model
from easyreflectometry.model import PercentageFwhm
from easyreflectometry.sample import Layer
from easyreflectometry.sample import Material
from easyreflectometry.sample import Multilayer
from easyreflectometry.sample import Sample


def test_calculator_factory_pickle_preserves_active_wrapper_storage():
"""Pickled calculator factories retain model storage for worker processes."""
si = Material(sld=2.07, isld=0.0, name='Si')
film = Material(sld=2.0, isld=0.0, name='Film')
d2o = Material(sld=6.36, isld=0.0, name='D2O')

sample = Sample(
Multilayer(Layer(material=si, thickness=0.0, roughness=3.0, name='Si')),
Multilayer(Layer(material=film, thickness=250.0, roughness=3.0, name='Film')),
Multilayer(Layer(material=d2o, thickness=0.0, roughness=3.0, name='D2O')),
)
model = Model(
sample=sample,
scale=1.0,
background=1e-6,
resolution_function=PercentageFwhm(0.02),
)
interface = CalculatorFactory()
interface.switch('refnx')
model.interface = interface

restored = pickle.loads(pickle.dumps(interface)) # noqa: S301

assert model.unique_name in restored()._wrapper.storage['model']
q = np.linspace(0.01, 0.3, 10)
assert_allclose(
restored.fit_func(q, model.unique_name),
interface.fit_func(q, model.unique_name),
)
Loading
Loading