Source code for GUIBRUSHR.Retrieval.ExofastMCMC.exofast_demc

"""
EXOFAST DEMC (Differential Evolution Markov Chain) Implementation.

This module implements the DEMC posterior sampling algorithm for parameter estimation
in exoplanet fitting. The code maintains the exact logic and sequence of operations
from the original implementation while improving readability and documentation.

PARALLEL VERSION: The _initialize_first_chain_step function has been parallelized
using multiprocessing.Pool for improved performance on multi-core systems.
"""

import datetime
import traceback
from pathlib import Path
# import multiprocessing as mp

import numpy as np
import pickle
from GUIBRUSHR.Retrieval.ExofastMCMC.exofast_gelmanrubin import exofast_gelmanrubin


[docs] def time_left_units(timeleft): """ Convert time in seconds to the most convenient time units. This function takes a time duration in seconds and converts it to the most appropriate unit (seconds, minutes, hours, or days) for display purposes. Parameters ---------- timeleft : float Remaining time in seconds. Returns ------- tuple A tuple containing (converted_time, units_string). """ units = "seconds" # Convert to minutes if more than 60 seconds if timeleft > 60: timeleft /= 60 units = "minutes" # Convert to hours if more than 60 minutes if timeleft > 60: timeleft /= 60 units = "hours" # Convert to days if more than 24 hours if timeleft > 24: timeleft /= 24 units = "days" return timeleft, units
def _initialize_chain_files(model_obj): """ Initialize file paths for chain and probability data storage. Parameters ---------- model_obj : Model object containing retrieval data and configuration. Returns ------- tuple Tuple containing (chain_file, prob_file, partial_prob_and_chain) paths. """ chain_file = "chain_burnin.dat" prob_file = "probabilities_burnin.dat" partial_prob_and_chain = "partial_prob_and_chain_burnin.pkl" # Construct full file paths using results directory chain_file = model_obj.retrieval_data.path_results + chain_file prob_file = model_obj.retrieval_data.path_results + prob_file partial_prob_and_chain = ( model_obj.retrieval_data.path_results + partial_prob_and_chain ) return chain_file, prob_file, partial_prob_and_chain def _load_previous_chains(partial_prob_and_chain, pars, chi2, det_chain): """ Load and concatenate previous chain data when resuming sampling. Parameters ---------- partial_prob_and_chain : str Path to the pickle file containing previous chain data. pars : ndarray Current parameter array to concatenate with. chi2 : ndarray Current chi-squared array to concatenate with. det_chain : ndarray Current determinant chain array to concatenate with. Returns ------- tuple Updated (pars, chi2, det_chain, last_index_data_resumed, naccept). """ with open(partial_prob_and_chain, "rb") as f: pars_and_chain = pickle.load(f) # Handle both dictionary and legacy list formats if isinstance(pars_and_chain, dict): pars_old = pars_and_chain["parameters"] chi2_old = pars_and_chain["chi2"] det_chain_old = pars_and_chain.get("det_chain", None) else: # Legacy format pars_old = pars_and_chain[0] chi2_old = pars_and_chain[1] det_chain_old = pars_and_chain[2] if len(pars_and_chain) > 2 else None # Concatenate old and new arrays pars = np.concatenate((pars_old, pars), axis=1) chi2 = np.concatenate((chi2_old, chi2), axis=0) last_index_data_resumed = len(chi2_old[:, 0]) - 1 naccept = int(np.sum([len(np.unique(pars[0, :, i])) for i in range(len(pars[0, 0, :]))])) # Handle determinant chain data with error handling try: if det_chain_old is not None: det_chain = np.concatenate((det_chain_old, det_chain), axis=0) else: raise ValueError("No det_chain in loaded data") except Exception as e: print("Exception reached due to:", e, "\n No previous det_chain. It will be set to 1 for all elements") det_chain = np.concatenate((np.ones(np.shape(chi2_old)), det_chain), axis=0) return pars, chi2, det_chain, last_index_data_resumed, naccept def _extract_model_data_for_parallel(model_obj): """ Extract necessary data from model_obj for parallel workers. This function extracts all data needed by the parallel workers to recreate param_full objects and calculate likelihoods without access to the full model_obj. Parameters ---------- model_obj : The model object containing all necessary methods and data Returns ------- dict Dictionary containing serializable model data needed by workers """ # Extract the essential data for parameter creation and likelihood calculation model_data = { # Core parameter data 'bestpars_initial': model_obj.bestpars_data.list_bestpars_initial_value.copy(), 'scale_vector': model_obj.retrieval_data.scale_vector_params.copy(), # Add any other data your model_obj needs for create_param_full and lh_function_gib # You may need to customize this based on your specific model structure # Examples (uncomment and adapt as needed): # 'observations': getattr(model_obj, 'observations', None), # 'covariance_matrix': getattr(model_obj, 'covariance_matrix', None), # 'parameter_bounds': getattr(model_obj, 'parameter_bounds', None), # 'instrument_data': getattr(model_obj, 'instrument_data', None), # Add a reference to the model object itself if it's serializable # Otherwise, extract specific methods or data structures needed 'model_obj_ref': model_obj # This works if model_obj is pickleable } return model_data def _create_param_full_parallel(candidate_pars, model_data): """ Parallel-safe version of model_obj.create_param_full(). This function recreates the param_full object using the extracted model data, enabling parallel workers to check parameter boundaries without the full model_obj. Parameters ---------- candidate_pars : ndarray Candidate parameter values model_data : dict Dictionary containing extracted model data Returns ------- list List of parameter objects with boundary checking capability """ # If the full model_obj is available and serializable, use it directly if 'model_obj_ref' in model_data and model_data['model_obj_ref'] is not None: return model_data['model_obj_ref'].create_param_full(candidate_pars) # Otherwise, implement a standalone version based on your parameter structure # This is where you would implement the logic from your create_param_full method # without requiring the full model_obj # Example implementation (customize based on your actual parameter structure): # param_full = [] # for i, par_value in enumerate(candidate_pars): # # Create parameter objects based on your specific parameter classes # param_obj = YourParameterClass( # value=par_value, # bounds=model_data.get('parameter_bounds', {}).get(i, None), # # Add other necessary parameter attributes # ) # param_full.append(param_obj) # return param_full raise NotImplementedError( "You need to implement _create_param_full_parallel() based on your " "model_obj.create_param_full() method. Either ensure model_obj is " "serializable or implement a standalone version of parameter creation." ) def _calculate_likelihood_parallel(param_full, model_data): """ Parallel-safe version of model_obj.lh_function_gib(). This function calculates the likelihood using the extracted model data, enabling parallel workers to compute chi2 and determinant values. Parameters ---------- param_full : list List of parameter objects created by _create_param_full_parallel model_data : dict Dictionary containing extracted model data Returns ------- tuple (chi2_value, determinant_value, additional_info) - same format as lh_function_gib """ # If the full model_obj is available and serializable, use it directly if 'model_obj_ref' in model_data and model_data['model_obj_ref'] is not None: return model_data['model_obj_ref'].lh_function_gib(param_full) # Otherwise, implement a standalone version based on your likelihood calculation # This is where you would implement the logic from your lh_function_gib method # without requiring the full model_obj # Example implementation (customize based on your actual likelihood calculation): # param_values = np.array([p.value for p in param_full if p is not None]) # # # Implement your specific likelihood calculation here # chi2_val = your_chi2_calculation(param_values, model_data) # det_val = your_determinant_calculation(param_values, model_data) # additional_info = None # or whatever your lh_function_gib returns as third element # # return chi2_val, det_val, additional_info raise NotImplementedError( "You need to implement _calculate_likelihood_parallel() based on your " "model_obj.lh_function_gib() method. Either ensure model_obj is " "serializable or implement a standalone version of likelihood calculation." ) # def _initialize_single_chain_worker(kr, bestpars_initial, scale_vector, nfit, seed_base, model_data): # """ # Worker function to initialize a single chain with valid parameter values. # # This function replicates the logic of the original for loop iteration, # generating initial parameter values for one chain and ensuring they # satisfy boundary conditions through iterative sampling. # # Parameters # ---------- # kr : int # Chain index # bestpars_initial : ndarray # Initial best parameter values # scale_vector : ndarray # Parameter scaling vector # nfit : int # Number of fitted parameters # seed_base : int # Base seed for random number generation # model_data : dict # Dictionary containing model data needed for param_full creation and likelihood # # Returns # ------- # tuple # (chain_index, parameters, chi2_value, determinant_value) # """ # # Create independent RNG for this worker to avoid correlation between chains # rng = np.random.default_rng(seed=seed_base + kr) # # param_full = None # cond_params = False # # # Keep generating parameters until boundary conditions are satisfied # while not cond_params: # cond_params = True # # # Generate random normal deviates # a = rng.standard_normal(nfit) # # # Create parameter values using initial values and scaling # candidate_pars = bestpars_initial + scale_vector * a # # # Create full parameter object and check boundaries # param_full = _create_param_full_parallel(candidate_pars, model_data) # for elem in param_full: # if elem is not None: # cond_params = cond_params and elem.boundaries_check() # # # Calculate likelihood for initial parameters # chi2_val, det_val, _ = _calculate_likelihood_parallel(param_full, model_data) # # return kr, candidate_pars, chi2_val, det_val # def _initialize_first_chain_step_parallel(model_obj, nchains, nfit, pars, chi2, det_chain, n_cores=None): # """ # Initialize the first step of each chain with valid parameter values. # # This function generates initial parameter values for each chain, ensuring # they satisfy boundary conditions through iterative sampling. # # Parameters # ---------- # model_obj : # Model object containing parameter generation methods. # nchains : int # Number of chains to initialize. # nfit : int # Number of fitted parameters. # pars : ndarray # Parameter array to populate. # chi2 : ndarray # Chi-squared array to populate. # det_chain : ndarray # Determinant chain array to populate. # n_cores : int, # Number of worker processes. # """ # # # Extract necessary data from model_obj for parallel workers # bestpars_initial = model_obj.bestpars_data.list_bestpars_initial_value # scale_vector = model_obj.retrieval_data.scale_vector_params # seed_base = model_obj.random_obj.seed # # # Prepare model data for workers # model_data = _extract_model_data_for_parallel(model_obj) # # # Prepare arguments for each worker # worker_args = [ # (kr, bestpars_initial, scale_vector, nfit, seed_base, model_data) # for kr in range(nchains) # ] # # try: # # Execute in parallel using Pool # with mp.Pool(processes=n_cores) as pool: # results = pool.starmap(_initialize_single_chain_worker, worker_args) # # # Populate output arrays with results # for kr, candidate_pars, chi2_val, det_val in results: # pars[:, 0, kr] = candidate_pars # chi2[0, kr] = chi2_val # det_chain[0, kr] = det_val # # except Exception as e: # print(f"Parallel initialization failed: {e}") # print("Falling back to serial implementation...") # # Fallback to serial implementation if parallel fails # _initialize_first_chain_step_serial(model_obj, nchains, nfit, pars, chi2, det_chain) def _initialize_first_chain_step_serial(model_obj, nchains, nfit, pars, chi2, det_chain): """ Original serial implementation of chain initialization. This function maintains the original logic as a fallback when parallel processing fails or is not beneficial. Parameters ---------- model_obj : Model object containing parameter generation methods. nchains : int Number of chains to initialize. nfit : int Number of fitted parameters. pars : ndarray Parameter array to populate. chi2 : ndarray Chi-squared array to populate. det_chain : ndarray Determinant chain array to populate. """ param_full = None # Initialize each chain for kr in range(nchains): print(f"Initializing chain {kr}") cond_params = False # Keep generating parameters until boundary conditions are satisfied n_tentative = 0 while not cond_params: n_tentative += 1 cond_params = True # Generate random normal deviates a = model_obj.random_obj.rng.standard_normal(nfit) # Create parameter values using initial values and scaling pars[:, 0, kr] = ( model_obj.bestpars_data.list_bestpars_initial_value + model_obj.retrieval_data.scale_vector_params * a ) # Create full parameter object and check boundaries param_full = model_obj.create_param_full(pars[:, 0, kr]) for elem in param_full: if elem is not None: cond_params = cond_params and elem.boundaries_check() if n_tentative % 100 == 0 and not elem.boundaries_check(): print(f"Chain {kr}: tentative {n_tentative}. Param {elem.name}: {elem.value_in_retrieval}, must stay between {elem.range_min} - {elem.range_max}") # Calculate likelihood for initial parameters chi2[0, kr], det_chain[0, kr], _ = model_obj.lh_function_gib(param_full) def _save_intermediate_results(chain_file, prob_file, partial_prob_and_chain, pars, chi2, det_chain, nfit, index_position_python): """ Save intermediate chain results to files. Parameters ---------- chain_file : str Path to chain data file. prob_file : str Path to probability data file. partial_prob_and_chain : str Path to pickle file for partial results. pars : ndarray Parameter array. chi2 : ndarray Chi-squared array. det_chain : ndarray Determinant chain array. nfit : int Number of fitted parameters. index_position_python : int Current position index (Python indexing). """ # Save chain data with open(chain_file, "w") as fc: samples_temp = pars[:, :index_position_python, :].reshape((nfit, -1)) for kk in range(samples_temp.shape[1]): fc.write("%s\n" % " ".join(str(x) for x in samples_temp[:, kk])) # Save probability data with open(prob_file, "w") as fc2: lnpmflat = chi2[:index_position_python, :].reshape(-1) for kk in range(lnpmflat.shape[0]): fc2.write("%s\n" % lnpmflat[kk]) # Save partial results as pickle: saving partial burnin chains with open(partial_prob_and_chain, "wb") as f: pickle.dump({ "parameters": pars[:, :index_position_python, :], "chi2": chi2[:index_position_python, :], "det_chain": det_chain[:index_position_python, :] }, f) def _check_convergence_and_burnin( pars, chi2, nfit, nchains, index_position_python, ): """ Check convergence and determine burn-in index. Parameters ---------- pars : ndarray Parameter array. chi2 : ndarray Chi-squared array. nfit : int Number of fitted parameters. nchains : int Number of chains. index_position_python : int Current position index. Returns ------- tuple Tuple containing (converged, gelmanrubin, tz, burnndx). """ # Calculate median chi-squared for burn-in determination medchi2 = np.median(chi2[:index_position_python, :]) burnndx = 0 for jj in range(nchains): tmpndx = np.where(chi2[:index_position_python, jj] > medchi2)[0] if len(tmpndx) > 0: if tmpndx[0] > burnndx: burnndx = tmpndx[0] # allows Gelman-Rubin calculation if one chain is being problematic burnndx = min(burnndx, index_position_python - 3) # Calculate Gelman-Rubin convergence statistics converged, gelmanrubin, tz = exofast_gelmanrubin( pars[0:nfit-1, burnndx:index_position_python, :] ) return converged, gelmanrubin, tz, burnndx def _update_convergence_tracking(converged, nstop, i, npass, dontstop, maxsteps, output_file, gelmanrubin, tz): """ Update convergence tracking variables and determine next recalculation step. Parameters ---------- converged : int Convergence flag (1 if converged, 0 otherwise). nstop : int Step at which convergence was first achieved. i : int Current step index. npass : int Number of consecutive convergence passes. dontstop : bool Flag to continue even after convergence. maxsteps : int Maximum number of steps. output_file : str Path to output file for status messages. gelmanrubin : ndarray Gelman-Rubin statistics. tz : ndarray Independent draws statistics. Returns ------- tuple Updated (nextrecalc, npass, nstop, should_break). """ should_break = False if converged == 1: if nstop == 0: nstop = i nextrecalc = int(nstop / (1 - npass / 100)) npass += 1 if npass == 6: if dontstop == 0: temp_str = f"Has converged: {converged} {gelmanrubin} {tz}" with open(output_file, "w") as f: f.write(temp_str) should_break = True nextrecalc = maxsteps else: nextrecalc = int(i / 0.9) nstop = 0 npass = 1 return nextrecalc, npass, nstop, should_break def _finalize_chains_and_save(model_obj, pars, chi2, nfit, removeburn, burnndx, nstop): """ Finalize chains by removing burn-in and save final results. Parameters ---------- model_obj : Model object containing configuration. pars : ndarray Parameter array. chi2 : ndarray Chi-squared array. nfit : int Number of fitted parameters. removeburn : bool Whether to remove burn-in period from final results. burnndx : int Burn-in index. nstop : int Final step index. Returns ------- tuple Final (pars, chi2) arrays. """ # Remove burn-in period if requested if removeburn: pars = pars[:, burnndx:nstop, :] chi2 = chi2[burnndx:nstop, :] else: pars = pars[:, 0:nstop, :] chi2 = chi2[0:nstop, :] # Set up final output file paths chain_file = "chain.dat" prob_file = "probabilities.dat" partial_prob_and_chain = "partial_prob_and_chain.pkl" chain_file = model_obj.retrieval_data.path_results + chain_file prob_file = model_obj.retrieval_data.path_results + prob_file partial_prob_and_chain = str(Path(model_obj.retrieval_data.path_results, partial_prob_and_chain)) # Save final chain data with open(chain_file, "w") as fc: samples_temp = pars[:, :, :].reshape((nfit, -1)) for kk in range(samples_temp.shape[1]): fc.write("%s\n" % " ".join(str(x) for x in samples_temp[:, kk])) # Save final probability data with open(prob_file, "w") as fc2: lnpmflat = chi2[:, :].reshape(-1) for kk in range(lnpmflat.shape[0]): fc2.write("%s\n" % lnpmflat[kk]) # Save final partial results: saving partial burnin chains with open(partial_prob_and_chain, "wb") as f: pickle.dump({ "parameters": pars, "chi2": chi2 }, f) return pars, chi2
[docs] def likelihood(model_obj, dontstop=False, nthin=1, removeburn=True, moresteps=False): """ Run a DEMC (Differential Evolution Markov Chain) posterior sampling. This function implements the main DEMC algorithm for Bayesian parameter estimation. It runs multiple chains in parallel, monitors convergence using Gelman-Rubin statistics, and saves intermediate results. Parameters ---------- model_obj : Model object containing all necessary data and methods for sampling. dontstop : bool, optional If True, continue sampling even after convergence. Default is False. nthin : int, optional Thinning factor for the chains. Default is 1. removeburn : bool, optional If True, remove burn-in period from final results. Default is True. moresteps : bool, optional If True, resume from previous sampling run. Default is False. Returns ------- tuple Final parameter chains and chi-squared values as (pars, chi2). """ # Initialize file paths for data storage chain_file, prob_file, partial_prob_and_chain = _initialize_chain_files(model_obj) # Extract configuration parameters from model object nfit = model_obj.bestpars_data.nfit max_steps_run = model_obj.retrieval_data.maxsteps nchains = int(model_obj.bestpars_data.nchains) ncores = int(model_obj.bestpars_data.ncores) # use_pool = model_obj.bestpars_data.use_pool # use_parallel_init = model_obj.bestpars_data.use_parallel_init # Initialize arrays for parameters, chi-squared, and determinant chain pars = np.zeros([nfit, max_steps_run, nchains]) chi2 = np.zeros([max_steps_run, nchains]) det_chain = np.zeros([max_steps_run, nchains]) # Record start time for progress tracking t0 = datetime.datetime.now() # Handle resuming from previous run or starting fresh if moresteps: # Load and concatenate previous chain data pars, chi2, det_chain, last_index_data_resumed, naccept = _load_previous_chains( partial_prob_and_chain, pars, chi2, det_chain ) else: # Initialize for fresh start naccept = 1 last_index_data_resumed = 0 # Set up output file and display initial status output_file = model_obj.retrieval_data.table_output_file maxsteps = max_steps_run + last_index_data_resumed print(f"\nStart at step {last_index_data_resumed} of {maxsteps}") print(f"{naccept} accepted") print(f"{ncores} cores") print(f"{nchains} chains\n") acceptancerate = "0" with open(output_file, "w") as f: f.write("First newpars creation") # Initialize first step of chains if not resuming # Note: This section is parallelizable if not moresteps: # if use_parallel_init: # print("Using parallel chain initialization...") # _initialize_first_chain_step_parallel( # model_obj, nchains, nfit, pars, chi2, det_chain, ncores # ) # else: _initialize_first_chain_step_serial(model_obj, nchains, nfit, pars, chi2, det_chain) # Initialize convergence tracking variables nextrecalc = 1000 npass = 1 nstop = 0 with open(output_file, "w") as f: f.write("Starting chains") # Main sampling loop counter = 0 index_total = last_index_data_resumed for i in range(1, int(max_steps_run)): index_total = i + last_index_data_resumed counter += 1 # Run parallel processes for current step with retry logic max_step_retries = 5 step_retry_count = 0 step_success = False while step_retry_count < max_step_retries and not step_success: try: # Run parallel processes for current step return_dict = model_obj.run_multiple_processes( pars[:, index_total-1, :], chi2[index_total-1, :], det_chain[index_total-1, :], # use_pool ) # Collect results from parallel processes counter_chain = 0 chain_per_core = int(nchains / ncores) for j in range(ncores): for k in range(chain_per_core): # ch = j * chain_per_core + k temp_dict = return_dict[j][k] naccept += temp_dict.naccept pars[:, index_total, counter_chain] = np.squeeze(temp_dict.pars) chi2[index_total, counter_chain] = temp_dict.chi2 det_chain[index_total, counter_chain] = temp_dict.det counter_chain += 1 step_success = True except KeyError as e: step_retry_count += 1 print(f"KeyError at step {index_total}, core/chain key {e}, " f"attempt {step_retry_count}/{max_step_retries}") if step_retry_count < max_step_retries: print(f"Re-running processes in 2 seconds...") import time time.sleep(2) else: print(f"Error on process {model_obj.retrieval_data.id_process}, " f"max retries reached") print(traceback.format_exc()) exit() except Exception as e: print(f"Error on process {model_obj.retrieval_data.id_process}, " f"Exception: {e}") print(traceback.format_exc()) exit() index_position_python = index_total + 1 # Save intermediate results every 10 steps if counter == 10: _save_intermediate_results( chain_file, prob_file, partial_prob_and_chain, pars, chi2, det_chain, nfit, index_position_python ) counter = 0 # Check convergence and burn-in at specified intervals if i == nextrecalc: converged, gelmanrubin, tz, burnndx = _check_convergence_and_burnin( pars, chi2, nfit, nchains, index_position_python ) # Update convergence tracking nextrecalc, npass, nstop, should_break = _update_convergence_tracking( converged, nstop, i, npass, dontstop, maxsteps, output_file, gelmanrubin, tz ) if should_break: break # Calculate and display progress information acceptancerate = str(float(naccept / (index_position_python * nchains * nthin)) * 100) timeleft = (datetime.datetime.now() - t0) * (max_steps_run / (i + 1) - 1) timeleft, units = time_left_units(timeleft.total_seconds()) # Update progress display periodically if i % round(max_steps_run / 1000) == 0: progress = float(100 * (i + 1) / max_steps_run) temp_str = ( f"EXOFAST: {progress:.3}%, " f"acceptance rate = {acceptancerate:.3}%. " f"Time left = {timeleft:.3} {units}" ) with open(output_file, "w") as f: f.write(temp_str) # Post-loop processing: exited from the for loop if npass != 6 or dontstop == 1: nstop = max_steps_run - 1 # Final burn-in determination medchi2 = np.median(chi2[:nstop, :]) burnndx = 0 for j in range(nchains): tmpndx = np.where(chi2[:nstop, j] > medchi2)[0] if len(tmpndx) > 0: if tmpndx[0] > burnndx: burnndx = tmpndx[0] burnndx = min(burnndx, maxsteps - 3) # Final convergence check and warnings if npass != 6: converged, gelmanrubin, tz = exofast_gelmanrubin( pars[0:nfit, burnndx:nstop, :] ) bad = np.where(np.logical_or(tz < 1000, gelmanrubin > 1.01)) if len(bad[0]) > 0: temp_str = ( f"Following parameters are not well-mixed: {bad} " f"GELMANRUBIN BAD: {gelmanrubin[bad]} " f"TZ_BAD: {tz[bad]}" ) with open(output_file, "w") as f: f.write(temp_str) else: temp_str = ("WARNING: The chain did not pass 6 consecutive tests " "and may be marginally well-mixed.") with open(output_file, "w") as f: f.write(temp_str) # Calculate and display final runtime statistics runtime = datetime.datetime.now() - t0 runtime, units = time_left_units(runtime.total_seconds()) temp_str = ( f"EXOFAST_DEMC: done in {runtime:.2} {units}. " f"Took {(index_total / maxsteps) * 100:.2}% of the steps" ) with open(output_file, "w") as f: f.write(temp_str) # Finalize chains and save results pars, chi2 = _finalize_chains_and_save( model_obj, pars, chi2, nfit, removeburn, burnndx, nstop ) return pars, chi2