Source code for GUIBRUSHR.Retrieval.ExofastMCMC.exofast_gelmanrubin

"""
Gelman-Rubin convergence diagnostic implementation.

This module implements the Gelman-Rubin convergence diagnostic for MCMC chains
as described in Ford 2006. The diagnostic is used to assess whether multiple
MCMC chains have converged to the same distribution.

References:
    Ford, E. B. (2006). Improving the efficiency of Markov chain Monte Carlo 
    for analyzing the orbits of extrasolar planets. The Astrophysical Journal, 
    642(1), 505.
"""

import numpy as np

# import pickle
# from os.path import dirname, join as pjoin
# from scipy.io import readsav

# Global variable for output file (used for error logging)
file_output_file = None


[docs] def exofast_gelmanrubin(pars0): """ Calculate the Gelman-Rubin convergence diagnostic for MCMC chains. This function implements the Gelman-Rubin diagnostic (R-hat) to assess convergence of multiple MCMC chains. The diagnostic compares within-chain and between-chain variances to determine if chains have converged. The function follows the equations from Ford 2006: - Equation 21: W(z) - within-chain variance - Equation 23: B(z) - between-chain variance - Equation 24: varhat+(z) - pooled variance estimate - Equation 25: Rhat(z) - Gelman-Rubin statistic - Equation 26: T(z) - effective sample size factor Args: pars0: 3D numpy array with shape (npars, nsteps, nchains) where npars is the number of parameters, nsteps is the number of MCMC steps, nchains is the number of independent chains Returns: tuple: (converged, gelmanrubin, tz) where: - converged: int, 1 if converged, 0 if not converged - gelmanrubin: numpy array of R-hat values for each parameter - tz: numpy array of effective sample size factors Notes: Convergence is determined when: - All R-hat values < 1.01 - All effective sample sizes > 1000 """ # Copy input to avoid modifying original data pars = pars0 sz = np.shape(pars) # Validate input dimensions if len(sz) != 3: error_msg = "ERROR: pars must have 3 dimensions" if file_output_file is not None: with open(file_output_file, "w") as f: f.write(error_msg) return None, None, None # Extract dimensions npars = sz[0] # Number of parameters nsteps = sz[1] # Number of MCMC steps nchains = sz[2] # Number of chains # Validate minimum requirements if nsteps == 1: error_msg = "ERROR: NSTEPS must be greater than 1" if file_output_file is not None: with open(file_output_file, "w") as f: f.write(error_msg) return None, None, None if nchains == 2: error_msg = "ERROR: NCHAINS must be greater than 2" if file_output_file is not None: with open(file_output_file, "w") as f: f.write(error_msg) return None, None, None # Calculate within-chain variances for each parameter and chain variances = np.zeros([npars, nchains]) for i in range(npars): for k in range(nchains): # Use ddof=1 for sample variance (Bessel's correction) variances[i, k] = np.std(pars[i, :, k], ddof=1) ** 2 # Equation 21: W(z) - Mean of within-chain variances meanofvariances = np.sum(variances, axis=1) / nchains # Calculate chain means for each parameter means = np.sum(pars, axis=1) / nsteps # Equation 23: B(z) - Between-chain variance calculation # Calculate variance of chain means overall_means = np.sum(means, axis=1) / nchains mean_deviations = (means - np.repeat(overall_means, nchains).reshape((npars, nchains))) varianceofmeans = np.sum(mean_deviations ** 2, axis=1) / (nchains - 1) # Scale by number of steps to get B(z) bz = varianceofmeans * nsteps # Equation 24: varhat+(z) - Pooled variance estimate varz = ((nsteps - 1) / nsteps * meanofvariances + varianceofmeans) # Equation 25: Rhat(z) - Gelman-Rubin statistic gelmanrubin = np.sqrt(varz / meanofvariances) # Equation 26: T(z) - Effective sample size factor # Original commented line: tz = nchains * nsteps * (varz/bz < 1) # Current implementation uses minimum to cap the ratio at 1 tz = nchains * nsteps * np.minimum(varz / bz, np.ones(np.shape(varz))) # Determine convergence based on thresholds # Converged if effective sample size > 1000 AND all R-hat < 1.01 if np.min(tz) > 1000 and np.max(gelmanrubin) < 1.01: return 1, gelmanrubin, tz else: return 0, gelmanrubin, tz