Files
GridFire/validation/vv/testsuite.py

242 lines
8.4 KiB
Python
Raw Permalink Normal View History

from abc import ABC, abstractmethod
import fourdst.atomic
import scipy.integrate
from fourdst.composition import Composition
from gridfire.engine import DynamicEngine, GraphEngine
from gridfire.type import NetIn, NetOut
from gridfire.exceptions import GridFireError
from gridfire.solver import CVODESolverStrategy
from logger import StepLogger
from typing import List
import re
from typing import Dict, Tuple, Any, Union
from datetime import datetime
import pynucastro as pyna
import os
import importlib.util
import sys
import numpy as np
import json
import time
def load_network_module(filepath):
module_name = os.path.basename(filepath).replace(".py", "")
if module_name in sys.modules: # clear any existing module with the same name
del sys.modules[module_name]
spec = importlib.util.spec_from_file_location(module_name, filepath)
if spec is None:
raise FileNotFoundError(f"Could not find module at {filepath}")
network_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = network_module
spec.loader.exec_module(network_module)
return network_module
def get_pyna_rate(my_rate_str, library):
match = re.match(r"([a-zA-Z0-9]+)\(([^,]+),([^)]*)\)(.*)", my_rate_str)
if not match:
print(f"Could not parse string format: {my_rate_str}")
return None
target = match.group(1)
projectile = match.group(2)
ejectiles = match.group(3)
product = match.group(4)
def expand_species(s_str):
if not s_str or s_str.strip() == "":
return []
# Split by space (handling "p a" or "2p a")
parts = s_str.split()
expanded = []
for p in parts:
# Check for multipliers like 2p, 3a
mult_match = re.match(r"(\d+)([a-zA-Z0-9]+)", p)
if mult_match:
count = int(mult_match.group(1))
spec = mult_match.group(2)
# Map common aliases if necessary (though pyna handles most)
if spec == 'a': spec = 'he4'
expanded.extend([spec] * count)
else:
spec = p
if spec == 'a': spec = 'he4'
expanded.append(spec)
return expanded
reactants_str = [target] + expand_species(projectile)
products_str = expand_species(ejectiles) + [product]
# Convert strings to pyna.Nucleus objects
try:
r_nuc = [pyna.Nucleus(r) for r in reactants_str]
p_nuc = [pyna.Nucleus(p) for p in products_str]
except Exception as e:
print(f"Error converting nuclei for {my_rate_str}: {e}")
return None
rates = library.get_rate_by_nuclei(r_nuc, p_nuc)
if rates:
if isinstance(rates, list):
return rates[0] # Return the first match
return rates
else:
return None
class TestSuite(ABC):
def __init__(self, name: str, description: str, temp: float, density: float, tMax: float, composition: Composition, notes: str = ""):
self.name : str = name
self.description : str = description
self.temperature : float = temp
self.density : float = density
self.tMax : float = tMax
self.composition : Composition = composition
self.notes : str = notes
def evolve_pynucastro(self, engine: GraphEngine):
print("Evolution complete. Now building equivalent pynucastro network...")
# Build equivalent pynucastro network for comparison
reaclib_library : pyna.ReacLibLibrary = pyna.ReacLibLibrary()
rate_names = [r.id().replace("e+","").replace("e-","").replace(", ", ",") for r in engine.getNetworkReactions()]
goodRates : List[pyna.rates.reaclib_rate.ReacLibRate] = []
missingRates = []
for r_str in rate_names:
# Try the exact name match first (fastest)
try:
pyna_rate = reaclib_library.get_rate_by_name(r_str)
if isinstance(pyna_rate, list):
goodRates.append(pyna_rate[0])
else:
goodRates.append(pyna_rate)
except:
# Fallback to the smart parser
pyna_rate = get_pyna_rate(r_str, reaclib_library)
if pyna_rate:
goodRates.append(pyna_rate)
else:
missingRates.append(r_str)
pynet : pyna.PythonNetwork = pyna.PythonNetwork(rates=goodRates)
pynet.write_network(f"{self.name}_pynucastro_network.py")
net = load_network_module(f"{self.name}_pynucastro_network.py")
Y0 = np.zeros(net.nnuc)
Y0[net.jp] = self.composition.getMolarAbundance("H-1")
Y0[net.jhe3] = self.composition.getMolarAbundance("He-3")
Y0[net.jhe4] = self.composition.getMolarAbundance("He-4")
Y0[net.jc12] = self.composition.getMolarAbundance("C-12")
Y0[net.jn14] = self.composition.getMolarAbundance("N-14")
Y0[net.jo16] = self.composition.getMolarAbundance("O-16")
Y0[net.jne20] = self.composition.getMolarAbundance("Ne-20")
Y0[net.jmg24] = self.composition.getMolarAbundance("Mg-24")
print("Starting pynucastro integration...")
startTime = time.time()
sol = scipy.integrate.solve_ivp(
net.rhs,
[0, self.tMax],
Y0,
args=(self.density, self.temperature),
method="BDF",
jac=net.jacobian,
rtol=1e-5,
atol=1e-8
)
endTime = time.time()
print("Pynucastro integration complete. Writing results to JSON...")
data: List[Dict[str, Union[float, Dict[str, float]]]] = []
for time_step, t in enumerate(sol.t):
data.append({"t": t, "Composition": {}})
for j in range(net.nnuc):
A = net.A[j]
Z = net.Z[j]
species: str
try:
species = fourdst.atomic.az_to_species(A, Z).name()
except:
species = f"SP-A_{A}_Z_{Z}"
data[-1]["Composition"][species] = sol.y[j, time_step]
pynucastro_json : Dict[str, Any] = {
"Metadata": {
"Name": f"{self.name}_pynucastro",
"Description": f"pynucastro simulation equivalent to GridFire validation suite: {self.description}",
"Status": "Success",
"Notes": self.notes,
"Temperature": self.temperature,
"Density": self.density,
"tMax": self.tMax,
"ElapsedTime": endTime - startTime,
"DateCreated": datetime.now().isoformat()
},
"Steps": data
}
with open(f"GridFireValidationSuite_{self.name}_pynucastro.json", "w") as f:
json.dump(pynucastro_json, f, indent=4)
def evolve(self, engine: GraphEngine, netIn: NetIn, pynucastro_compare: bool = True):
solver : CVODESolverStrategy = CVODESolverStrategy(engine)
stepLogger : StepLogger = StepLogger()
solver.set_callback(lambda ctx: stepLogger.log_step(ctx))
startTime = time.time()
try:
startTime = time.time()
netOut : NetOut = solver.evaluate(netIn)
endTime = time.time()
stepLogger.to_json(
f"GridFireValidationSuite_{self.name}_OKAY.json",
Name = f"{self.name}_Success",
Description=self.description,
Status="Success",
Notes=self.notes,
Temperature=netIn.temperature,
Density=netIn.density,
tMax=netIn.tMax,
FinalEps = netOut.energy,
FinaldEpsdT = netOut.dEps_dT,
FinaldEpsdRho = netOut.dEps_dRho,
ElapsedTime = endTime - startTime
)
except GridFireError as e:
endTime = time.time()
stepLogger.to_json(
f"GridFireValidationSuite_{self.name}_FAIL.json",
Name = f"{self.name}_Failure",
Description=self.description,
Status=f"Error",
ErrorMessage=str(e),
Notes=self.notes,
Temperature=netIn.temperature,
Density=netIn.density,
tMax=netIn.tMax,
ElapsedTime = endTime - startTime
)
if pynucastro_compare:
self.evolve_pynucastro(engine)
@abstractmethod
def __call__(self):
pass