2026-04-20 12:41:10 -04:00
import numpy as np
import matplotlib . pyplot as plt
import scipy . integrate
import pynucastro as pyna
import os
import sys
import importlib . util
import time
import matplotlib . lines as mlines
import re
import json
import argparse
from fourdst . composition import Composition
from gridfire . type import NetIn
from gridfire . engine import GraphEngine
from gridfire . solver import PointSolver , PointSolverContext
from tqdm import tqdm
from fourdst . composition . utils import buildCompositionFromMassFractions
def T9 ( age ) :
return 10.0 / np . sqrt ( age )
def get_density ( age ) :
return 4e-5 * ( T9 ( age ) * * 3 )
def get_pyna_rate ( my_rate_str , library ) :
match = re . match ( r " ([a-zA-Z0-9]+) \ (([^,]*),([^)]*) \ )(.*) " , my_rate_str )
if not match :
print ( f " Could not parse string format: { my_rate_str } " )
return [ ]
target = match . group ( 1 )
projectile = match . group ( 2 )
ejectiles = match . group ( 3 )
product = match . group ( 4 )
def expand_species ( s_str ) :
if not s_str or s_str . strip ( ) == " " :
return [ ]
parts = s_str . split ( )
expanded = [ ]
for p in parts :
if p == ' g ' :
continue
mult_match = re . match ( r " ^( \ d+)([a-zA-Z0-9]+)$ " , p )
if mult_match :
count = int ( mult_match . group ( 1 ) )
spec = mult_match . group ( 2 )
else :
count = 1
spec = p
if spec == ' g ' :
continue
if spec == ' a ' : spec = ' he4 '
expanded . extend ( [ spec ] * count )
return expanded
reactants_str = [ target ] + expand_species ( projectile )
products_str = expand_species ( ejectiles ) + [ product ]
try :
r_nuc = [ pyna . Nucleus ( r ) for r in reactants_str ]
p_nuc = [ pyna . Nucleus ( p ) for p in products_str ]
except Exception as e :
print ( f " Error converting nuclei for { my_rate_str } : { e } " )
return [ ]
rates = library . get_rate_by_nuclei ( r_nuc , p_nuc )
if rates :
if not isinstance ( rates , list ) :
return [ rates ]
return rates
r_nuc_names = sorted ( [ str ( n ) for n in r_nuc ] )
p_nuc_names = sorted ( [ str ( n ) for n in p_nuc ] )
ignore_list = [ ' e- ' , ' e+ ' , ' g ' , ' nu ' , ' anu ' ]
matched_rates = [ ]
for rate in library . get_rates ( ) :
lib_r_names = sorted ( [ str ( n ) for n in rate . reactants if str ( n ) not in ignore_list ] )
lib_p_names = sorted ( [ str ( n ) for n in rate . products if str ( n ) not in ignore_list ] )
if r_nuc_names == lib_r_names and p_nuc_names == lib_p_names :
matched_rates . append ( rate )
return matched_rates
def load_network_module ( filepath ) :
module_name = os . path . basename ( filepath ) . replace ( " .py " , " " )
if module_name in sys . modules :
del sys . modules [ module_name ]
spec = importlib . util . spec_from_file_location ( module_name , filepath )
if spec is None :
raise FileNotFoundError ( f " Error: could not find module at { filepath } " )
network_module = importlib . util . module_from_spec ( spec )
sys . modules [ module_name ] = network_module
spec . loader . exec_module ( network_module )
return network_module
def main ( args ) :
tMax = 3600.0
h = 0.01
current_time = 180.0
XpXn = 7.17
Xn = 1.0 / ( 1.0 + XpXn )
Xp = 1.0 - Xn
comp : Composition = buildCompositionFromMassFractions ( [ " H-1 " , " n-1 " ] , [ Xp , Xn ] )
netIn = NetIn ( )
netIn . composition = comp
netIn . dt0 = 1e-12
if args . depth is not None :
print ( f " Initializing GridFire GraphEngine with restricted depth = { args . depth } " )
engine = GraphEngine ( comp , args . depth )
else :
print ( " Initializing full-depth GridFire GraphEngine (Note: pynucastro may take a long time to run JIT, set NUMBA_DISABLE_JIT=1 as an eviromental variable to disable JIT, this makes per timestep time increase but may still be faster for large networks due to the lack of upfront compilation time) " )
engine = GraphEngine ( comp )
blob = engine . constructStateBlob ( )
solver_ctx = PointSolverContext ( blob )
solver_ctx . stdout_logging = False
solver = PointSolver ( engine )
gf_initial_Y = { }
for sp in engine . getNetworkSpecies ( solver_ctx . engine_ctx ) :
if comp . contains ( sp ) :
gf_initial_Y [ sp . name ( ) ] = comp . getMolarAbundance ( sp )
else :
gf_initial_Y [ sp . name ( ) ] = 0.0
gf_time = [ ]
gf_results = { }
step_conditions = [ ]
gf_start_time = time . time ( )
gf_current_time = current_time
total_steps = int ( np . ceil ( np . log ( tMax / current_time ) / np . log ( 1 + h ) ) )
with tqdm ( total = total_steps , desc = " GridFire BBN " , unit = " step " ) as pbar :
while gf_current_time < tMax :
current_dt = h * gf_current_time
next_time = gf_current_time + current_dt
burn_temp = ( T9 ( gf_current_time ) + T9 ( next_time ) ) / 2.0 * 1e9
burn_density = ( get_density ( gf_current_time ) + get_density ( next_time ) ) / 2.0
netIn . temperature = burn_temp
netIn . density = burn_density
netIn . tMax = current_dt
netOut = solver . evaluate ( solver_ctx , netIn )
netIn . composition = netOut . composition
pbar . update ( 1 )
pbar . set_postfix ( t = f " { gf_current_time : .2e } " , T = f " { burn_temp : .2e } " , rho = f " { burn_density : .2e } " )
step_conditions . append ( {
" dt " : current_dt ,
" T " : burn_temp ,
" rho " : burn_density ,
" t " : gf_current_time
} )
gf_time . append ( gf_current_time )
for sp in engine . getNetworkSpecies ( solver_ctx . engine_ctx ) :
name = sp . name ( )
if name not in gf_results :
gf_results [ name ] = [ ]
gf_results [ name ] . append ( netOut . composition . getMolarAbundance ( sp ) )
gf_current_time + = current_dt
gf_end_time = time . time ( )
print ( f " GridFire integration finished in { gf_end_time - gf_start_time : .4f } seconds. " )
print ( " Building Pynucastro BBN Network... " )
reaclib_library = pyna . ReacLibLibrary ( )
rate_names = [ r . id ( ) . replace ( " e+ " , " " ) . replace ( " e- " , " " ) . replace ( " , " , " , " ) for r in engine . getNetworkReactions ( solver_ctx . engine_ctx ) ]
goodRates = [ ]
missingRates = [ ]
skipped_photo_rates = 0
pyna_rate_mapping = { }
import io
import contextlib
for r_str in rate_names :
pyna_rates_for_reaction = [ ]
with contextlib . redirect_stdout ( io . StringIO ( ) ) , contextlib . redirect_stderr ( io . StringIO ( ) ) :
try :
res = reaclib_library . get_rate_by_name ( r_str )
if res is not None :
if isinstance ( res , list ) :
pyna_rates_for_reaction . extend ( res )
else :
pyna_rates_for_reaction . append ( res )
except :
pass
if not pyna_rates_for_reaction :
res_nuc = get_pyna_rate ( r_str , reaclib_library )
if res_nuc :
if isinstance ( res_nuc , list ) :
pyna_rates_for_reaction . extend ( res_nuc )
else :
pyna_rates_for_reaction . append ( res_nuc )
if pyna_rates_for_reaction :
pyna_rate_mapping [ r_str ] = pyna_rates_for_reaction
for rate in pyna_rates_for_reaction :
if args . filter_photo :
is_photo_rate = any ( str ( r ) . lower ( ) in [ ' g ' , ' gamma ' ] for r in rate . reactants )
if is_photo_rate :
skipped_photo_rates + = 1
continue
goodRates . append ( rate )
else :
missingRates . append ( r_str )
if missingRates :
print ( f " Warning: Could not map { len ( missingRates ) } rates to Pynucastro (likely absent from default ReacLib). " )
print ( f " Missing sample: { missingRates [ : 10 ] } ... " )
if args . filter_photo :
print ( f " Info: Skipped { skipped_photo_rates } photodisintegration rates due to --filter-photo flag. " )
print ( " --- Evaluating reaction rates over all temperatures --- " )
gf_rates_history = { }
py_rates_history = { }
gf_rate_labels = { }
py_rate_labels = { }
for reaction in engine . getNetworkReactions ( solver_ctx . engine_ctx ) :
r_str = reaction . id ( ) . replace ( " e+ " , " " ) . replace ( " e- " , " " ) . replace ( " , " , " , " )
gf_rates_history [ r_str ] = [ ]
py_rates_history [ r_str ] = [ ]
try :
gf_rate_labels [ r_str ] = reaction . sources ( )
except AttributeError :
try :
gf_rate_labels [ r_str ] = reaction . sourceLabel ( )
except AttributeError :
gf_rate_labels [ r_str ] = " Unknown "
if r_str in pyna_rate_mapping :
py_rate_labels [ r_str ] = [ getattr ( pr , ' label ' , ' Unknown ' ) for pr in pyna_rate_mapping [ r_str ] ]
else :
py_rate_labels [ r_str ] = [ ]
for step in tqdm ( step_conditions , desc = " Calculating Rates " , unit = " step " ) :
T9_val = step [ " T " ] / 1e9
T_K = step [ " T " ]
for reaction in engine . getNetworkReactions ( solver_ctx . engine_ctx ) :
r_str = reaction . id ( ) . replace ( " e+ " , " " ) . replace ( " e- " , " " ) . replace ( " , " , " , " )
gf_rate_val = 0.0
try :
gf_rate_val = reaction . calculate_rate ( T9_val , 0 , [ ] )
except :
try :
gf_rate_val = reaction . calculate_rate ( T9_val , 0 , 0 , 0 , [ ] , dict ( ) )
except Exception as e :
pass
gf_rates_history [ r_str ] . append ( gf_rate_val )
py_rate_val = 0.0
if r_str in pyna_rate_mapping :
for pr in pyna_rate_mapping [ r_str ] :
py_rate_val + = pr . eval ( T_K )
py_rates_history [ r_str ] . append ( py_rate_val )
print ( " --- Rate Comparison Summary --- " )
threshold = 1e-4
mismatches = { }
for r_str in gf_rates_history :
gf_arr = np . array ( gf_rates_history [ r_str ] )
py_arr = np . array ( py_rates_history [ r_str ] )
with np . errstate ( divide = ' ignore ' , invalid = ' ignore ' ) :
denom = np . where ( py_arr != 0 , py_arr , gf_arr )
denom = np . where ( denom == 0 , 1e-30 , denom )
rel_diffs = np . abs ( gf_arr - py_arr ) / denom
max_diff = np . max ( rel_diffs )
if max_diff > threshold :
max_idx = np . argmax ( rel_diffs )
mismatches [ r_str ] = {
" max_diff " : max_diff ,
" temp " : step_conditions [ max_idx ] [ " T " ] ,
" gf_val " : gf_arr [ max_idx ] ,
" py_val " : py_arr [ max_idx ]
}
if mismatches :
print ( f " Found { len ( mismatches ) } rates with differences > { threshold : .2% } " )
for r_str , info in mismatches . items ( ) :
gf_lbl = gf_rate_labels . get ( r_str , ' Unknown ' )
py_lbl = py_rate_labels . get ( r_str , [ ] )
print ( f " { r_str : 20 } : Max Diff = { info [ ' max_diff ' ] : .2% } , at T = { info [ ' temp ' ] : .2e } K " )
print ( f " GF = { info [ ' gf_val ' ] : .4e } (Source: { gf_lbl } ) " )
print ( f " Py = { info [ ' py_val ' ] : .4e } (Sources: { py_lbl } ) " )
else :
print ( f " All rates match within the { threshold : .2% } threshold across all temperatures. " )
print ( " ------------------------------- " )
pynet = pyna . PythonNetwork ( rates = goodRates )
network_file = " pynuc_bbn_network.py "
pynet . write_network ( network_file )
net = load_network_module ( network_file )
mapping = {
" H-1 " : ( " p " , " tab:blue " ) ,
" n-1 " : ( " n " , " tab:orange " ) ,
" He-4 " : ( " he4 " , " tab:green " ) ,
" H-2 " : ( " d " , " tab:red " ) ,
" H-3 " : ( " t " , " tab:purple " ) ,
" He-3 " : ( " he3 " , " tab:brown " ) ,
" Li-7 " : ( " li7 " , " tab:pink " ) ,
" Be-7 " : ( " be7 " , " tab:gray " )
}
Y0 = np . zeros ( net . nnuc )
for i , nuc in enumerate ( pynet . get_nuclei ( ) ) :
nuc_name = str ( nuc )
gf_name = None
for gf , ( py , _ ) in mapping . items ( ) :
if py == nuc_name :
gf_name = gf
break
if not gf_name :
match = re . match ( r " ([a-zA-Z]+)( \ d+) " , nuc_name )
if match :
gf_name = f " { match . group ( 1 ) . capitalize ( ) } - { match . group ( 2 ) } "
if gf_name and gf_name in gf_initial_Y :
Y0 [ i ] = gf_initial_Y [ gf_name ]
pyna_time = [ ]
pyna_nuc_names = [ str ( n ) for n in pynet . get_nuclei ( ) ]
pyna_results = { nuc : [ ] for nuc in pyna_nuc_names }
pyna_start_time = time . time ( )
for step in tqdm ( step_conditions , unit = " step " , desc = " pynucastro Integration " ) :
sol = scipy . integrate . solve_ivp (
net . rhs ,
[ 0 , step [ " dt " ] ] ,
Y0 ,
args = ( step [ " rho " ] , step [ " T " ] ) ,
method = " Radau " ,
jac = net . jacobian ,
rtol = 1e-8 ,
atol = 1e-20
)
Y0 = sol . y [ : , - 1 ]
pyna_time . append ( step [ " t " ] )
for j in range ( net . nnuc ) :
nuc_name = str ( pynet . get_nuclei ( ) [ j ] )
if nuc_name in pyna_results :
pyna_results [ nuc_name ] . append ( Y0 [ j ] )
pyna_end_time = time . time ( )
print ( f " Pynucastro integration finished in { pyna_end_time - pyna_start_time : .4f } seconds. " )
export_data = {
" metadata " : {
" tMax " : tMax ,
" h " : h ,
" initial_time " : current_time ,
" initial_XpXn_ratio " : XpXn ,
" initial_mass_fractions " : {
" Xp " : Xp ,
" Xn " : Xn
} ,
" execution_times_seconds " : {
" gridfire " : gf_end_time - gf_start_time ,
" pynucastro " : pyna_end_time - pyna_start_time
} ,
" missing_pynucastro_rates " : missingRates ,
" skipped_photodisintegration_rates " : skipped_photo_rates if args . filter_photo else 0 ,
" rate_labels " : {
" gridfire " : gf_rate_labels ,
" pynucastro " : py_rate_labels
}
} ,
" thermodynamic_conditions " : step_conditions ,
" data " : {
" gridfire " : {
" time " : gf_time ,
" molar_abundances " : gf_results ,
" reaction_rates " : gf_rates_history
} ,
" pynucastro " : {
" time " : pyna_time ,
" molar_abundances " : pyna_results ,
" reaction_rates " : py_rates_history
}
}
}
json_out_file = " bbn_simulation_data.json "
with open ( json_out_file , " w " ) as f :
json . dump ( export_data , f , indent = 4 )
plt . style . use ( " default " )
fig , ax = plt . subplots ( figsize = ( 10 , 7 ) )
for gf_name , ( pyna_name , color ) in mapping . items ( ) :
if gf_name in gf_results :
ax . plot ( gf_time , gf_results [ gf_name ] , color = color , linestyle = " - " , linewidth = 2.5 , label = f " GF { gf_name } " )
if pyna_name in pyna_results :
ax . plot ( pyna_time , pyna_results [ pyna_name ] , color = color , linestyle = " -- " , linewidth = 1.5 , label = f " Pyna { pyna_name } " )
ax . set_xscale ( " log " )
ax . set_yscale ( " log " )
ax . set_ylim ( 1e-12 , 2 )
ax . set_xlabel ( " Time (s) " , fontsize = 14 )
ax . set_ylabel ( " Molar Abundance (Y) " , fontsize = 14 )
line_gf = mlines . Line2D ( [ ] , [ ] , color = ' black ' , linestyle = ' - ' , linewidth = 2.5 , label = ' GridFire ' )
line_py = mlines . Line2D ( [ ] , [ ] , color = ' black ' , linestyle = ' -- ' , linewidth = 1.5 , label = ' Pynucastro ' )
sp_handles = [ ]
for gf_name , ( pyna_name , color ) in mapping . items ( ) :
sp_handles . append ( mlines . Line2D ( [ ] , [ ] , color = color , linestyle = ' - ' , linewidth = 2 , label = gf_name ) )
ax . legend ( handles = [ line_gf , line_py ] + sp_handles , loc = ' center left ' , bbox_to_anchor = ( 1.02 , 0.5 ) , fontsize = 12 )
out_file = " bbn_comparison.pdf "
plt . savefig ( out_file )
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( description = " GridFire vs Pynucastro BBN Comparison " )
parser . add_argument ( " --filter-photo " , action = " store_true " ,
help = " Filter out photodisintegration (reverse) rates to mimic GridFire ' s forward-only mechanics. " )
parser . add_argument ( " --depth " , type = int , default = None ,
help = " Limit the assembly depth of GridFire ' s GraphEngine. E.g., setting ' --depth 3 ' shrinks the network size from 5000+ reactions to ~100, which reduces Pynucastro ' s Numba JIT compile time from hours to seconds. " )
args = parser . parse_args ( )
main ( args )