242 lines
8.4 KiB
Python
242 lines
8.4 KiB
Python
|
|
from abc import ABC, abstractmethod
|
||
|
|
|
||
|
|
import fourdst.atomic
|
||
|
|
import scipy.integrate
|
||
|
|
from fourdst.composition import Composition
|
||
|
|
from gridfire.engine import DynamicEngine, GraphEngine
|
||
|
|
from gridfire.type import NetIn, NetOut
|
||
|
|
from gridfire.exceptions import GridFireError
|
||
|
|
from gridfire.solver import CVODESolverStrategy
|
||
|
|
from logger import StepLogger
|
||
|
|
from typing import List
|
||
|
|
import re
|
||
|
|
from typing import Dict, Tuple, Any, Union
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
import pynucastro as pyna
|
||
|
|
import os
|
||
|
|
import importlib.util
|
||
|
|
import sys
|
||
|
|
import numpy as np
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
|
||
|
|
def load_network_module(filepath):
|
||
|
|
module_name = os.path.basename(filepath).replace(".py", "")
|
||
|
|
if module_name in sys.modules: # clear any existing module with the same name
|
||
|
|
del sys.modules[module_name]
|
||
|
|
|
||
|
|
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
||
|
|
if spec is None:
|
||
|
|
raise FileNotFoundError(f"Could not find module at {filepath}")
|
||
|
|
|
||
|
|
network_module = importlib.util.module_from_spec(spec)
|
||
|
|
|
||
|
|
sys.modules[module_name] = network_module
|
||
|
|
|
||
|
|
spec.loader.exec_module(network_module)
|
||
|
|
|
||
|
|
return network_module
|
||
|
|
|
||
|
|
def get_pyna_rate(my_rate_str, library):
|
||
|
|
match = re.match(r"([a-zA-Z0-9]+)\(([^,]+),([^)]*)\)(.*)", my_rate_str)
|
||
|
|
|
||
|
|
if not match:
|
||
|
|
print(f"Could not parse string format: {my_rate_str}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
target = match.group(1)
|
||
|
|
projectile = match.group(2)
|
||
|
|
ejectiles = match.group(3)
|
||
|
|
product = match.group(4)
|
||
|
|
|
||
|
|
def expand_species(s_str):
|
||
|
|
if not s_str or s_str.strip() == "":
|
||
|
|
return []
|
||
|
|
|
||
|
|
# Split by space (handling "p a" or "2p a")
|
||
|
|
parts = s_str.split()
|
||
|
|
expanded = []
|
||
|
|
|
||
|
|
for p in parts:
|
||
|
|
# Check for multipliers like 2p, 3a
|
||
|
|
mult_match = re.match(r"(\d+)([a-zA-Z0-9]+)", p)
|
||
|
|
if mult_match:
|
||
|
|
count = int(mult_match.group(1))
|
||
|
|
spec = mult_match.group(2)
|
||
|
|
# Map common aliases if necessary (though pyna handles most)
|
||
|
|
if spec == 'a': spec = 'he4'
|
||
|
|
expanded.extend([spec] * count)
|
||
|
|
else:
|
||
|
|
spec = p
|
||
|
|
if spec == 'a': spec = 'he4'
|
||
|
|
expanded.append(spec)
|
||
|
|
return expanded
|
||
|
|
|
||
|
|
reactants_str = [target] + expand_species(projectile)
|
||
|
|
products_str = expand_species(ejectiles) + [product]
|
||
|
|
|
||
|
|
# Convert strings to pyna.Nucleus objects
|
||
|
|
try:
|
||
|
|
r_nuc = [pyna.Nucleus(r) for r in reactants_str]
|
||
|
|
p_nuc = [pyna.Nucleus(p) for p in products_str]
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error converting nuclei for {my_rate_str}: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
rates = library.get_rate_by_nuclei(r_nuc, p_nuc)
|
||
|
|
|
||
|
|
if rates:
|
||
|
|
if isinstance(rates, list):
|
||
|
|
return rates[0] # Return the first match
|
||
|
|
return rates
|
||
|
|
else:
|
||
|
|
return None
|
||
|
|
|
||
|
|
class TestSuite(ABC):
|
||
|
|
def __init__(self, name: str, description: str, temp: float, density: float, tMax: float, composition: Composition, notes: str = ""):
|
||
|
|
self.name : str = name
|
||
|
|
self.description : str = description
|
||
|
|
self.temperature : float = temp
|
||
|
|
self.density : float = density
|
||
|
|
self.tMax : float = tMax
|
||
|
|
self.composition : Composition = composition
|
||
|
|
self.notes : str = notes
|
||
|
|
|
||
|
|
def evolve_pynucastro(self, engine: GraphEngine):
|
||
|
|
print("Evolution complete. Now building equivalent pynucastro network...")
|
||
|
|
# Build equivalent pynucastro network for comparison
|
||
|
|
reaclib_library : pyna.ReacLibLibrary = pyna.ReacLibLibrary()
|
||
|
|
rate_names = [r.id().replace("e+","").replace("e-","").replace(", ", ",") for r in engine.getNetworkReactions()]
|
||
|
|
|
||
|
|
goodRates : List[pyna.rates.reaclib_rate.ReacLibRate] = []
|
||
|
|
missingRates = []
|
||
|
|
|
||
|
|
for r_str in rate_names:
|
||
|
|
# Try the exact name match first (fastest)
|
||
|
|
try:
|
||
|
|
pyna_rate = reaclib_library.get_rate_by_name(r_str)
|
||
|
|
if isinstance(pyna_rate, list):
|
||
|
|
goodRates.append(pyna_rate[0])
|
||
|
|
else:
|
||
|
|
goodRates.append(pyna_rate)
|
||
|
|
except:
|
||
|
|
# Fallback to the smart parser
|
||
|
|
pyna_rate = get_pyna_rate(r_str, reaclib_library)
|
||
|
|
if pyna_rate:
|
||
|
|
goodRates.append(pyna_rate)
|
||
|
|
else:
|
||
|
|
missingRates.append(r_str)
|
||
|
|
|
||
|
|
|
||
|
|
pynet : pyna.PythonNetwork = pyna.PythonNetwork(rates=goodRates)
|
||
|
|
pynet.write_network(f"{self.name}_pynucastro_network.py")
|
||
|
|
|
||
|
|
net = load_network_module(f"{self.name}_pynucastro_network.py")
|
||
|
|
Y0 = np.zeros(net.nnuc)
|
||
|
|
Y0[net.jp] = self.composition.getMolarAbundance("H-1")
|
||
|
|
Y0[net.jhe3] = self.composition.getMolarAbundance("He-3")
|
||
|
|
Y0[net.jhe4] = self.composition.getMolarAbundance("He-4")
|
||
|
|
Y0[net.jc12] = self.composition.getMolarAbundance("C-12")
|
||
|
|
Y0[net.jn14] = self.composition.getMolarAbundance("N-14")
|
||
|
|
Y0[net.jo16] = self.composition.getMolarAbundance("O-16")
|
||
|
|
Y0[net.jne20] = self.composition.getMolarAbundance("Ne-20")
|
||
|
|
Y0[net.jmg24] = self.composition.getMolarAbundance("Mg-24")
|
||
|
|
|
||
|
|
print("Starting pynucastro integration...")
|
||
|
|
startTime = time.time()
|
||
|
|
sol = scipy.integrate.solve_ivp(
|
||
|
|
net.rhs,
|
||
|
|
[0, self.tMax],
|
||
|
|
Y0,
|
||
|
|
args=(self.density, self.temperature),
|
||
|
|
method="BDF",
|
||
|
|
jac=net.jacobian,
|
||
|
|
rtol=1e-5,
|
||
|
|
atol=1e-8
|
||
|
|
)
|
||
|
|
endTime = time.time()
|
||
|
|
print("Pynucastro integration complete. Writing results to JSON...")
|
||
|
|
|
||
|
|
data: List[Dict[str, Union[float, Dict[str, float]]]] = []
|
||
|
|
|
||
|
|
for time_step, t in enumerate(sol.t):
|
||
|
|
data.append({"t": t, "Composition": {}})
|
||
|
|
for j in range(net.nnuc):
|
||
|
|
A = net.A[j]
|
||
|
|
Z = net.Z[j]
|
||
|
|
species: str
|
||
|
|
try:
|
||
|
|
species = fourdst.atomic.az_to_species(A, Z).name()
|
||
|
|
except:
|
||
|
|
species = f"SP-A_{A}_Z_{Z}"
|
||
|
|
|
||
|
|
data[-1]["Composition"][species] = sol.y[j, time_step]
|
||
|
|
|
||
|
|
pynucastro_json : Dict[str, Any] = {
|
||
|
|
"Metadata": {
|
||
|
|
"Name": f"{self.name}_pynucastro",
|
||
|
|
"Description": f"pynucastro simulation equivalent to GridFire validation suite: {self.description}",
|
||
|
|
"Status": "Success",
|
||
|
|
"Notes": self.notes,
|
||
|
|
"Temperature": self.temperature,
|
||
|
|
"Density": self.density,
|
||
|
|
"tMax": self.tMax,
|
||
|
|
"ElapsedTime": endTime - startTime,
|
||
|
|
"DateCreated": datetime.now().isoformat()
|
||
|
|
},
|
||
|
|
"Steps": data
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(f"GridFireValidationSuite_{self.name}_pynucastro.json", "w") as f:
|
||
|
|
json.dump(pynucastro_json, f, indent=4)
|
||
|
|
|
||
|
|
def evolve(self, engine: GraphEngine, netIn: NetIn, pynucastro_compare: bool = True):
|
||
|
|
solver : CVODESolverStrategy = CVODESolverStrategy(engine)
|
||
|
|
|
||
|
|
stepLogger : StepLogger = StepLogger()
|
||
|
|
solver.set_callback(lambda ctx: stepLogger.log_step(ctx))
|
||
|
|
|
||
|
|
startTime = time.time()
|
||
|
|
try:
|
||
|
|
startTime = time.time()
|
||
|
|
netOut : NetOut = solver.evaluate(netIn)
|
||
|
|
endTime = time.time()
|
||
|
|
stepLogger.to_json(
|
||
|
|
f"GridFireValidationSuite_{self.name}_OKAY.json",
|
||
|
|
Name = f"{self.name}_Success",
|
||
|
|
Description=self.description,
|
||
|
|
Status="Success",
|
||
|
|
Notes=self.notes,
|
||
|
|
Temperature=netIn.temperature,
|
||
|
|
Density=netIn.density,
|
||
|
|
tMax=netIn.tMax,
|
||
|
|
FinalEps = netOut.energy,
|
||
|
|
FinaldEpsdT = netOut.dEps_dT,
|
||
|
|
FinaldEpsdRho = netOut.dEps_dRho,
|
||
|
|
ElapsedTime = endTime - startTime
|
||
|
|
)
|
||
|
|
except GridFireError as e:
|
||
|
|
endTime = time.time()
|
||
|
|
stepLogger.to_json(
|
||
|
|
f"GridFireValidationSuite_{self.name}_FAIL.json",
|
||
|
|
Name = f"{self.name}_Failure",
|
||
|
|
Description=self.description,
|
||
|
|
Status=f"Error",
|
||
|
|
ErrorMessage=str(e),
|
||
|
|
Notes=self.notes,
|
||
|
|
Temperature=netIn.temperature,
|
||
|
|
Density=netIn.density,
|
||
|
|
tMax=netIn.tMax,
|
||
|
|
ElapsedTime = endTime - startTime
|
||
|
|
)
|
||
|
|
|
||
|
|
if pynucastro_compare:
|
||
|
|
self.evolve_pynucastro(engine)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def __call__(self):
|
||
|
|
pass
|