Source code for GUIBRUSHR.Retrieval.ExofastMCMC.exofast_gelmanrubin
"""
Gelman-Rubin convergence diagnostic implementation.
This module implements the Gelman-Rubin convergence diagnostic for MCMC chains
as described in Ford 2006. The diagnostic is used to assess whether multiple
MCMC chains have converged to the same distribution.
References:
Ford, E. B. (2006). Improving the efficiency of Markov chain Monte Carlo
for analyzing the orbits of extrasolar planets. The Astrophysical Journal,
642(1), 505.
"""
import numpy as np
# import pickle
# from os.path import dirname, join as pjoin
# from scipy.io import readsav
# Global variable for output file (used for error logging)
file_output_file = None
[docs]
def exofast_gelmanrubin(pars0):
"""
Calculate the Gelman-Rubin convergence diagnostic for MCMC chains.
This function implements the Gelman-Rubin diagnostic (R-hat) to assess
convergence of multiple MCMC chains. The diagnostic compares within-chain
and between-chain variances to determine if chains have converged.
The function follows the equations from Ford 2006:
- Equation 21: W(z) - within-chain variance
- Equation 23: B(z) - between-chain variance
- Equation 24: varhat+(z) - pooled variance estimate
- Equation 25: Rhat(z) - Gelman-Rubin statistic
- Equation 26: T(z) - effective sample size factor
Args:
pars0: 3D numpy array with shape (npars, nsteps, nchains)
where npars is the number of parameters,
nsteps is the number of MCMC steps,
nchains is the number of independent chains
Returns:
tuple: (converged, gelmanrubin, tz) where:
- converged: int, 1 if converged, 0 if not converged
- gelmanrubin: numpy array of R-hat values for each parameter
- tz: numpy array of effective sample size factors
Notes:
Convergence is determined when:
- All R-hat values < 1.01
- All effective sample sizes > 1000
"""
# Copy input to avoid modifying original data
pars = pars0
sz = np.shape(pars)
# Validate input dimensions
if len(sz) != 3:
error_msg = "ERROR: pars must have 3 dimensions"
if file_output_file is not None:
with open(file_output_file, "w") as f:
f.write(error_msg)
return None, None, None
# Extract dimensions
npars = sz[0] # Number of parameters
nsteps = sz[1] # Number of MCMC steps
nchains = sz[2] # Number of chains
# Validate minimum requirements
if nsteps == 1:
error_msg = "ERROR: NSTEPS must be greater than 1"
if file_output_file is not None:
with open(file_output_file, "w") as f:
f.write(error_msg)
return None, None, None
if nchains == 2:
error_msg = "ERROR: NCHAINS must be greater than 2"
if file_output_file is not None:
with open(file_output_file, "w") as f:
f.write(error_msg)
return None, None, None
# Calculate within-chain variances for each parameter and chain
variances = np.zeros([npars, nchains])
for i in range(npars):
for k in range(nchains):
# Use ddof=1 for sample variance (Bessel's correction)
variances[i, k] = np.std(pars[i, :, k], ddof=1) ** 2
# Equation 21: W(z) - Mean of within-chain variances
meanofvariances = np.sum(variances, axis=1) / nchains
# Calculate chain means for each parameter
means = np.sum(pars, axis=1) / nsteps
# Equation 23: B(z) - Between-chain variance calculation
# Calculate variance of chain means
overall_means = np.sum(means, axis=1) / nchains
mean_deviations = (means -
np.repeat(overall_means, nchains).reshape((npars, nchains)))
varianceofmeans = np.sum(mean_deviations ** 2, axis=1) / (nchains - 1)
# Scale by number of steps to get B(z)
bz = varianceofmeans * nsteps
# Equation 24: varhat+(z) - Pooled variance estimate
varz = ((nsteps - 1) / nsteps * meanofvariances + varianceofmeans)
# Equation 25: Rhat(z) - Gelman-Rubin statistic
gelmanrubin = np.sqrt(varz / meanofvariances)
# Equation 26: T(z) - Effective sample size factor
# Original commented line: tz = nchains * nsteps * (varz/bz < 1)
# Current implementation uses minimum to cap the ratio at 1
tz = nchains * nsteps * np.minimum(varz / bz, np.ones(np.shape(varz)))
# Determine convergence based on thresholds
# Converged if effective sample size > 1000 AND all R-hat < 1.01
if np.min(tz) > 1000 and np.max(gelmanrubin) < 1.01:
return 1, gelmanrubin, tz
else:
return 0, gelmanrubin, tz