"""
EXOFAST DEMC (Differential Evolution Markov Chain) Implementation.
This module implements the DEMC posterior sampling algorithm for parameter estimation
in exoplanet fitting. The code maintains the exact logic and sequence of operations
from the original implementation while improving readability and documentation.
PARALLEL VERSION: The _initialize_first_chain_step function has been parallelized
using multiprocessing.Pool for improved performance on multi-core systems.
"""
import datetime
import traceback
from pathlib import Path
# import multiprocessing as mp
import numpy as np
import pickle
from GUIBRUSHR.Retrieval.ExofastMCMC.exofast_gelmanrubin import exofast_gelmanrubin
[docs]
def time_left_units(timeleft):
"""
Convert time in seconds to the most convenient time units.
This function takes a time duration in seconds and converts it to the most
appropriate unit (seconds, minutes, hours, or days) for display purposes.
Parameters
----------
timeleft : float
Remaining time in seconds.
Returns
-------
tuple
A tuple containing (converted_time, units_string).
"""
units = "seconds"
# Convert to minutes if more than 60 seconds
if timeleft > 60:
timeleft /= 60
units = "minutes"
# Convert to hours if more than 60 minutes
if timeleft > 60:
timeleft /= 60
units = "hours"
# Convert to days if more than 24 hours
if timeleft > 24:
timeleft /= 24
units = "days"
return timeleft, units
def _initialize_chain_files(model_obj):
"""
Initialize file paths for chain and probability data storage.
Parameters
----------
model_obj :
Model object containing retrieval data and configuration.
Returns
-------
tuple
Tuple containing (chain_file, prob_file, partial_prob_and_chain) paths.
"""
chain_file = "chain_burnin.dat"
prob_file = "probabilities_burnin.dat"
partial_prob_and_chain = "partial_prob_and_chain_burnin.pkl"
# Construct full file paths using results directory
chain_file = model_obj.retrieval_data.path_results + chain_file
prob_file = model_obj.retrieval_data.path_results + prob_file
partial_prob_and_chain = (
model_obj.retrieval_data.path_results + partial_prob_and_chain
)
return chain_file, prob_file, partial_prob_and_chain
def _load_previous_chains(partial_prob_and_chain, pars, chi2, det_chain):
"""
Load and concatenate previous chain data when resuming sampling.
Parameters
----------
partial_prob_and_chain : str
Path to the pickle file containing previous chain data.
pars : ndarray
Current parameter array to concatenate with.
chi2 : ndarray
Current chi-squared array to concatenate with.
det_chain : ndarray
Current determinant chain array to concatenate with.
Returns
-------
tuple
Updated (pars, chi2, det_chain, last_index_data_resumed, naccept).
"""
with open(partial_prob_and_chain, "rb") as f:
pars_and_chain = pickle.load(f)
# Handle both dictionary and legacy list formats
if isinstance(pars_and_chain, dict):
pars_old = pars_and_chain["parameters"]
chi2_old = pars_and_chain["chi2"]
det_chain_old = pars_and_chain.get("det_chain", None)
else:
# Legacy format
pars_old = pars_and_chain[0]
chi2_old = pars_and_chain[1]
det_chain_old = pars_and_chain[2] if len(pars_and_chain) > 2 else None
# Concatenate old and new arrays
pars = np.concatenate((pars_old, pars), axis=1)
chi2 = np.concatenate((chi2_old, chi2), axis=0)
last_index_data_resumed = len(chi2_old[:, 0]) - 1
naccept = int(np.sum([len(np.unique(pars[0, :, i]))
for i in range(len(pars[0, 0, :]))]))
# Handle determinant chain data with error handling
try:
if det_chain_old is not None:
det_chain = np.concatenate((det_chain_old, det_chain), axis=0)
else:
raise ValueError("No det_chain in loaded data")
except Exception as e:
print("Exception reached due to:", e,
"\n No previous det_chain. It will be set to 1 for all elements")
det_chain = np.concatenate((np.ones(np.shape(chi2_old)), det_chain), axis=0)
return pars, chi2, det_chain, last_index_data_resumed, naccept
def _extract_model_data_for_parallel(model_obj):
"""
Extract necessary data from model_obj for parallel workers.
This function extracts all data needed by the parallel workers to recreate
param_full objects and calculate likelihoods without access to the full model_obj.
Parameters
----------
model_obj :
The model object containing all necessary methods and data
Returns
-------
dict
Dictionary containing serializable model data needed by workers
"""
# Extract the essential data for parameter creation and likelihood calculation
model_data = {
# Core parameter data
'bestpars_initial': model_obj.bestpars_data.list_bestpars_initial_value.copy(),
'scale_vector': model_obj.retrieval_data.scale_vector_params.copy(),
# Add any other data your model_obj needs for create_param_full and lh_function_gib
# You may need to customize this based on your specific model structure
# Examples (uncomment and adapt as needed):
# 'observations': getattr(model_obj, 'observations', None),
# 'covariance_matrix': getattr(model_obj, 'covariance_matrix', None),
# 'parameter_bounds': getattr(model_obj, 'parameter_bounds', None),
# 'instrument_data': getattr(model_obj, 'instrument_data', None),
# Add a reference to the model object itself if it's serializable
# Otherwise, extract specific methods or data structures needed
'model_obj_ref': model_obj # This works if model_obj is pickleable
}
return model_data
def _create_param_full_parallel(candidate_pars, model_data):
"""
Parallel-safe version of model_obj.create_param_full().
This function recreates the param_full object using the extracted model data,
enabling parallel workers to check parameter boundaries without the full model_obj.
Parameters
----------
candidate_pars : ndarray
Candidate parameter values
model_data : dict
Dictionary containing extracted model data
Returns
-------
list
List of parameter objects with boundary checking capability
"""
# If the full model_obj is available and serializable, use it directly
if 'model_obj_ref' in model_data and model_data['model_obj_ref'] is not None:
return model_data['model_obj_ref'].create_param_full(candidate_pars)
# Otherwise, implement a standalone version based on your parameter structure
# This is where you would implement the logic from your create_param_full method
# without requiring the full model_obj
# Example implementation (customize based on your actual parameter structure):
# param_full = []
# for i, par_value in enumerate(candidate_pars):
# # Create parameter objects based on your specific parameter classes
# param_obj = YourParameterClass(
# value=par_value,
# bounds=model_data.get('parameter_bounds', {}).get(i, None),
# # Add other necessary parameter attributes
# )
# param_full.append(param_obj)
# return param_full
raise NotImplementedError(
"You need to implement _create_param_full_parallel() based on your "
"model_obj.create_param_full() method. Either ensure model_obj is "
"serializable or implement a standalone version of parameter creation."
)
def _calculate_likelihood_parallel(param_full, model_data):
"""
Parallel-safe version of model_obj.lh_function_gib().
This function calculates the likelihood using the extracted model data,
enabling parallel workers to compute chi2 and determinant values.
Parameters
----------
param_full : list
List of parameter objects created by _create_param_full_parallel
model_data : dict
Dictionary containing extracted model data
Returns
-------
tuple
(chi2_value, determinant_value, additional_info) - same format as lh_function_gib
"""
# If the full model_obj is available and serializable, use it directly
if 'model_obj_ref' in model_data and model_data['model_obj_ref'] is not None:
return model_data['model_obj_ref'].lh_function_gib(param_full)
# Otherwise, implement a standalone version based on your likelihood calculation
# This is where you would implement the logic from your lh_function_gib method
# without requiring the full model_obj
# Example implementation (customize based on your actual likelihood calculation):
# param_values = np.array([p.value for p in param_full if p is not None])
#
# # Implement your specific likelihood calculation here
# chi2_val = your_chi2_calculation(param_values, model_data)
# det_val = your_determinant_calculation(param_values, model_data)
# additional_info = None # or whatever your lh_function_gib returns as third element
#
# return chi2_val, det_val, additional_info
raise NotImplementedError(
"You need to implement _calculate_likelihood_parallel() based on your "
"model_obj.lh_function_gib() method. Either ensure model_obj is "
"serializable or implement a standalone version of likelihood calculation."
)
# def _initialize_single_chain_worker(kr, bestpars_initial, scale_vector, nfit, seed_base, model_data):
# """
# Worker function to initialize a single chain with valid parameter values.
#
# This function replicates the logic of the original for loop iteration,
# generating initial parameter values for one chain and ensuring they
# satisfy boundary conditions through iterative sampling.
#
# Parameters
# ----------
# kr : int
# Chain index
# bestpars_initial : ndarray
# Initial best parameter values
# scale_vector : ndarray
# Parameter scaling vector
# nfit : int
# Number of fitted parameters
# seed_base : int
# Base seed for random number generation
# model_data : dict
# Dictionary containing model data needed for param_full creation and likelihood
#
# Returns
# -------
# tuple
# (chain_index, parameters, chi2_value, determinant_value)
# """
# # Create independent RNG for this worker to avoid correlation between chains
# rng = np.random.default_rng(seed=seed_base + kr)
#
# param_full = None
# cond_params = False
#
# # Keep generating parameters until boundary conditions are satisfied
# while not cond_params:
# cond_params = True
#
# # Generate random normal deviates
# a = rng.standard_normal(nfit)
#
# # Create parameter values using initial values and scaling
# candidate_pars = bestpars_initial + scale_vector * a
#
# # Create full parameter object and check boundaries
# param_full = _create_param_full_parallel(candidate_pars, model_data)
# for elem in param_full:
# if elem is not None:
# cond_params = cond_params and elem.boundaries_check()
#
# # Calculate likelihood for initial parameters
# chi2_val, det_val, _ = _calculate_likelihood_parallel(param_full, model_data)
#
# return kr, candidate_pars, chi2_val, det_val
# def _initialize_first_chain_step_parallel(model_obj, nchains, nfit, pars, chi2, det_chain, n_cores=None):
# """
# Initialize the first step of each chain with valid parameter values.
#
# This function generates initial parameter values for each chain, ensuring
# they satisfy boundary conditions through iterative sampling.
#
# Parameters
# ----------
# model_obj :
# Model object containing parameter generation methods.
# nchains : int
# Number of chains to initialize.
# nfit : int
# Number of fitted parameters.
# pars : ndarray
# Parameter array to populate.
# chi2 : ndarray
# Chi-squared array to populate.
# det_chain : ndarray
# Determinant chain array to populate.
# n_cores : int,
# Number of worker processes.
# """
#
# # Extract necessary data from model_obj for parallel workers
# bestpars_initial = model_obj.bestpars_data.list_bestpars_initial_value
# scale_vector = model_obj.retrieval_data.scale_vector_params
# seed_base = model_obj.random_obj.seed
#
# # Prepare model data for workers
# model_data = _extract_model_data_for_parallel(model_obj)
#
# # Prepare arguments for each worker
# worker_args = [
# (kr, bestpars_initial, scale_vector, nfit, seed_base, model_data)
# for kr in range(nchains)
# ]
#
# try:
# # Execute in parallel using Pool
# with mp.Pool(processes=n_cores) as pool:
# results = pool.starmap(_initialize_single_chain_worker, worker_args)
#
# # Populate output arrays with results
# for kr, candidate_pars, chi2_val, det_val in results:
# pars[:, 0, kr] = candidate_pars
# chi2[0, kr] = chi2_val
# det_chain[0, kr] = det_val
#
# except Exception as e:
# print(f"Parallel initialization failed: {e}")
# print("Falling back to serial implementation...")
# # Fallback to serial implementation if parallel fails
# _initialize_first_chain_step_serial(model_obj, nchains, nfit, pars, chi2, det_chain)
def _initialize_first_chain_step_serial(model_obj, nchains, nfit, pars, chi2, det_chain):
"""
Original serial implementation of chain initialization.
This function maintains the original logic as a fallback when parallel
processing fails or is not beneficial.
Parameters
----------
model_obj :
Model object containing parameter generation methods.
nchains : int
Number of chains to initialize.
nfit : int
Number of fitted parameters.
pars : ndarray
Parameter array to populate.
chi2 : ndarray
Chi-squared array to populate.
det_chain : ndarray
Determinant chain array to populate.
"""
param_full = None
# Initialize each chain
for kr in range(nchains):
print(f"Initializing chain {kr}")
cond_params = False
# Keep generating parameters until boundary conditions are satisfied
n_tentative = 0
while not cond_params:
n_tentative += 1
cond_params = True
# Generate random normal deviates
a = model_obj.random_obj.rng.standard_normal(nfit)
# Create parameter values using initial values and scaling
pars[:, 0, kr] = (
model_obj.bestpars_data.list_bestpars_initial_value +
model_obj.retrieval_data.scale_vector_params * a
)
# Create full parameter object and check boundaries
param_full = model_obj.create_param_full(pars[:, 0, kr])
for elem in param_full:
if elem is not None:
cond_params = cond_params and elem.boundaries_check()
if n_tentative % 100 == 0 and not elem.boundaries_check():
print(f"Chain {kr}: tentative {n_tentative}. Param {elem.name}: {elem.value_in_retrieval}, must stay between {elem.range_min} - {elem.range_max}")
# Calculate likelihood for initial parameters
chi2[0, kr], det_chain[0, kr], _ = model_obj.lh_function_gib(param_full)
def _save_intermediate_results(chain_file, prob_file, partial_prob_and_chain,
pars, chi2, det_chain, nfit, index_position_python):
"""
Save intermediate chain results to files.
Parameters
----------
chain_file : str
Path to chain data file.
prob_file : str
Path to probability data file.
partial_prob_and_chain : str
Path to pickle file for partial results.
pars : ndarray
Parameter array.
chi2 : ndarray
Chi-squared array.
det_chain : ndarray
Determinant chain array.
nfit : int
Number of fitted parameters.
index_position_python : int
Current position index (Python indexing).
"""
# Save chain data
with open(chain_file, "w") as fc:
samples_temp = pars[:, :index_position_python, :].reshape((nfit, -1))
for kk in range(samples_temp.shape[1]):
fc.write("%s\n" % " ".join(str(x) for x in samples_temp[:, kk]))
# Save probability data
with open(prob_file, "w") as fc2:
lnpmflat = chi2[:index_position_python, :].reshape(-1)
for kk in range(lnpmflat.shape[0]):
fc2.write("%s\n" % lnpmflat[kk])
# Save partial results as pickle: saving partial burnin chains
with open(partial_prob_and_chain, "wb") as f:
pickle.dump({
"parameters": pars[:, :index_position_python, :],
"chi2": chi2[:index_position_python, :],
"det_chain": det_chain[:index_position_python, :]
}, f)
def _check_convergence_and_burnin(
pars, chi2, nfit, nchains, index_position_python,
):
"""
Check convergence and determine burn-in index.
Parameters
----------
pars : ndarray
Parameter array.
chi2 : ndarray
Chi-squared array.
nfit : int
Number of fitted parameters.
nchains : int
Number of chains.
index_position_python : int
Current position index.
Returns
-------
tuple
Tuple containing (converged, gelmanrubin, tz, burnndx).
"""
# Calculate median chi-squared for burn-in determination
medchi2 = np.median(chi2[:index_position_python, :])
burnndx = 0
for jj in range(nchains):
tmpndx = np.where(chi2[:index_position_python, jj] > medchi2)[0]
if len(tmpndx) > 0:
if tmpndx[0] > burnndx:
burnndx = tmpndx[0]
# allows Gelman-Rubin calculation if one chain is being problematic
burnndx = min(burnndx, index_position_python - 3)
# Calculate Gelman-Rubin convergence statistics
converged, gelmanrubin, tz = exofast_gelmanrubin(
pars[0:nfit-1, burnndx:index_position_python, :]
)
return converged, gelmanrubin, tz, burnndx
def _update_convergence_tracking(converged, nstop, i, npass, dontstop,
maxsteps, output_file, gelmanrubin, tz):
"""
Update convergence tracking variables and determine next recalculation step.
Parameters
----------
converged : int
Convergence flag (1 if converged, 0 otherwise).
nstop : int
Step at which convergence was first achieved.
i : int
Current step index.
npass : int
Number of consecutive convergence passes.
dontstop : bool
Flag to continue even after convergence.
maxsteps : int
Maximum number of steps.
output_file : str
Path to output file for status messages.
gelmanrubin : ndarray
Gelman-Rubin statistics.
tz : ndarray
Independent draws statistics.
Returns
-------
tuple
Updated (nextrecalc, npass, nstop, should_break).
"""
should_break = False
if converged == 1:
if nstop == 0:
nstop = i
nextrecalc = int(nstop / (1 - npass / 100))
npass += 1
if npass == 6:
if dontstop == 0:
temp_str = f"Has converged: {converged} {gelmanrubin} {tz}"
with open(output_file, "w") as f:
f.write(temp_str)
should_break = True
nextrecalc = maxsteps
else:
nextrecalc = int(i / 0.9)
nstop = 0
npass = 1
return nextrecalc, npass, nstop, should_break
def _finalize_chains_and_save(model_obj, pars, chi2, nfit, removeburn,
burnndx, nstop):
"""
Finalize chains by removing burn-in and save final results.
Parameters
----------
model_obj :
Model object containing configuration.
pars : ndarray
Parameter array.
chi2 : ndarray
Chi-squared array.
nfit : int
Number of fitted parameters.
removeburn : bool
Whether to remove burn-in period from final results.
burnndx : int
Burn-in index.
nstop : int
Final step index.
Returns
-------
tuple
Final (pars, chi2) arrays.
"""
# Remove burn-in period if requested
if removeburn:
pars = pars[:, burnndx:nstop, :]
chi2 = chi2[burnndx:nstop, :]
else:
pars = pars[:, 0:nstop, :]
chi2 = chi2[0:nstop, :]
# Set up final output file paths
chain_file = "chain.dat"
prob_file = "probabilities.dat"
partial_prob_and_chain = "partial_prob_and_chain.pkl"
chain_file = model_obj.retrieval_data.path_results + chain_file
prob_file = model_obj.retrieval_data.path_results + prob_file
partial_prob_and_chain = str(Path(model_obj.retrieval_data.path_results,
partial_prob_and_chain))
# Save final chain data
with open(chain_file, "w") as fc:
samples_temp = pars[:, :, :].reshape((nfit, -1))
for kk in range(samples_temp.shape[1]):
fc.write("%s\n" % " ".join(str(x) for x in samples_temp[:, kk]))
# Save final probability data
with open(prob_file, "w") as fc2:
lnpmflat = chi2[:, :].reshape(-1)
for kk in range(lnpmflat.shape[0]):
fc2.write("%s\n" % lnpmflat[kk])
# Save final partial results: saving partial burnin chains
with open(partial_prob_and_chain, "wb") as f:
pickle.dump({
"parameters": pars,
"chi2": chi2
}, f)
return pars, chi2
[docs]
def likelihood(model_obj, dontstop=False, nthin=1, removeburn=True, moresteps=False):
"""
Run a DEMC (Differential Evolution Markov Chain) posterior sampling.
This function implements the main DEMC algorithm for Bayesian parameter
estimation. It runs multiple chains in parallel, monitors convergence
using Gelman-Rubin statistics, and saves intermediate results.
Parameters
----------
model_obj :
Model object containing all necessary data and methods for sampling.
dontstop : bool, optional
If True, continue sampling even after convergence. Default is False.
nthin : int, optional
Thinning factor for the chains. Default is 1.
removeburn : bool, optional
If True, remove burn-in period from final results. Default is True.
moresteps : bool, optional
If True, resume from previous sampling run. Default is False.
Returns
-------
tuple
Final parameter chains and chi-squared values as (pars, chi2).
"""
# Initialize file paths for data storage
chain_file, prob_file, partial_prob_and_chain = _initialize_chain_files(model_obj)
# Extract configuration parameters from model object
nfit = model_obj.bestpars_data.nfit
max_steps_run = model_obj.retrieval_data.maxsteps
nchains = int(model_obj.bestpars_data.nchains)
ncores = int(model_obj.bestpars_data.ncores)
# use_pool = model_obj.bestpars_data.use_pool
# use_parallel_init = model_obj.bestpars_data.use_parallel_init
# Initialize arrays for parameters, chi-squared, and determinant chain
pars = np.zeros([nfit, max_steps_run, nchains])
chi2 = np.zeros([max_steps_run, nchains])
det_chain = np.zeros([max_steps_run, nchains])
# Record start time for progress tracking
t0 = datetime.datetime.now()
# Handle resuming from previous run or starting fresh
if moresteps:
# Load and concatenate previous chain data
pars, chi2, det_chain, last_index_data_resumed, naccept = _load_previous_chains(
partial_prob_and_chain, pars, chi2, det_chain
)
else:
# Initialize for fresh start
naccept = 1
last_index_data_resumed = 0
# Set up output file and display initial status
output_file = model_obj.retrieval_data.table_output_file
maxsteps = max_steps_run + last_index_data_resumed
print(f"\nStart at step {last_index_data_resumed} of {maxsteps}")
print(f"{naccept} accepted")
print(f"{ncores} cores")
print(f"{nchains} chains\n")
acceptancerate = "0"
with open(output_file, "w") as f:
f.write("First newpars creation")
# Initialize first step of chains if not resuming
# Note: This section is parallelizable
if not moresteps:
# if use_parallel_init:
# print("Using parallel chain initialization...")
# _initialize_first_chain_step_parallel(
# model_obj, nchains, nfit, pars, chi2, det_chain, ncores
# )
# else:
_initialize_first_chain_step_serial(model_obj, nchains, nfit, pars, chi2, det_chain)
# Initialize convergence tracking variables
nextrecalc = 1000
npass = 1
nstop = 0
with open(output_file, "w") as f:
f.write("Starting chains")
# Main sampling loop
counter = 0
index_total = last_index_data_resumed
for i in range(1, int(max_steps_run)):
index_total = i + last_index_data_resumed
counter += 1
# Run parallel processes for current step with retry logic
max_step_retries = 5
step_retry_count = 0
step_success = False
while step_retry_count < max_step_retries and not step_success:
try:
# Run parallel processes for current step
return_dict = model_obj.run_multiple_processes(
pars[:, index_total-1, :],
chi2[index_total-1, :],
det_chain[index_total-1, :],
# use_pool
)
# Collect results from parallel processes
counter_chain = 0
chain_per_core = int(nchains / ncores)
for j in range(ncores):
for k in range(chain_per_core):
# ch = j * chain_per_core + k
temp_dict = return_dict[j][k]
naccept += temp_dict.naccept
pars[:, index_total, counter_chain] = np.squeeze(temp_dict.pars)
chi2[index_total, counter_chain] = temp_dict.chi2
det_chain[index_total, counter_chain] = temp_dict.det
counter_chain += 1
step_success = True
except KeyError as e:
step_retry_count += 1
print(f"KeyError at step {index_total}, core/chain key {e}, "
f"attempt {step_retry_count}/{max_step_retries}")
if step_retry_count < max_step_retries:
print(f"Re-running processes in 2 seconds...")
import time
time.sleep(2)
else:
print(f"Error on process {model_obj.retrieval_data.id_process}, "
f"max retries reached")
print(traceback.format_exc())
exit()
except Exception as e:
print(f"Error on process {model_obj.retrieval_data.id_process}, "
f"Exception: {e}")
print(traceback.format_exc())
exit()
index_position_python = index_total + 1
# Save intermediate results every 10 steps
if counter == 10:
_save_intermediate_results(
chain_file, prob_file, partial_prob_and_chain,
pars, chi2, det_chain, nfit, index_position_python
)
counter = 0
# Check convergence and burn-in at specified intervals
if i == nextrecalc:
converged, gelmanrubin, tz, burnndx = _check_convergence_and_burnin(
pars, chi2, nfit, nchains, index_position_python
)
# Update convergence tracking
nextrecalc, npass, nstop, should_break = _update_convergence_tracking(
converged, nstop, i, npass, dontstop, maxsteps,
output_file, gelmanrubin, tz
)
if should_break:
break
# Calculate and display progress information
acceptancerate = str(float(naccept / (index_position_python * nchains * nthin)) * 100)
timeleft = (datetime.datetime.now() - t0) * (max_steps_run / (i + 1) - 1)
timeleft, units = time_left_units(timeleft.total_seconds())
# Update progress display periodically
if i % round(max_steps_run / 1000) == 0:
progress = float(100 * (i + 1) / max_steps_run)
temp_str = (
f"EXOFAST: {progress:.3}%, "
f"acceptance rate = {acceptancerate:.3}%. "
f"Time left = {timeleft:.3} {units}"
)
with open(output_file, "w") as f:
f.write(temp_str)
# Post-loop processing: exited from the for loop
if npass != 6 or dontstop == 1:
nstop = max_steps_run - 1
# Final burn-in determination
medchi2 = np.median(chi2[:nstop, :])
burnndx = 0
for j in range(nchains):
tmpndx = np.where(chi2[:nstop, j] > medchi2)[0]
if len(tmpndx) > 0:
if tmpndx[0] > burnndx:
burnndx = tmpndx[0]
burnndx = min(burnndx, maxsteps - 3)
# Final convergence check and warnings
if npass != 6:
converged, gelmanrubin, tz = exofast_gelmanrubin(
pars[0:nfit, burnndx:nstop, :]
)
bad = np.where(np.logical_or(tz < 1000, gelmanrubin > 1.01))
if len(bad[0]) > 0:
temp_str = (
f"Following parameters are not well-mixed: {bad} "
f"GELMANRUBIN BAD: {gelmanrubin[bad]} "
f"TZ_BAD: {tz[bad]}"
)
with open(output_file, "w") as f:
f.write(temp_str)
else:
temp_str = ("WARNING: The chain did not pass 6 consecutive tests "
"and may be marginally well-mixed.")
with open(output_file, "w") as f:
f.write(temp_str)
# Calculate and display final runtime statistics
runtime = datetime.datetime.now() - t0
runtime, units = time_left_units(runtime.total_seconds())
temp_str = (
f"EXOFAST_DEMC: done in {runtime:.2} {units}. "
f"Took {(index_total / maxsteps) * 100:.2}% of the steps"
)
with open(output_file, "w") as f:
f.write(temp_str)
# Finalize chains and save results
pars, chi2 = _finalize_chains_and_save(
model_obj, pars, chi2, nfit, removeburn, burnndx, nstop
)
return pars, chi2