perf(multi): Simple parallel multi zone solver
Added a simple parallel multi-zone solver
This commit is contained in:
@@ -118,8 +118,7 @@ namespace gridfire::engine {
|
||||
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
|
||||
m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)),
|
||||
m_partitionFunction(partitionFunction.clone()),
|
||||
m_depth(buildDepth),
|
||||
m_state_blob_offset(0) // For a base engine the offset is always 0
|
||||
m_depth(buildDepth)
|
||||
{
|
||||
syncInternalMaps();
|
||||
}
|
||||
@@ -128,8 +127,7 @@ namespace gridfire::engine {
|
||||
const reaction::ReactionSet &reactions
|
||||
) :
|
||||
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
|
||||
m_reactions(reactions),
|
||||
m_state_blob_offset(0)
|
||||
m_reactions(reactions)
|
||||
{
|
||||
syncInternalMaps();
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "fourdst/atomic/species.h"
|
||||
#include "fourdst/composition/utils.h"
|
||||
#include "gridfire/engine/views/engine_priming.h"
|
||||
#include "gridfire/solver/solver.h"
|
||||
|
||||
#include "gridfire/engine/engine_abstract.h"
|
||||
@@ -13,7 +12,7 @@
|
||||
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
|
||||
|
||||
#include "fourdst/logging/logging.h"
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
#include "gridfire/solver/strategies/PointSolver.h"
|
||||
#include "quill/Logger.h"
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
@@ -28,13 +27,12 @@ namespace gridfire::engine {
|
||||
const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
|
||||
) {
|
||||
const auto logger = LogManager::getInstance().getLogger("log");
|
||||
solver::CVODESolverStrategy integrator(engine, ctx);
|
||||
solver::PointSolver integrator(engine);
|
||||
solver::PointSolverContext solverCtx(ctx);
|
||||
solverCtx.abs_tol = 1e-3;
|
||||
solverCtx.rel_tol = 1e-3;
|
||||
solverCtx.stdout_logging = false;
|
||||
|
||||
// Do not need high precision for priming
|
||||
integrator.set_absTol(1e-3);
|
||||
integrator.set_relTol(1e-3);
|
||||
|
||||
integrator.set_stdout_logging_enabled(false);
|
||||
NetIn solverInput(netIn);
|
||||
|
||||
solverInput.tMax = 1e-15;
|
||||
@@ -43,7 +41,7 @@ namespace gridfire::engine {
|
||||
LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax);
|
||||
PrimingReport report;
|
||||
try {
|
||||
const NetOut netOut = integrator.evaluate(solverInput, false);
|
||||
const NetOut netOut = integrator.evaluate(solverCtx, solverInput);
|
||||
LOG_INFO(logger, "Network ignition completed.");
|
||||
LOG_TRACE_L2(
|
||||
logger,
|
||||
|
||||
@@ -2005,7 +2005,32 @@ namespace gridfire::engine {
|
||||
LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.",
|
||||
fnorm, ACCEPTABLE_FTOL);
|
||||
} else {
|
||||
LOG_WARNING(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Error {}", utils::kinsol_ret_code_map.at(flag), fnorm);
|
||||
LOG_CRITICAL(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Flag No.: {}, Error (fNorm): {}", utils::kinsol_ret_code_map.at(flag), flag, fnorm);
|
||||
LOG_CRITICAL(getLogger(), "State prior to failure: {}",
|
||||
[&comp, &data]() -> std::string {
|
||||
std::ostringstream oss;
|
||||
oss << "Solve species: <";
|
||||
size_t count = 0;
|
||||
for (const auto& species : data.qse_solve_species) {
|
||||
oss << species.name();
|
||||
if (count < data.qse_solve_species.size() - 1) {
|
||||
oss << ", ";
|
||||
}
|
||||
count++;
|
||||
}
|
||||
oss << "> | Abundances and rates at failure: ";
|
||||
count = 0;
|
||||
for (const auto& [species, abundance] : comp) {
|
||||
oss << species.name() << ": Y = " << abundance;
|
||||
if (count < comp.size() - 1) {
|
||||
oss << ", ";
|
||||
}
|
||||
count++;
|
||||
}
|
||||
oss << " | Temperature: " << data.T9 << ", Density: " << data.rho;
|
||||
return oss.str();
|
||||
}()
|
||||
);
|
||||
throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag));
|
||||
}
|
||||
}
|
||||
|
||||
94
src/lib/solver/strategies/GridSolver.cpp
Normal file
94
src/lib/solver/strategies/GridSolver.cpp
Normal file
@@ -0,0 +1,94 @@
|
||||
#include "gridfire/solver/strategies/GridSolver.h"
|
||||
|
||||
#include "gridfire/exceptions/error_solver.h"
|
||||
#include "gridfire/solver/strategies/PointSolver.h"
|
||||
#include "gridfire/utils/macros.h"
|
||||
#include "gridfire/utils/gf_omp.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <print>
|
||||
|
||||
namespace gridfire::solver {
|
||||
void GridSolverContext::init() {}
|
||||
void GridSolverContext::reset() {
|
||||
solver_workspaces.clear();
|
||||
timestep_callbacks.clear();
|
||||
}
|
||||
|
||||
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback) {
|
||||
for (auto &cb : timestep_callbacks) {
|
||||
cb = callback;
|
||||
}
|
||||
}
|
||||
|
||||
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback, const size_t zone_idx) {
|
||||
if (zone_idx >= timestep_callbacks.size()) {
|
||||
throw exceptions::SolverError("GridSolverContext::set_callback: zone_idx out of range.");
|
||||
}
|
||||
timestep_callbacks[zone_idx] = callback;
|
||||
}
|
||||
|
||||
void GridSolverContext::set_stdout_logging(const bool enable) {
|
||||
zone_stdout_logging = enable;
|
||||
}
|
||||
|
||||
void GridSolverContext::set_detailed_logging(const bool enable) {
|
||||
zone_detailed_logging = enable;
|
||||
}
|
||||
|
||||
GridSolverContext::GridSolverContext(
|
||||
const engine::scratch::StateBlob &ctx_template
|
||||
) :
|
||||
ctx_template(ctx_template) {}
|
||||
|
||||
|
||||
GridSolver::GridSolver(
|
||||
const engine::DynamicEngine &engine,
|
||||
const SingleZoneDynamicNetworkSolver &solver
|
||||
) :
|
||||
MultiZoneNetworkSolver(engine, solver) {
|
||||
GF_PAR_INIT();
|
||||
}
|
||||
|
||||
std::vector<NetOut> GridSolver::evaluate(
|
||||
SolverContextBase& ctx,
|
||||
const std::vector<NetIn>& netIns
|
||||
) const {
|
||||
auto* sctx_p = dynamic_cast<GridSolverContext*>(&ctx);
|
||||
if (!sctx_p) {
|
||||
throw exceptions::SolverError("GridSolver::evaluate: SolverContextBase is not of type GridSolverContext.");
|
||||
}
|
||||
|
||||
const size_t n_zones = netIns.size();
|
||||
if (n_zones == 0) { return {}; }
|
||||
|
||||
std::vector<NetOut> results(n_zones);
|
||||
|
||||
sctx_p->solver_workspaces.resize(n_zones);
|
||||
|
||||
GF_OMP(
|
||||
parallel for default(none) shared(sctx_p, n_zones),
|
||||
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
|
||||
sctx_p->solver_workspaces[zone_idx] = std::make_unique<PointSolverContext>(sctx_p->ctx_template);
|
||||
sctx_p->solver_workspaces[zone_idx]->set_stdout_logging(sctx_p->zone_stdout_logging);
|
||||
sctx_p->solver_workspaces[zone_idx]->set_detailed_logging(sctx_p->zone_detailed_logging);
|
||||
}
|
||||
|
||||
GF_OMP(
|
||||
parallel for default(none) shared(results, sctx_p, netIns, n_zones),
|
||||
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
|
||||
try {
|
||||
results[zone_idx] = m_solver.evaluate(
|
||||
*sctx_p->solver_workspaces[zone_idx],
|
||||
netIns[zone_idx]
|
||||
);
|
||||
} catch (exceptions::GridFireError& e) {
|
||||
std::println("CVODE Solver Failure in zone {}: {}", zone_idx, e.what());
|
||||
}
|
||||
if (sctx_p->zone_completion_logging) {
|
||||
std::println("Thread {} completed zone {}", GF_OMP_THREAD_NUM, zone_idx);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
#include "gridfire/solver/strategies/PointSolver.h"
|
||||
|
||||
#include "gridfire/types/types.h"
|
||||
#include "gridfire/utils/table_format.h"
|
||||
@@ -28,7 +28,7 @@
|
||||
namespace gridfire::solver {
|
||||
using namespace gridfire::engine;
|
||||
|
||||
CVODESolverStrategy::TimestepContext::TimestepContext(
|
||||
PointSolverTimestepContext::PointSolverTimestepContext(
|
||||
const double t,
|
||||
const N_Vector &state,
|
||||
const double dt,
|
||||
@@ -58,7 +58,7 @@ namespace gridfire::solver {
|
||||
state_ctx(ctx)
|
||||
{}
|
||||
|
||||
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const {
|
||||
std::vector<std::tuple<std::string, std::string>> PointSolverTimestepContext::describe() const {
|
||||
std::vector<std::tuple<std::string, std::string>> description;
|
||||
description.emplace_back("t", "Current Time");
|
||||
description.emplace_back("state", "Current State Vector (N_Vector)");
|
||||
@@ -74,36 +74,112 @@ namespace gridfire::solver {
|
||||
return description;
|
||||
}
|
||||
|
||||
void PointSolverContext::init() {
|
||||
reset_all();
|
||||
init_context();
|
||||
}
|
||||
|
||||
CVODESolverStrategy::CVODESolverStrategy(
|
||||
const DynamicEngine &engine,
|
||||
const scratch::StateBlob& ctx
|
||||
): SingleZoneNetworkSolver<DynamicEngine>(engine, ctx) {
|
||||
// PERF: In order to support MPI this function must be changed
|
||||
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
|
||||
if (flag < 0) {
|
||||
throw std::runtime_error("Failed to create SUNDIALS context (SUNDIALS Errno: " + std::to_string(flag) + ")");
|
||||
void PointSolverContext::set_stdout_logging(const bool enable) {
|
||||
stdout_logging = enable;
|
||||
}
|
||||
|
||||
void PointSolverContext::set_detailed_logging(const bool enable) {
|
||||
detailed_step_logging = enable;
|
||||
}
|
||||
|
||||
void PointSolverContext::reset_all() {
|
||||
reset_user();
|
||||
reset_cvode();
|
||||
}
|
||||
|
||||
void PointSolverContext::reset_user() {
|
||||
callback.reset();
|
||||
num_steps = 0;
|
||||
stdout_logging = true;
|
||||
abs_tol.reset();
|
||||
rel_tol.reset();
|
||||
detailed_step_logging = false;
|
||||
last_size = 0;
|
||||
last_composition_hash = 0ULL;
|
||||
}
|
||||
|
||||
void PointSolverContext::reset_cvode() {
|
||||
if (LS) {
|
||||
SUNLinSolFree(LS);
|
||||
LS = nullptr;
|
||||
}
|
||||
if (J) {
|
||||
SUNMatDestroy(J);
|
||||
J = nullptr;
|
||||
}
|
||||
if (Y) {
|
||||
N_VDestroy(Y);
|
||||
Y = nullptr;
|
||||
}
|
||||
if (YErr) {
|
||||
N_VDestroy(YErr);
|
||||
YErr = nullptr;
|
||||
}
|
||||
if (constraints) {
|
||||
N_VDestroy(constraints);
|
||||
constraints = nullptr;
|
||||
}
|
||||
if (cvode_mem) {
|
||||
CVodeFree(&cvode_mem);
|
||||
cvode_mem = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
CVODESolverStrategy::~CVODESolverStrategy() {
|
||||
LOG_TRACE_L1(m_logger, "Cleaning up CVODE resources...");
|
||||
cleanup_cvode_resources(true);
|
||||
|
||||
if (m_sun_ctx) {
|
||||
SUNContext_Free(&m_sun_ctx);
|
||||
void PointSolverContext::clear_context() {
|
||||
if (sun_ctx) {
|
||||
SUNContext_Free(&sun_ctx);
|
||||
sun_ctx = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) {
|
||||
return evaluate(netIn, false);
|
||||
void PointSolverContext::init_context() {
|
||||
if (!sun_ctx) {
|
||||
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
|
||||
}
|
||||
}
|
||||
|
||||
NetOut CVODESolverStrategy::evaluate(
|
||||
bool PointSolverContext::has_context() const {
|
||||
return sun_ctx != nullptr;
|
||||
}
|
||||
|
||||
PointSolverContext::PointSolverContext(
|
||||
const scratch::StateBlob& engine_ctx
|
||||
) :
|
||||
engine_ctx(engine_ctx.clone_structure())
|
||||
{
|
||||
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
|
||||
}
|
||||
|
||||
PointSolverContext::~PointSolverContext() {
|
||||
reset_cvode();
|
||||
clear_context();
|
||||
}
|
||||
|
||||
|
||||
PointSolver::PointSolver(
|
||||
const DynamicEngine &engine
|
||||
): SingleZoneNetworkSolver(engine) {}
|
||||
|
||||
NetOut PointSolver::evaluate(
|
||||
SolverContextBase& solver_ctx,
|
||||
const NetIn& netIn
|
||||
) const {
|
||||
return evaluate(solver_ctx, netIn, false);
|
||||
}
|
||||
|
||||
NetOut PointSolver::evaluate(
|
||||
SolverContextBase& solver_ctx,
|
||||
const NetIn &netIn,
|
||||
bool displayTrigger,
|
||||
bool forceReinitialize
|
||||
) {
|
||||
) const {
|
||||
auto* sctx_p = dynamic_cast<PointSolverContext*>(&solver_ctx);
|
||||
|
||||
LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density);
|
||||
LOG_TRACE_L1(m_logger, "Building engine update trigger....");
|
||||
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2);
|
||||
@@ -117,23 +193,24 @@ namespace gridfire::solver {
|
||||
// 2. If the user has set tolerances in code, those override the config
|
||||
// 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults
|
||||
|
||||
auto absTol = m_config->solver.cvode.absTol;
|
||||
auto relTol = m_config->solver.cvode.relTol;
|
||||
|
||||
if (m_absTol) {
|
||||
absTol = *m_absTol;
|
||||
if (!sctx_p->abs_tol.has_value()) {
|
||||
sctx_p->abs_tol = m_config->solver.cvode.absTol;
|
||||
}
|
||||
if (m_relTol) {
|
||||
relTol = *m_relTol;
|
||||
if (!sctx_p->rel_tol.has_value()) {
|
||||
sctx_p->rel_tol = m_config->solver.cvode.relTol;
|
||||
}
|
||||
|
||||
bool resourcesExist = (m_cvode_mem != nullptr) && (m_Y != nullptr);
|
||||
|
||||
bool inconsistentComposition = netIn.composition.hash() != m_last_composition_hash;
|
||||
bool resourcesExist = (sctx_p->cvode_mem != nullptr) && (sctx_p->Y != nullptr);
|
||||
|
||||
bool inconsistentComposition = netIn.composition.hash() != sctx_p->last_composition_hash;
|
||||
fourdst::composition::Composition equilibratedComposition;
|
||||
|
||||
if (forceReinitialize || !resourcesExist || inconsistentComposition) {
|
||||
cleanup_cvode_resources(true);
|
||||
sctx_p->reset_cvode();
|
||||
if (!sctx_p->has_context()) {
|
||||
sctx_p->init_context();
|
||||
}
|
||||
LOG_INFO(
|
||||
m_logger,
|
||||
"Preforming full CVODE initialization (Reason: {})",
|
||||
@@ -141,26 +218,24 @@ namespace gridfire::solver {
|
||||
(!resourcesExist ? "CVODE resources do not exist" :
|
||||
"Input composition inconsistent with previous state"));
|
||||
LOG_TRACE_L1(m_logger, "Starting engine update chain...");
|
||||
equilibratedComposition = m_engine.project(*m_scratch_blob, netIn);
|
||||
equilibratedComposition = m_engine.project(*sctx_p->engine_ctx, netIn);
|
||||
LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!");
|
||||
|
||||
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
||||
size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||
uint64_t N = numSpecies + 1;
|
||||
|
||||
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
|
||||
LOG_TRACE_L1(m_logger, "Initializing CVODE resources");
|
||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
||||
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
|
||||
initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0);
|
||||
m_last_size = N;
|
||||
initialize_cvode_integration_resources(sctx_p, N, numSpecies, 0.0, equilibratedComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), 0.0);
|
||||
sctx_p->last_size = N;
|
||||
} else {
|
||||
LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", m_last_size);
|
||||
LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", sctx_p->last_size);
|
||||
|
||||
const size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
||||
const size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||
sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||
for (size_t i = 0; i < numSpecies; i++) {
|
||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
||||
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||
if (netIn.composition.contains(species)) {
|
||||
y_data[i] = netIn.composition.getMolarAbundance(species);
|
||||
} else {
|
||||
@@ -168,16 +243,17 @@ namespace gridfire::solver {
|
||||
}
|
||||
}
|
||||
y_data[numSpecies] = 0.0; // Reset energy accumulator
|
||||
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
|
||||
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, 0.0, m_Y), "CVodeReInit");
|
||||
utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, sctx_p->rel_tol.value(), sctx_p->abs_tol.value()), "CVodeSStolerances");
|
||||
utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, 0.0, sctx_p->Y), "CVodeReInit");
|
||||
|
||||
equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state
|
||||
}
|
||||
|
||||
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
||||
size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||
CVODEUserData user_data {
|
||||
.solver_instance = this,
|
||||
.ctx = *m_scratch_blob,
|
||||
.sctx = sctx_p,
|
||||
.ctx = *sctx_p->engine_ctx,
|
||||
.engine = &m_engine,
|
||||
};
|
||||
LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!");
|
||||
@@ -185,7 +261,7 @@ namespace gridfire::solver {
|
||||
double current_time = 0;
|
||||
// ReSharper disable once CppTooWideScope
|
||||
[[maybe_unused]] double last_callback_time = 0;
|
||||
m_num_steps = 0;
|
||||
sctx_p->num_steps = 0;
|
||||
double accumulated_energy = 0.0;
|
||||
|
||||
double accumulated_neutrino_energy_loss = 0.0;
|
||||
@@ -205,13 +281,13 @@ namespace gridfire::solver {
|
||||
while (current_time < netIn.tMax) {
|
||||
user_data.T9 = T9;
|
||||
user_data.rho = netIn.density;
|
||||
user_data.networkSpecies = &m_engine.getNetworkSpecies(*m_scratch_blob);
|
||||
user_data.networkSpecies = &m_engine.getNetworkSpecies(*sctx_p->engine_ctx);
|
||||
user_data.captured_exception.reset();
|
||||
|
||||
utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
||||
utils::check_cvode_flag(CVodeSetUserData(sctx_p->cvode_mem, &user_data), "CVodeSetUserData");
|
||||
|
||||
LOG_TRACE_L2(m_logger, "Taking one CVODE step...");
|
||||
int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP);
|
||||
int flag = CVode(sctx_p->cvode_mem, netIn.tMax, sctx_p->Y, ¤t_time, CV_ONE_STEP);
|
||||
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag));
|
||||
|
||||
if (user_data.captured_exception){
|
||||
@@ -223,13 +299,13 @@ namespace gridfire::solver {
|
||||
|
||||
long int n_steps;
|
||||
double last_step_size;
|
||||
CVodeGetNumSteps(m_cvode_mem, &n_steps);
|
||||
CVodeGetLastStep(m_cvode_mem, &last_step_size);
|
||||
CVodeGetNumSteps(sctx_p->cvode_mem, &n_steps);
|
||||
CVodeGetLastStep(sctx_p->cvode_mem, &last_step_size);
|
||||
long int nliters, nlcfails;
|
||||
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
|
||||
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
|
||||
CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nliters);
|
||||
CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nlcfails);
|
||||
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
||||
|
||||
// TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it
|
||||
@@ -238,7 +314,7 @@ namespace gridfire::solver {
|
||||
|
||||
size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations;
|
||||
size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures;
|
||||
if (m_stdout_logging_enabled) {
|
||||
if (sctx_p->stdout_logging) {
|
||||
std::println(
|
||||
"Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})",
|
||||
total_steps + n_steps,
|
||||
@@ -253,20 +329,16 @@ namespace gridfire::solver {
|
||||
);
|
||||
}
|
||||
for (size_t i = 0; i < numSpecies; ++i) {
|
||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
||||
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||
if (y_data[i] > 0.0) {
|
||||
postStep.setMolarAbundance(species, y_data[i]);
|
||||
}
|
||||
}
|
||||
// fourdst::composition::Composition collectedComposition = m_engine.collectComposition(postStep, netIn.temperature/1e9, netIn.density);
|
||||
// for (size_t i = 0; i < numSpecies; ++i) {
|
||||
// y_data[i] = collectedComposition.getMolarAbundance(m_engine.getNetworkSpecies()[i]);
|
||||
// }
|
||||
LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy);
|
||||
LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string {
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < numSpecies; ++i) {
|
||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
||||
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||
ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")";
|
||||
if (i < numSpecies - 1) {
|
||||
ss << ", ";
|
||||
@@ -282,36 +354,44 @@ namespace gridfire::solver {
|
||||
? user_data.reaction_contribution_map.value()
|
||||
: kEmptyMap;
|
||||
|
||||
auto ctx = TimestepContext(
|
||||
auto ctx = PointSolverTimestepContext(
|
||||
current_time,
|
||||
m_Y,
|
||||
sctx_p->Y,
|
||||
last_step_size,
|
||||
last_callback_time,
|
||||
T9,
|
||||
netIn.density,
|
||||
n_steps,
|
||||
m_engine,
|
||||
m_engine.getNetworkSpecies(*m_scratch_blob),
|
||||
m_engine.getNetworkSpecies(*sctx_p->engine_ctx),
|
||||
convFail_diff,
|
||||
iter_diff,
|
||||
rcMap,
|
||||
*m_scratch_blob
|
||||
*sctx_p->engine_ctx
|
||||
);
|
||||
|
||||
prev_nonlinear_iterations = nliters + total_nonlinear_iterations;
|
||||
prev_convergence_failures = nlcfails + total_convergence_failures;
|
||||
|
||||
if (m_callback.has_value()) {
|
||||
m_callback.value()(ctx);
|
||||
if (sctx_p->callback.has_value()) {
|
||||
sctx_p->callback.value()(ctx);
|
||||
}
|
||||
trigger->step(ctx);
|
||||
|
||||
if (m_detailed_step_logging) {
|
||||
log_step_diagnostics(*m_scratch_blob, user_data, true, true, true, "step_" + std::to_string(total_steps + n_steps) + ".json");
|
||||
if (sctx_p->detailed_step_logging) {
|
||||
log_step_diagnostics(
|
||||
sctx_p,
|
||||
*sctx_p->engine_ctx,
|
||||
user_data,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
"step_" + std::to_string(total_steps + n_steps) + ".json"
|
||||
);
|
||||
}
|
||||
|
||||
if (trigger->check(ctx)) {
|
||||
if (m_stdout_logging_enabled && displayTrigger) {
|
||||
if (sctx_p->stdout_logging && displayTrigger) {
|
||||
trigger::printWhy(trigger->why(ctx));
|
||||
}
|
||||
trigger->update(ctx);
|
||||
@@ -333,20 +413,20 @@ namespace gridfire::solver {
|
||||
|
||||
fourdst::composition::Composition temp_comp;
|
||||
std::vector<double> mass_fractions;
|
||||
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*m_scratch_blob).size());
|
||||
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size());
|
||||
|
||||
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) {
|
||||
if (num_species_at_stop > sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1) {
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
|
||||
num_species_at_stop,
|
||||
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index
|
||||
sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1 // -1 due to energy in the last index
|
||||
);
|
||||
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
|
||||
}
|
||||
|
||||
for (const auto& species: m_engine.getNetworkSpecies(*m_scratch_blob)) {
|
||||
const size_t sid = m_engine.getSpeciesIndex(*m_scratch_blob, species);
|
||||
for (const auto& species: m_engine.getNetworkSpecies(*sctx_p->engine_ctx)) {
|
||||
const size_t sid = m_engine.getSpeciesIndex(*sctx_p->engine_ctx, species);
|
||||
temp_comp.registerSpecies(species);
|
||||
double y = end_of_step_abundances[sid];
|
||||
if (y > 0.0) {
|
||||
@@ -356,7 +436,7 @@ namespace gridfire::solver {
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (long int i = 0; i < num_species_at_stop; ++i) {
|
||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
||||
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||
if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) {
|
||||
throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification.");
|
||||
}
|
||||
@@ -391,7 +471,7 @@ namespace gridfire::solver {
|
||||
"Prior to Engine Update active reactions are: {}",
|
||||
[&]() -> std::string {
|
||||
std::stringstream ss;
|
||||
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob);
|
||||
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
|
||||
size_t count = 0;
|
||||
for (const auto& reaction : reactions) {
|
||||
ss << reaction -> id();
|
||||
@@ -403,7 +483,7 @@ namespace gridfire::solver {
|
||||
return ss.str();
|
||||
}()
|
||||
);
|
||||
fourdst::composition::Composition currentComposition = m_engine.project(*m_scratch_blob, netInTemp);
|
||||
fourdst::composition::Composition currentComposition = m_engine.project(*sctx_p->engine_ctx, netInTemp);
|
||||
LOG_DEBUG(
|
||||
m_logger,
|
||||
"After to Engine update composition is (molar abundance) {}",
|
||||
@@ -450,7 +530,7 @@ namespace gridfire::solver {
|
||||
"After Engine Update active reactions are: {}",
|
||||
[&]() -> std::string {
|
||||
std::stringstream ss;
|
||||
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob);
|
||||
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
|
||||
size_t count = 0;
|
||||
for (const auto& reaction : reactions) {
|
||||
ss << reaction -> id();
|
||||
@@ -466,34 +546,29 @@ namespace gridfire::solver {
|
||||
m_logger,
|
||||
"Due to a triggered engine update the composition was updated from size {} to {} species.",
|
||||
num_species_at_stop,
|
||||
m_engine.getNetworkSpecies(*m_scratch_blob).size()
|
||||
m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size()
|
||||
);
|
||||
|
||||
numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
||||
numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||
size_t N = numSpecies + 1;
|
||||
|
||||
LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update...");
|
||||
cleanup_cvode_resources(true);
|
||||
sctx_p->reset_cvode();
|
||||
initialize_cvode_integration_resources(sctx_p, N, numSpecies, current_time, currentComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), accumulated_energy);
|
||||
|
||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
||||
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
|
||||
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
|
||||
|
||||
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
|
||||
// throw exceptions::DebugException("Debug");
|
||||
utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, current_time, sctx_p->Y), "CVodeReInit");
|
||||
LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization...");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (m_stdout_logging_enabled) { // Flush the buffer if standard out logging is enabled
|
||||
if (sctx_p->stdout_logging) { // Flush the buffer if standard out logging is enabled
|
||||
std::cout << std::flush;
|
||||
}
|
||||
|
||||
LOG_INFO(m_logger, "CVODE iteration complete");
|
||||
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||
accumulated_energy += y_data[numSpecies];
|
||||
std::vector<double> y_vec(y_data, y_data + numSpecies);
|
||||
|
||||
@@ -505,7 +580,7 @@ namespace gridfire::solver {
|
||||
|
||||
LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies);
|
||||
|
||||
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec);
|
||||
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*sctx_p->engine_ctx), y_vec);
|
||||
LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string {
|
||||
std::ostringstream ss;
|
||||
size_t i = 0;
|
||||
@@ -520,7 +595,7 @@ namespace gridfire::solver {
|
||||
}());
|
||||
|
||||
LOG_INFO(m_logger, "Collecting final composition...");
|
||||
fourdst::composition::Composition outputComposition = m_engine.collectComposition(*m_scratch_blob, topLevelComposition, netIn.temperature/1e9, netIn.density);
|
||||
fourdst::composition::Composition outputComposition = m_engine.collectComposition(*sctx_p->engine_ctx, topLevelComposition, netIn.temperature/1e9, netIn.density);
|
||||
|
||||
assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size());
|
||||
|
||||
@@ -541,11 +616,11 @@ namespace gridfire::solver {
|
||||
NetOut netOut;
|
||||
netOut.composition = outputComposition;
|
||||
netOut.energy = accumulated_energy;
|
||||
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
|
||||
utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
|
||||
|
||||
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
|
||||
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
|
||||
*m_scratch_blob,
|
||||
*sctx_p->engine_ctx,
|
||||
outputComposition,
|
||||
T9,
|
||||
netIn.density
|
||||
@@ -559,53 +634,13 @@ namespace gridfire::solver {
|
||||
LOG_TRACE_L2(m_logger, "Output data built!");
|
||||
LOG_TRACE_L2(m_logger, "Solver evaluation complete!.");
|
||||
|
||||
m_last_composition_hash = netOut.composition.hash();
|
||||
m_last_size = netOut.composition.size() + 1;
|
||||
CVodeGetLastStep(m_cvode_mem, &m_last_good_time_step);
|
||||
sctx_p->last_composition_hash = netOut.composition.hash();
|
||||
sctx_p->last_size = netOut.composition.size() + 1;
|
||||
CVodeGetLastStep(sctx_p->cvode_mem, &sctx_p->last_good_time_step);
|
||||
return netOut;
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::set_callback(const std::any &callback) {
|
||||
m_callback = std::any_cast<TimestepCallback>(callback);
|
||||
}
|
||||
|
||||
bool CVODESolverStrategy::get_stdout_logging_enabled() const {
|
||||
return m_stdout_logging_enabled;
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::set_stdout_logging_enabled(const bool logging_enabled) {
|
||||
m_stdout_logging_enabled = logging_enabled;
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::set_absTol(double absTol) {
|
||||
m_absTol = absTol;
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::set_relTol(double relTol) {
|
||||
m_relTol = relTol;
|
||||
}
|
||||
|
||||
double CVODESolverStrategy::get_absTol() const {
|
||||
if (m_absTol.has_value()) {
|
||||
return m_absTol.value();
|
||||
} else {
|
||||
return -1.0;
|
||||
}
|
||||
}
|
||||
|
||||
double CVODESolverStrategy::get_relTol() const {
|
||||
if (m_relTol.has_value()) {
|
||||
return m_relTol.value();
|
||||
} else {
|
||||
return -1.0;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::describe_callback_context() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
int CVODESolverStrategy::cvode_rhs_wrapper(
|
||||
int PointSolver::cvode_rhs_wrapper(
|
||||
const sunrealtype t,
|
||||
const N_Vector y,
|
||||
const N_Vector ydot,
|
||||
@@ -633,7 +668,7 @@ namespace gridfire::solver {
|
||||
}
|
||||
}
|
||||
|
||||
int CVODESolverStrategy::cvode_jac_wrapper(
|
||||
int PointSolver::cvode_jac_wrapper(
|
||||
sunrealtype t,
|
||||
N_Vector y,
|
||||
N_Vector ydot,
|
||||
@@ -754,7 +789,7 @@ namespace gridfire::solver {
|
||||
return 0;
|
||||
}
|
||||
|
||||
CVODESolverStrategy::CVODERHSOutputData CVODESolverStrategy::calculate_rhs(
|
||||
PointSolver::CVODERHSOutputData PointSolver::calculate_rhs(
|
||||
const sunrealtype t,
|
||||
N_Vector y,
|
||||
N_Vector ydot,
|
||||
@@ -772,10 +807,10 @@ namespace gridfire::solver {
|
||||
}
|
||||
}
|
||||
std::vector<double> y_vec(y_data, y_data + numSpecies);
|
||||
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec);
|
||||
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(data->ctx), y_vec);
|
||||
|
||||
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
|
||||
const auto result = m_engine.calculateRHSAndEnergy(*m_scratch_blob, composition, data->T9, data->rho, false);
|
||||
const auto result = m_engine.calculateRHSAndEnergy(data->ctx, composition, data->T9, data->rho, false);
|
||||
if (!result) {
|
||||
LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()));
|
||||
throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())));
|
||||
@@ -805,7 +840,7 @@ namespace gridfire::solver {
|
||||
}());
|
||||
|
||||
for (size_t i = 0; i < numSpecies; ++i) {
|
||||
fourdst::atomic::Species species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
||||
fourdst::atomic::Species species = m_engine.getNetworkSpecies(data->ctx)[i];
|
||||
ydot_data[i] = dydt.at(species);
|
||||
}
|
||||
ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate
|
||||
@@ -813,7 +848,8 @@ namespace gridfire::solver {
|
||||
return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux};
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::initialize_cvode_integration_resources(
|
||||
void PointSolver::initialize_cvode_integration_resources(
|
||||
PointSolverContext* sctx_p,
|
||||
const uint64_t N,
|
||||
const size_t numSpecies,
|
||||
const double current_time,
|
||||
@@ -821,16 +857,18 @@ namespace gridfire::solver {
|
||||
const double absTol,
|
||||
const double relTol,
|
||||
const double accumulatedEnergy
|
||||
) {
|
||||
) const {
|
||||
LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol);
|
||||
cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones
|
||||
sctx_p->reset_cvode();
|
||||
|
||||
m_Y = utils::init_sun_vector(N, m_sun_ctx);
|
||||
m_YErr = N_VClone(m_Y);
|
||||
sctx_p->cvode_mem = CVodeCreate(CV_BDF, sctx_p->sun_ctx);
|
||||
utils::check_cvode_flag(sctx_p->cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
sctx_p->Y = utils::init_sun_vector(N, sctx_p->sun_ctx);
|
||||
sctx_p->YErr = N_VClone(sctx_p->Y);
|
||||
|
||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
||||
sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||
for (size_t i = 0; i < numSpecies; i++) {
|
||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
||||
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||
if (composition.contains(species)) {
|
||||
y_data[i] = composition.getMolarAbundance(species);
|
||||
} else {
|
||||
@@ -840,8 +878,8 @@ namespace gridfire::solver {
|
||||
y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero
|
||||
|
||||
|
||||
utils::check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit");
|
||||
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
|
||||
utils::check_cvode_flag(CVodeInit(sctx_p->cvode_mem, cvode_rhs_wrapper, current_time, sctx_p->Y), "CVodeInit");
|
||||
utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, relTol, absTol), "CVodeSStolerances");
|
||||
|
||||
// Constraints
|
||||
// We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual
|
||||
@@ -854,53 +892,30 @@ namespace gridfire::solver {
|
||||
// -2.0: The corresponding component of y is constrained to be < 0
|
||||
// Here we use 1.0 for all species to ensure they remain non-negative.
|
||||
|
||||
m_constraints = N_VClone(m_Y);
|
||||
if (m_constraints == nullptr) {
|
||||
sctx_p->constraints = N_VClone(sctx_p->Y);
|
||||
if (sctx_p->constraints == nullptr) {
|
||||
LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE");
|
||||
throw std::runtime_error("Failed to create constraints vector for CVODE");
|
||||
}
|
||||
N_VConst(1.0, m_constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
|
||||
N_VConst(1.0, sctx_p->constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
|
||||
|
||||
utils::check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints");
|
||||
utils::check_cvode_flag(CVodeSetConstraints(sctx_p->cvode_mem, sctx_p->constraints), "CVodeSetConstraints");
|
||||
|
||||
utils::check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep");
|
||||
utils::check_cvode_flag(CVodeSetMaxStep(sctx_p->cvode_mem, 1.0e20), "CVodeSetMaxStep");
|
||||
|
||||
m_J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), m_sun_ctx);
|
||||
utils::check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix");
|
||||
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx);
|
||||
utils::check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
|
||||
sctx_p->J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), sctx_p->sun_ctx);
|
||||
utils::check_cvode_flag(sctx_p->J == nullptr ? -1 : 0, "SUNDenseMatrix");
|
||||
sctx_p->LS = SUNLinSol_Dense(sctx_p->Y, sctx_p->J, sctx_p->sun_ctx);
|
||||
utils::check_cvode_flag(sctx_p->LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
|
||||
|
||||
utils::check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver");
|
||||
utils::check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
|
||||
utils::check_cvode_flag(CVodeSetLinearSolver(sctx_p->cvode_mem, sctx_p->LS, sctx_p->J), "CVodeSetLinearSolver");
|
||||
utils::check_cvode_flag(CVodeSetJacFn(sctx_p->cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
|
||||
LOG_TRACE_L2(m_logger, "CVODE solver initialized");
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::cleanup_cvode_resources(const bool memFree) {
|
||||
LOG_TRACE_L2(m_logger, "Cleaning up cvode resources");
|
||||
if (m_LS) SUNLinSolFree(m_LS);
|
||||
if (m_J) SUNMatDestroy(m_J);
|
||||
if (m_Y) N_VDestroy(m_Y);
|
||||
if (m_YErr) N_VDestroy(m_YErr);
|
||||
if (m_constraints) N_VDestroy(m_constraints);
|
||||
|
||||
m_LS = nullptr;
|
||||
m_J = nullptr;
|
||||
m_Y = nullptr;
|
||||
m_YErr = nullptr;
|
||||
m_constraints = nullptr;
|
||||
|
||||
if (memFree) {
|
||||
if (m_cvode_mem) CVodeFree(&m_cvode_mem);
|
||||
m_cvode_mem = nullptr;
|
||||
}
|
||||
LOG_TRACE_L2(m_logger, "Done Cleaning up cvode resources");
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::set_detailed_step_logging(const bool enabled) {
|
||||
m_detailed_step_logging = enabled;
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::log_step_diagnostics(
|
||||
void PointSolver::log_step_diagnostics(
|
||||
PointSolverContext* sctx_p,
|
||||
scratch::StateBlob &ctx,
|
||||
const CVODEUserData &user_data,
|
||||
bool displayJacobianStiffness,
|
||||
@@ -916,10 +931,10 @@ namespace gridfire::solver {
|
||||
sunrealtype hlast, hcur, tcur;
|
||||
int qlast;
|
||||
|
||||
utils::check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep");
|
||||
utils::check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep");
|
||||
utils::check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder");
|
||||
utils::check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime");
|
||||
utils::check_cvode_flag(CVodeGetLastStep(sctx_p->cvode_mem, &hlast), "CVodeGetLastStep");
|
||||
utils::check_cvode_flag(CVodeGetCurrentStep(sctx_p->cvode_mem, &hcur), "CVodeGetCurrentStep");
|
||||
utils::check_cvode_flag(CVodeGetLastOrder(sctx_p->cvode_mem, &qlast), "CVodeGetLastOrder");
|
||||
utils::check_cvode_flag(CVodeGetCurrentTime(sctx_p->cvode_mem, &tcur), "CVodeGetCurrentTime");
|
||||
|
||||
nlohmann::json j;
|
||||
{
|
||||
@@ -941,13 +956,13 @@ namespace gridfire::solver {
|
||||
// These are the CRITICAL counters for diagnosing your problem
|
||||
long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails;
|
||||
|
||||
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps");
|
||||
utils::check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
|
||||
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
|
||||
utils::check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails");
|
||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
|
||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
|
||||
utils::check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
|
||||
utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, &nsteps), "CVodeGetNumSteps");
|
||||
utils::check_cvode_flag(CVodeGetNumRhsEvals(sctx_p->cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
|
||||
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(sctx_p->cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
|
||||
utils::check_cvode_flag(CVodeGetNumErrTestFails(sctx_p->cvode_mem, &netfails), "CVodeGetNumErrTestFails");
|
||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
|
||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
|
||||
utils::check_cvode_flag(CVodeGetNumLinConvFails(sctx_p->cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
|
||||
|
||||
|
||||
{
|
||||
@@ -975,22 +990,26 @@ namespace gridfire::solver {
|
||||
}
|
||||
|
||||
// --- 3. Get Estimated Local Errors (Your Original Logic) ---
|
||||
utils::check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors");
|
||||
utils::check_cvode_flag(CVodeGetEstLocalErrors(sctx_p->cvode_mem, sctx_p->YErr), "CVodeGetEstLocalErrors");
|
||||
|
||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
||||
sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr);
|
||||
|
||||
const auto absTol = m_config->solver.cvode.absTol;
|
||||
const auto relTol = m_config->solver.cvode.relTol;
|
||||
sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||
sunrealtype *y_err_data = N_VGetArrayPointer(sctx_p->YErr);
|
||||
|
||||
std::vector<double> err_ratios;
|
||||
const size_t num_components = N_VGetLength(m_Y);
|
||||
const size_t num_components = N_VGetLength(sctx_p->Y);
|
||||
err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar
|
||||
|
||||
std::vector<double> Y_full(y_data, y_data + num_components - 1);
|
||||
std::vector<double> E_full(y_err_data, y_err_data + num_components - 1);
|
||||
|
||||
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, relTol, absTol, 10, to_file);
|
||||
if (!sctx_p->abs_tol.has_value()) {
|
||||
sctx_p->abs_tol = m_config->solver.cvode.absTol;
|
||||
}
|
||||
if (!sctx_p->rel_tol.has_value()) {
|
||||
sctx_p->rel_tol = m_config->solver.cvode.relTol;
|
||||
}
|
||||
|
||||
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, sctx_p->rel_tol.value(), sctx_p->abs_tol.value(), 10, to_file);
|
||||
if (to_file && result.has_value()) {
|
||||
j["Limiting_Species"] = result.value();
|
||||
}
|
||||
@@ -1003,8 +1022,9 @@ namespace gridfire::solver {
|
||||
0.0
|
||||
);
|
||||
|
||||
|
||||
for (size_t i = 0; i < num_components - 1; i++) {
|
||||
const double weight = relTol * std::abs(y_data[i]) + absTol;
|
||||
const double weight = sctx_p->rel_tol.value() * std::abs(y_data[i]) + sctx_p->abs_tol.value();
|
||||
if (weight == 0.0) {
|
||||
err_ratios[i] = 0.0; // Avoid division by zero
|
||||
continue;
|
||||
@@ -1013,11 +1033,11 @@ namespace gridfire::solver {
|
||||
err_ratios[i] = err_ratio;
|
||||
}
|
||||
|
||||
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*m_scratch_blob), Y_full);
|
||||
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*m_scratch_blob, composition, user_data.T9, user_data.rho);
|
||||
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*sctx_p->engine_ctx), Y_full);
|
||||
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*sctx_p->engine_ctx, composition, user_data.T9, user_data.rho);
|
||||
|
||||
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho);
|
||||
auto netTimescales = user_data.engine->getSpeciesTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho);
|
||||
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
|
||||
auto netTimescales = user_data.engine->getSpeciesTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
|
||||
|
||||
bool timescaleOkay = false;
|
||||
if (destructionTimescales && netTimescales) timescaleOkay = true;
|
||||
@@ -1037,7 +1057,7 @@ namespace gridfire::solver {
|
||||
if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp));
|
||||
else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity());
|
||||
|
||||
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*m_scratch_blob, sp)));
|
||||
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*sctx_p->engine_ctx, sp)));
|
||||
}
|
||||
|
||||
utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
#include "gridfire/solver/strategies/PointSolver.h"
|
||||
|
||||
#include "gridfire/trigger/trigger_logical.h"
|
||||
#include "gridfire/trigger/trigger_abstract.h"
|
||||
@@ -28,7 +28,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
}
|
||||
|
||||
bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
bool SimulationTimeTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||
if (ctx.t - m_last_trigger_time >= m_interval) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
|
||||
@@ -38,7 +38,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
return false;
|
||||
}
|
||||
|
||||
void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
void SimulationTimeTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
|
||||
if (check(ctx)) {
|
||||
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
|
||||
m_last_trigger_time = ctx.t;
|
||||
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
|
||||
void SimulationTimeTrigger::step(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) {
|
||||
// --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- //
|
||||
}
|
||||
@@ -65,7 +65,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
return "Simulation Time Trigger";
|
||||
}
|
||||
|
||||
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
if (check(ctx)) {
|
||||
@@ -99,18 +99,18 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
}
|
||||
|
||||
bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
bool OffDiagonalTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||
//TODO : This currently does nothing
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
void OffDiagonalTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
|
||||
m_updates++;
|
||||
}
|
||||
|
||||
void OffDiagonalTrigger::step(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) {
|
||||
// --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- //
|
||||
}
|
||||
@@ -126,7 +126,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
return "Off-Diagonal Trigger";
|
||||
}
|
||||
|
||||
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
|
||||
@@ -173,7 +173,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
}
|
||||
|
||||
bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
bool TimestepCollapseTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||
if (m_timestep_window.size() < m_windowSize) {
|
||||
m_misses++;
|
||||
return false;
|
||||
@@ -201,13 +201,13 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
return false;
|
||||
}
|
||||
|
||||
void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
void TimestepCollapseTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
|
||||
m_updates++;
|
||||
m_timestep_window.clear();
|
||||
}
|
||||
|
||||
void TimestepCollapseTrigger::step(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) {
|
||||
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
|
||||
// --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- //
|
||||
@@ -226,7 +226,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
|
||||
TriggerResult TimestepCollapseTrigger::why(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
@@ -263,7 +263,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
m_windowSize(windowSize) {}
|
||||
|
||||
bool ConvergenceFailureTrigger::check(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) const {
|
||||
if (m_window.size() != m_windowSize) {
|
||||
m_misses++;
|
||||
@@ -278,13 +278,13 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
|
||||
void ConvergenceFailureTrigger::update(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) {
|
||||
m_window.clear();
|
||||
}
|
||||
|
||||
void ConvergenceFailureTrigger::step(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) {
|
||||
push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize);
|
||||
m_updates++;
|
||||
@@ -306,7 +306,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")";
|
||||
}
|
||||
|
||||
TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
|
||||
@@ -348,7 +348,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
|
||||
bool ConvergenceFailureTrigger::abs_failure(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) const {
|
||||
if (ctx.currentConvergenceFailures > m_totalFailures) {
|
||||
return true;
|
||||
@@ -357,7 +357,7 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
}
|
||||
|
||||
bool ConvergenceFailureTrigger::rel_failure(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||
) const {
|
||||
const float mean = current_mean();
|
||||
if (mean < 10) {
|
||||
@@ -369,13 +369,13 @@ namespace gridfire::trigger::solver::CVODE {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
|
||||
std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
|
||||
const double simulationTimeInterval,
|
||||
const double offDiagonalThreshold,
|
||||
const double timestepCollapseRatio,
|
||||
const size_t maxConvergenceFailures
|
||||
) {
|
||||
using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext;
|
||||
using ctx_t = gridfire::solver::PointSolverTimestepContext;
|
||||
|
||||
// 1. INSTABILITY TRIGGERS (High Priority)
|
||||
auto convergenceFailureTrigger = std::make_unique<ConvergenceFailureTrigger>(
|
||||
|
||||
Reference in New Issue
Block a user