Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
9256086
feat(ggm): log Z(G) closed-form primitives for hierarchical-spec MH
MaartenMarsman May 19, 2026
c2e7511
feat(ggm): general-alpha incremental log_Z_NLO_gamma (Phase 1b)
MaartenMarsman May 19, 2026
a3acba5
perf(ggm): V_a-pivot enumeration for log_Z_NLO_gamma partial sum
MaartenMarsman May 19, 2026
7372617
feat(ggm): DEGORD-permuted Bartlett-Cholesky importance sampler (Phas…
MaartenMarsman May 19, 2026
e6d8112
feat(ggm): V(Γ, U) Russian-Roulette estimator of 1/Z(Γ) (Phase 3)
MaartenMarsman May 19, 2026
29293bc
feat(ggm): hierarchical-spec MH hook in update_edge_indicator_paramet…
MaartenMarsman May 19, 2026
e0f9b80
feat(bgm): R API surface for graph_prior_spec = "hierarchical" (Phase…
MaartenMarsman May 19, 2026
f162e2d
fix(ci): ship live testthat fixtures + strip lint semicolons
MaartenMarsman May 19, 2026
843f63d
test(ggm): SBC for hierarchical-spec at q = 5 across the delta sweep …
MaartenMarsman May 19, 2026
c6279d3
feat(ggm): log-space V with within-toggle cache reuse (F5 + F6)
MaartenMarsman May 20, 2026
fd4b270
test(ggm): NUTS + hierarchical 2x2 API smoke (F1)
MaartenMarsman May 20, 2026
0c7cdd9
perf(ggm): SMW rank-2 cov update + drift-triggered refresh (F8)
MaartenMarsman May 20, 2026
effe8b0
feat(ggm): sign(V) and log|V| per-iteration diagnostics (F2)
MaartenMarsman May 20, 2026
5f09826
feat(bgms): sign-corrected posterior-mean helper (F3 + F4)
MaartenMarsman May 20, 2026
4cf5201
fix(ggm): correct sign of V-correction in hierarchical-spec MH ratio
MaartenMarsman May 20, 2026
2055394
feat(ggm): MH-on-U fix + manuscript NLO behind A/B flags
MaartenMarsman May 22, 2026
679a79f
Merge remote-tracking branch 'origin/main' into feat/log-z-nlo-gamma
MaartenMarsman May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
^paper$
^LICENSE$

# ---- Legacy test fixtures (GitHub CI only, not shipped to CRAN) ----
^tests/testthat/fixtures$
# ---- Legacy test fixtures: only the `legacy` subdirectory is excluded
# so it does not bloat the tarball; live fixtures used by tests in
# R CMD check must ship so the matching tests can run.
^tests/testthat/fixtures/legacy$

# ---- C/C++ build artifacts (REQUIRED) ----
^src/.*\.o$
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export(beta_bernoulli_prior)
export(beta_prime_prior)
export(bgm)
export(bgmCompare)
export(bgms_posterior_mean)
export(cauchy_prior)
export(exponential_prior)
export(extract_arguments)
Expand Down
76 changes: 74 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,78 @@ sample_ggm_prior_cpp <- function(p, n_samples, n_warmup = 1000L, pairwise_scale
.Call(`_bgms_sample_ggm_prior`, p, n_samples, n_warmup, pairwise_scale, interaction_prior_type, scale_prior_type, gamma_shape, gamma_rate, step_size, max_depth, seed, verbose, edge_indicators_nullable, delta)
}

log_Z_NLO_gamma_cpp <- function(G, alpha, beta, sigma, include_F = FALSE, delta = 0.0) {
.Call(`_bgms_log_Z_NLO_gamma_cpp`, G, alpha, beta, sigma, include_F, delta)
}

log_Z_manuscript_NLO_alpha1_cpp <- function(G, beta, sigma, delta) {
.Call(`_bgms_log_Z_manuscript_NLO_alpha1_cpp`, G, beta, sigma, delta)
}

log_Z_manuscript_NLO_alpha1_degord_cpp <- function(G, i, j, beta, sigma, delta) {
.Call(`_bgms_log_Z_manuscript_NLO_alpha1_degord_cpp`, G, i, j, beta, sigma, delta)
}

log_Z_NLO_gamma_degord_cpp <- function(G, i, j, alpha, beta, sigma, include_F = FALSE, delta = 0.0) {
.Call(`_bgms_log_Z_NLO_gamma_degord_cpp`, G, i, j, alpha, beta, sigma, include_F, delta)
}

log_Z_NLO_gamma_delta_incr_alpha1_cpp <- function(G_before, i, j, beta, sigma, delta, include_F = FALSE) {
.Call(`_bgms_log_Z_NLO_gamma_delta_incr_alpha1_cpp`, G_before, i, j, beta, sigma, delta, include_F)
}

log_Z_NLO_gamma_delta_incr_alphaN_cpp <- function(G_before, i, j, alpha, beta, sigma, delta, include_F = FALSE) {
.Call(`_bgms_log_Z_NLO_gamma_delta_incr_alphaN_cpp`, G_before, i, j, alpha, beta, sigma, delta, include_F)
}

degord_chain_aux_cpp <- function(q, alpha, beta, sigma, delta) {
.Call(`_bgms_degord_chain_aux_cpp`, q, alpha, beta, sigma, delta)
}

degord_pi_aux_cpp <- function(G_pi, alpha, beta, sigma, delta) {
.Call(`_bgms_degord_pi_aux_cpp`, G_pi, alpha, beta, sigma, delta)
}

degord_permute_graph_cpp <- function(G, i, j) {
.Call(`_bgms_degord_permute_graph_cpp`, G, i, j)
}

degord_log_Zhat_pi_from_pool_cpp <- function(noise_pool_t, G_pi, alpha, beta, sigma, delta, slab_tilt_mode = 0L) {
.Call(`_bgms_degord_log_Zhat_pi_from_pool_cpp`, noise_pool_t, G_pi, alpha, beta, sigma, delta, slab_tilt_mode)
}

degord_delta_log_Zhat_pi_toggle_cpp <- function(noise_pool, noise_pool_t, G_curr, i, j, alpha, beta, sigma, delta, slab_tilt_mode = 0L) {
.Call(`_bgms_degord_delta_log_Zhat_pi_toggle_cpp`, noise_pool, noise_pool_t, G_curr, i, j, alpha, beta, sigma, delta, slab_tilt_mode)
}

degord_draw_bartlett_pool_cpp <- function(q, M_inner, seed) {
.Call(`_bgms_degord_draw_bartlett_pool_cpp`, q, M_inner, seed)
}

degord_V_at_Gamma_pi_cpp <- function(K_depth, pools_t, G_pi, alpha, beta, sigma, delta, c_val, rho, slab_tilt_mode = 0L) {
.Call(`_bgms_degord_V_at_Gamma_pi_cpp`, K_depth, pools_t, G_pi, alpha, beta, sigma, delta, c_val, rho, slab_tilt_mode)
}

degord_V_log_at_Gamma_pi_cpp <- function(K_depth, pools_t, G_pi, alpha, beta, sigma, delta, log_c, rho, slab_tilt_mode = 0L) {
.Call(`_bgms_degord_V_log_at_Gamma_pi_cpp`, K_depth, pools_t, G_pi, alpha, beta, sigma, delta, log_c, rho, slab_tilt_mode)
}

degord_V_log_pair_at_Gamma_curr_star_cpp <- function(K_depth, pools_t, G_pi_curr, G_pi_star, alpha, beta, sigma, delta, log_c_curr, log_c_star, rho, slab_tilt_mode = 0L) {
.Call(`_bgms_degord_V_log_pair_at_Gamma_curr_star_cpp`, K_depth, pools_t, G_pi_curr, G_pi_star, alpha, beta, sigma, delta, log_c_curr, log_c_star, rho, slab_tilt_mode)
}

degord_log_Zhat_star_from_cache_cpp <- function(noise_pool_t, G_pi_curr, G_pi_star, alpha, beta, sigma, delta, slab_tilt_mode = 0L) {
.Call(`_bgms_degord_log_Zhat_star_from_cache_cpp`, noise_pool_t, G_pi_curr, G_pi_star, alpha, beta, sigma, delta, slab_tilt_mode)
}

degord_draw_U_rr_cpp <- function(M_inner, q, rho, seed) {
.Call(`_bgms_degord_draw_U_rr_cpp`, M_inner, q, rho, seed)
}

ggm_hierarchical_smoke_cpp <- function(observations, inclusion_prob, interaction_scale, diagonal_shape, diagonal_rate, delta, M_inner, kappa, rho, n_sweeps, seed, use_manuscript_nlo = FALSE) {
.Call(`_bgms_ggm_hierarchical_smoke_cpp`, observations, inclusion_prob, interaction_scale, diagonal_shape, diagonal_rate, delta, M_inner, kappa, rho, n_sweeps, seed, use_manuscript_nlo)
}

.compute_ess_cpp <- function(array3d) {
.Call(`_bgms_compute_ess_cpp`, array3d)
}
Expand Down Expand Up @@ -113,8 +185,8 @@ ggm_test_logp_and_gradient_prior <- function(theta, suf_stat, n, edge_indicators
.Call(`_bgms_ggm_test_logp_and_gradient_prior`, theta, suf_stat, n, edge_indicators, interaction_prior_type, interaction_scale, interaction_alpha, interaction_beta, diagonal_prior_type, diagonal_shape, diagonal_rate)
}

sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL, delta = 0.0) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable, delta)
sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL, delta = 0.0, graph_prior_spec = "joint", z_ratio_M_inner = 100L, z_ratio_kappa = 1.0, z_ratio_rho = 0.5, use_manuscript_nlo = FALSE, mh_U = FALSE) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable, delta, graph_prior_spec, z_ratio_M_inner, z_ratio_kappa, z_ratio_rho, use_manuscript_nlo, mh_U)
}

sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL, delta = 0.0) {
Expand Down
28 changes: 28 additions & 0 deletions R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@
#' apply the tilt. Not allowed for pure ordinal models (no precision
#' matrix to tilt).
#'
#' @param graph_prior_spec Character; one of \code{"joint"} (default)
#' or \code{"hierarchical"}. Controls the marginal prior on the graph
#' indicators \eqn{\Gamma}. Under \code{"joint"} the implicit
#' \eqn{\Gamma}-marginal is \eqn{\pi(\Gamma) \cdot Z(\Gamma)}, where
#' \eqn{Z(\Gamma)} is the normalising constant of the precision-matrix
#' prior conditional on the graph. Under \code{"hierarchical"} the
#' chain compensates with an unbiased estimator of
#' \eqn{Z(\Gamma_\text{curr}) / Z(\Gamma_\text{star})}, recovering
#' \eqn{\pi(\Gamma)} as the \eqn{\Gamma}-marginal. Only supported when
#' the interaction prior is \code{normal_prior(...)} and the precision-
#' scale prior is \code{gamma_prior(...)} (the prior families for which
#' the closed-form Laplace-NLO normaliser approximation is implemented).
#' Default: \code{"joint"}.
#'
#' @param z_ratio_tuning Named list with components \code{M_inner}
#' (positive integer, default 100), \code{kappa} (positive numeric,
#' default 1.0), and \code{rho} (numeric in (0, 1), default 0.5).
#' Tuning knobs for the V/Russian-Roulette estimator used when
#' \code{graph_prior_spec = "hierarchical"}; ignored otherwise.
#' \code{M_inner} is the number of inner Bartlett-Cholesky importance
#' samples per Russian-Roulette pool, \code{kappa} sets the analytic
#' centring \eqn{c = \kappa \exp(\log Z_\text{NLO}(\Gamma))}, and
#' \code{rho} is the geometric-truncation continuation probability.
#'
#' @param pairwise_scale `r lifecycle::badge("deprecated")` Double.
#' Scale of the Cauchy prior for pairwise
#' interaction parameters. Use \code{interaction_prior} instead.
Expand Down Expand Up @@ -343,6 +367,8 @@ bgm = function(
means_prior = normal_prior(scale = 1),
precision_scale_prior = gamma_prior(shape = 1, rate = 1),
delta = NULL,
graph_prior_spec = c("joint", "hierarchical"),
z_ratio_tuning = list(M_inner = 100L, kappa = 1.0, rho = 0.5),
edge_selection = TRUE,
edge_prior = bernoulli_prior(0.5),
na_action = c("listwise", "impute"),
Expand Down Expand Up @@ -511,6 +537,8 @@ bgm = function(
scale_shape = sp$scale_shape,
scale_rate = sp$scale_rate,
delta = delta,
graph_prior_spec = graph_prior_spec,
z_ratio_tuning = z_ratio_tuning,
standardize = standardize,
edge_selection = edge_selection,
edge_prior = edge_prior,
Expand Down
71 changes: 71 additions & 0 deletions R/bgm_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ bgm_spec = function(x,
scale_shape = 1,
scale_rate = 1,
delta = NULL,
graph_prior_spec = c("joint", "hierarchical"),
z_ratio_tuning = list(M_inner = 100L,
kappa = 1.0,
rho = 0.5),
standardize = FALSE,
edge_selection = TRUE,
edge_prior = bernoulli_prior(0.5),
Expand Down Expand Up @@ -364,6 +368,65 @@ bgm_spec = function(x,
!is.finite(delta) || delta < 0) {
stop("'delta' must be a single finite non-negative numeric, or NULL.")
}
# Validate hierarchical-spec args (only meaningful for ggm/mixed_mrf).
graph_prior_spec = if(is.character(graph_prior_spec) &&
length(graph_prior_spec) > 1L) {
match.arg(graph_prior_spec)
} else {
if(!(length(graph_prior_spec) == 1L &&
is.character(graph_prior_spec) &&
graph_prior_spec %in% c("joint", "hierarchical"))) {
stop("'graph_prior_spec' must be \"joint\" or \"hierarchical\".")
}
graph_prior_spec
}
if(graph_prior_spec == "hierarchical" &&
!(model_type %in% c("ggm", "mixed_mrf"))) {
stop(
"'graph_prior_spec = \"hierarchical\"' requires continuous data; ",
"the current model_type is '", model_type, "', which has no ",
"continuous precision block. Use \"joint\" or supply continuous data."
)
}
if(graph_prior_spec == "hierarchical" &&
interaction_prior_type != "normal") {
stop(
"'graph_prior_spec = \"hierarchical\"' requires a Normal slab ",
"prior (interaction_prior_type = \"normal\"). Re-fit with ",
"interaction_prior = normal_prior(scale = ...)."
)
}
if(graph_prior_spec == "hierarchical" &&
scale_prior_type != "gamma") {
stop(
"'graph_prior_spec = \"hierarchical\"' requires a Gamma diagonal ",
"prior (scale_prior_type = \"gamma\")."
)
}
# Validate z_ratio_tuning shape (only enforced if hierarchical; for joint
# the defaults pass through unused).
if(!is.list(z_ratio_tuning))
stop("'z_ratio_tuning' must be a list with components M_inner, kappa, rho.")
zrt_M_inner = z_ratio_tuning$M_inner %||% 100L
zrt_kappa = z_ratio_tuning$kappa %||% 1.0
zrt_rho = z_ratio_tuning$rho %||% 0.5
zrt_use_manuscript_nlo = isTRUE(z_ratio_tuning$use_manuscript_nlo)
zrt_mh_U = isTRUE(z_ratio_tuning$mh_U)
if(!is.numeric(zrt_M_inner) || length(zrt_M_inner) != 1L ||
!is.finite(zrt_M_inner) || zrt_M_inner < 1L)
stop("'z_ratio_tuning$M_inner' must be a positive integer.")
if(!is.numeric(zrt_kappa) || length(zrt_kappa) != 1L ||
!is.finite(zrt_kappa) || zrt_kappa <= 0)
stop("'z_ratio_tuning$kappa' must be a positive number.")
if(!is.numeric(zrt_rho) || length(zrt_rho) != 1L ||
!is.finite(zrt_rho) || zrt_rho <= 0 || zrt_rho >= 1)
stop("'z_ratio_tuning$rho' must be in (0, 1).")
z_ratio_tuning = list(M_inner = as.integer(zrt_M_inner),
kappa = as.numeric(zrt_kappa),
rho = as.numeric(zrt_rho),
use_manuscript_nlo = zrt_use_manuscript_nlo,
mh_U = zrt_mh_U)

if(delta > 0 && model_type %in% c("omrf", "compare")) {
stop(
"'delta' (determinant tilt) requires continuous variables; the ",
Expand Down Expand Up @@ -444,6 +507,8 @@ bgm_spec = function(x,
scale_shape = scale_shape,
scale_rate = scale_rate,
delta = delta,
graph_prior_spec = graph_prior_spec,
z_ratio_tuning = z_ratio_tuning,
edge_prior_flat = ep_flat
)
} else if(model_type == "mixed_mrf") {
Expand Down Expand Up @@ -536,6 +601,10 @@ build_spec_ggm = function(x, data_columnnames, num_variables,
interaction_alpha, interaction_beta,
scale_prior_type, scale_shape, scale_rate,
delta = 0,
graph_prior_spec = "joint",
z_ratio_tuning = list(M_inner = 100L,
kappa = 1.0,
rho = 0.5),
edge_prior_flat) {
# Missing data
md = validate_missing_data(
Expand Down Expand Up @@ -577,6 +646,8 @@ build_spec_ggm = function(x, data_columnnames, num_variables,
scale_shape = scale_shape,
scale_rate = scale_rate,
delta = delta,
graph_prior_spec = graph_prior_spec,
z_ratio_tuning = z_ratio_tuning,
edge_selection = ep$edge_selection,
edge_prior = ep$edge_prior,
inclusion_probability = ep$inclusion_probability,
Expand Down
Loading
Loading