feat(KINSOL): Switch from Eigen to KINSOL

Previously QSE solving was done using Eigen. While this worked we were
limited in the ability to use previous iterations to speed up later
steps. We have switched to KINSOL, from SUNDIALS, for linear solving.
This has drastically speed up the process of solving for QSE abundances,
primarily because the jacobian matrix does not need to be generated
every single time time a QSE abundance is requested.
This commit is contained in:
2025-11-19 12:06:21 -05:00
parent f7fbc6c1da
commit 442d4ed86c
12 changed files with 506 additions and 386 deletions

View File

@@ -1,6 +1,7 @@
#include "gridfire/engine/views/engine_multiscale.h"
#include "gridfire/exceptions/error_engine.h"
#include "gridfire/engine/procedures/priming.h"
#include "gridfire/utils/sundials.h"
#include <stdexcept>
#include <vector>
@@ -17,12 +18,15 @@
#include "quill/LogMacros.h"
#include "quill/Logger.h"
#include "kinsol/kinsol.h"
#include "sundials/sundials_context.h"
#include "sunmatrix/sunmatrix_dense.h"
#include "sunlinsol/sunlinsol_dense.h"
#include "xxhash64.h"
#include "fourdst/composition/utils/composition_hash.h"
namespace {
constexpr double MIN_ABS_NORM_ALWAYS_CONVERGED = 1e-30;
using namespace fourdst::atomic;
@@ -159,6 +163,18 @@ namespace {
{Eigen::LevenbergMarquardtSpace::Status::GtolTooSmall, "GtolTooSmall"},
{Eigen::LevenbergMarquardtSpace::Status::UserAsked, "UserAsked"}
};
void QuietErrorRouter(int line, const char *func, const char *file, const char *msg,
SUNErrCode err_code, void *err_user_data, SUNContext sunctx) {
// LIST OF ERRORS TO IGNORE
if (err_code == KIN_LINESEARCH_NONCONV) {
return;
}
// For everything else, use the default SUNDIALS logger (or your own)
SUNLogErrHandlerFn(line, func, file, msg, err_code, err_user_data, sunctx);
}
}
namespace gridfire {
@@ -166,7 +182,23 @@ namespace gridfire {
MultiscalePartitioningEngineView::MultiscalePartitioningEngineView(
DynamicEngine& baseEngine
) : m_baseEngine(baseEngine) {}
) : m_baseEngine(baseEngine) {
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
if (flag != 0) {
LOG_CRITICAL(m_logger, "Error while creating SUNContext in MultiscalePartitioningEngineView");
throw std::runtime_error("Error creating SUNContext in MultiscalePartitioningEngineView");
}
SUNContext_PushErrHandler(m_sun_ctx, QuietErrorRouter, nullptr);
}
MultiscalePartitioningEngineView::~MultiscalePartitioningEngineView() {
LOG_TRACE_L1(m_logger, "Cleaning up MultiscalePartitioningEngineView...");
m_qse_solvers.clear();
if (m_sun_ctx) {
SUNContext_Free(&m_sun_ctx);
m_sun_ctx = nullptr;
}
}
const std::vector<Species> & MultiscalePartitioningEngineView::getNetworkSpecies() const {
return m_baseEngine.getNetworkSpecies();
@@ -598,6 +630,7 @@ namespace gridfire {
LOG_TRACE_L1(m_logger, "Partitioning network...");
LOG_TRACE_L1(m_logger, "Clearing previous state...");
m_qse_groups.clear();
m_qse_solvers.clear();
m_dynamic_species.clear();
m_algebraic_species.clear();
m_composition_cache.clear(); // We need to clear the cache now cause the same comp, temp, and density may result in a different value
@@ -790,7 +823,18 @@ namespace gridfire {
m_dynamic_species.push_back(species);
}
}
return getNormalizedEquilibratedComposition(comp, T9, rho);
for (const auto& group : m_qse_groups) {
std::vector<Species> groupAlgebraicSpecies;
for (const auto& species : group.algebraic_species) {
groupAlgebraicSpecies.push_back(species);
}
m_qse_solvers.push_back(std::make_unique<QSESolver>(groupAlgebraicSpecies, m_baseEngine, m_sun_ctx));
}
fourdst::composition::Composition result = getNormalizedEquilibratedComposition(comp, T9, rho);
return result;
}
void MultiscalePartitioningEngineView::exportToDot(
@@ -1502,98 +1546,18 @@ namespace gridfire {
const double T9,
const double rho
) const {
LOG_TRACE_L1(m_logger, "Solving for QSE abundances...");
LOG_TRACE_L2(m_logger, "Composition before QSE solving: {}", [&comp]() -> std::string {
std::stringstream ss;
size_t i = 0;
for (const auto& [sp, y] : comp) {
ss << std::format("{}: {}", sp.name(), y);
if (i < comp.size() - 1) {
ss << ", ";
}
i++;
}
return ss.str();
}());
LOG_TRACE_L2(m_logger, "Solving for QSE abundances...");
fourdst::composition::Composition outputComposition(comp);
for (const auto&[is_in_equilibrium, algebraic_species, seed_species, mean_timescale] : m_qse_groups) {
LOG_TRACE_L2(m_logger, "Working on QSE group with algebraic species: {}",
[&]() -> std::string {
std::stringstream ss;
size_t count = 0;
for (const auto& species: algebraic_species) {
ss << species.name();
if (count < algebraic_species.size() - 1) {
ss << ", ";
}
count++;
}
return ss.str();
}());
if (!is_in_equilibrium || (algebraic_species.empty() && seed_species.empty())) {
continue;
}
Eigen::VectorXd Y_scale(algebraic_species.size());
Eigen::VectorXd v_initial(algebraic_species.size());
long i = 0;
std::unordered_map<Species, size_t> species_to_index_map;
for (const auto& species : algebraic_species) {
constexpr double abundance_floor = 1.0e-100;
const double initial_abundance = comp.getMolarAbundance(species);
const double Y = std::max(initial_abundance, abundance_floor);
v_initial(i) = std::log(Y);
species_to_index_map.emplace(species, i);
LOG_TRACE_L2(m_logger, "For species {} initial molar abundance is {}, log scaled to {}. Species placed at index {}.", species.name(), Y, v_initial(i), i);
i++;
}
LOG_TRACE_L2(m_logger, "Setting up Eigen Levenberg-Marquardt solver for QSE group...");
EigenFunctor functor(*this, algebraic_species, comp, T9, rho, Y_scale, species_to_index_map);
Eigen::LevenbergMarquardt lm(functor);
lm.parameters.ftol = 1.0e-10;
lm.parameters.xtol = 1.0e-10;
LOG_TRACE_L2(m_logger, "Minimizing functor...");
Eigen::LevenbergMarquardtSpace::Status status = lm.minimize(v_initial);
LOG_TRACE_L2(m_logger, "Minimizing functor status: {}", lm_status_map.at(status));
if (status <= 0 || status > 4) {
std::stringstream msg;
msg << "While working on QSE group with algebraic species: ";
size_t count = 0;
for (const auto& species: algebraic_species) {
msg << species;
if (count < algebraic_species.size() - 1) {
msg << ", ";
}
count++;
}
msg << " the QSE solver failed to converge with status: " << lm_status_map.at(status) << " (No. " << status << ").";
LOG_ERROR(m_logger, "{}", msg.str());
throw std::runtime_error(msg.str());
}
LOG_TRACE_L1(m_logger, "QSE Group minimization succeeded with status: {}", lm_status_map.at(status));
Eigen::VectorXd Y_final_qse = v_initial.array().exp(); // Convert back to physical abundances using exponential scaling
i = 0;
for (const auto& species: algebraic_species) {
LOG_TRACE_L1(
m_logger,
"During QSE solving species {} started with a molar abundance of {} and ended with an abundance of {}.",
species.name(),
comp.getMolarAbundance(species),
Y_final_qse(i)
);
// double Xi = Y_final_qse(i) * species.mass(); // Convert from molar abundance to mass fraction
if (!outputComposition.contains (species)) {
outputComposition.registerSpecies(species);
}
outputComposition.setMolarAbundance(species, Y_final_qse(i));
i++;
for (const auto& [group, solver]: std::views::zip(m_qse_groups, m_qse_solvers)) {
const fourdst::composition::Composition groupResult = solver->solve(outputComposition, T9, rho);
for (const auto& [sp, y] : groupResult) {
outputComposition.setMolarAbundance(sp, y);
}
solver->log_diagnostics();
}
LOG_TRACE_L2(m_logger, "Done solving for QSE abundances!");
return outputComposition;
}
@@ -1795,121 +1759,249 @@ namespace gridfire {
return candidate_groups;
}
//////////////////////////////////////
/// Eigen Functor Member Functions ///
/////////////////////////////////////
//////////////////////////////////
/// QSESolver Member Functions ///
//////////////////////////////////
int MultiscalePartitioningEngineView::EigenFunctor::operator()(const InputType &v_qse, OutputType &f_qse) const {
fourdst::composition::Composition comp_trial(m_initial_comp.getRegisteredSpecies());
for (const auto& [sp, y] : m_initial_comp) {
comp_trial.setMolarAbundance(sp, y);
}
Eigen::VectorXd y_qse = v_qse.array().exp(); // Convert to physical abundances using exponential scaling
MultiscalePartitioningEngineView::QSESolver::QSESolver(
const std::vector<fourdst::atomic::Species>& species,
const DynamicEngine& engine,
const SUNContext sun_ctx
) :
m_N(species.size()),
m_engine(engine),
m_species(species),
m_sun_ctx(sun_ctx) {
m_Y = utils::init_sun_vector(m_N, m_sun_ctx);
m_scale = N_VClone(m_Y);
m_f_scale = N_VClone(m_Y);
m_constraints = N_VClone(m_Y);
m_func_tmpl = N_VClone(m_Y);
for (const auto& species: m_qse_solve_species) {
auto index = static_cast<long>(m_qse_solve_species_index_map.at(species));
comp_trial.setMolarAbundance(species, y_qse[index]);
if (!m_Y || !m_scale || !m_constraints || !m_func_tmpl) {
LOG_CRITICAL(getLogger(), "Failed to allocate SUNVectors for QSE solver.");
throw std::runtime_error("Failed to allocate SUNVectors for QSE solver.");
}
const auto result = m_view.getBaseEngine().calculateRHSAndEnergy(comp_trial, m_T9, m_rho);
if (!result) {
throw exceptions::StaleEngineError("Failed to calculate RHS and energy due to stale engine state");
for (size_t i = 0; i < m_N; ++i) {
m_speciesMap[m_species[i]] = i;
}
const auto&[dydt, nuclearEnergyGenerationRate, _] = result.value();
f_qse.resize(static_cast<long>(m_qse_solve_species.size()));
long i = 0;
// TODO: make sure that just counting up i is a valid approach, this is a possible place an indexing bug may have crept in
for (const auto& species: m_qse_solve_species) {
const double dydt_i = dydt.at(species);
f_qse(i) = dydt_i/y_qse(i); // We square the residuals to improve numerical stability in the solver
i++;
}
LOG_TRACE_L2(
m_view.m_logger,
"Functor evaluation at T9 = {}, rho = {}, y_qse (v_qse) = <{}> complete. ||f|| = {}",
m_T9,
m_rho,
[&]() -> std::string {
std::stringstream ss;
for (long j = 0; j < y_qse.size(); ++j) {
ss << y_qse(j);
ss << "(" << v_qse(j) << ")";
if (j < y_qse.size() - 1) {
ss << ", ";
}
}
return ss.str();
}(),
f_qse.norm()
);
LOG_TRACE_L3(
m_view.m_logger,
"{}",
[&]() -> std::string {
std::stringstream ss;
const std::vector species(m_qse_solve_species.begin(), m_qse_solve_species.end());
for (long j = 0; j < f_qse.size(); ++j) {
ss << "Residual for species " << species.at(j).name() << " f(" << j << ") = " << f_qse(j) << "\n";
}
return ss.str();
}()
);
return 0;
N_VConst(1.0, m_constraints);
m_kinsol_mem = KINCreate(m_sun_ctx);
utils::check_cvode_flag(m_kinsol_mem ? 0 : -1, "KINCreate");
utils::check_cvode_flag(KINInit(m_kinsol_mem, sys_func, m_func_tmpl), "KINInit");
utils::check_cvode_flag(KINSetConstraints(m_kinsol_mem, m_constraints), "KINSetConstraints");
m_J = SUNDenseMatrix(static_cast<sunindextype>(m_N), static_cast<sunindextype>(m_N), m_sun_ctx);
utils::check_cvode_flag(m_J ? 0 : -1, "SUNDenseMatrix");
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx);
utils::check_cvode_flag(m_LS ? 0 : -1, "SUNLinSol_Dense");
utils::check_cvode_flag(KINSetLinearSolver(m_kinsol_mem, m_LS, m_J), "KINSetLinearSolver");
utils::check_cvode_flag(KINSetJacFn(m_kinsol_mem, sys_jac), "KINSetJacFn");
utils::check_cvode_flag(KINSetMaxSetupCalls(m_kinsol_mem, 20), "KINSetMaxSetupCalls");
utils::check_cvode_flag(KINSetFuncNormTol(m_kinsol_mem, 1e-6), "KINSetFuncNormTol");
utils::check_cvode_flag(KINSetNumMaxIters(m_kinsol_mem, 200), "KINSetNumMaxIters");
// We want to effectively disable this since enormous changes in order of magnitude are realistic for this problem.
utils::check_cvode_flag(KINSetMaxNewtonStep(m_kinsol_mem, 200), "KINSetMaxNewtonStep");
}
int MultiscalePartitioningEngineView::EigenFunctor::df(const InputType &v_qse, JacobianType &J_qse) const {
fourdst::composition::Composition comp_trial(m_initial_comp.getRegisteredSpecies());
for (const auto& [sp, y] : m_initial_comp) {
comp_trial.setMolarAbundance(sp, y);
MultiscalePartitioningEngineView::QSESolver::~QSESolver() {
if (m_Y) {
N_VDestroy(m_Y);
m_Y = nullptr;
}
Eigen::VectorXd y_qse = v_qse.array().exp(); // Convert to physical abundances using exponential scaling
if (m_scale) {
N_VDestroy(m_scale);
m_scale = nullptr;
}
if (m_f_scale) {
N_VDestroy(m_f_scale);
m_f_scale = nullptr;
}
if (m_constraints) {
N_VDestroy(m_constraints);
m_constraints = nullptr;
}
if (m_func_tmpl) {
N_VDestroy(m_func_tmpl);
m_func_tmpl = nullptr;
}
if (m_kinsol_mem) {
KINFree(&m_kinsol_mem);
m_kinsol_mem = nullptr;
}
if (m_J) {
SUNMatDestroy(m_J);
m_J = nullptr;
}
if (m_LS) {
SUNLinSolFree(m_LS);
m_LS = nullptr;
}
}
for (const auto& species: m_qse_solve_species) {
const double molarAbundance = y_qse[static_cast<long>(m_qse_solve_species_index_map.at(species))];
comp_trial.setMolarAbundance(species, molarAbundance);
fourdst::composition::Composition MultiscalePartitioningEngineView::QSESolver::solve(
const fourdst::composition::Composition &comp,
const double T9,
const double rho
) const {
fourdst::composition::Composition result = comp;
UserData data {
m_engine,
T9,
rho,
result,
m_speciesMap,
m_species
};
utils::check_cvode_flag(KINSetUserData(m_kinsol_mem, &data), "KINSetUserData");
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
sunrealtype* scale_data = N_VGetArrayPointer(m_scale);
// It is more cache optimized to do a standard as opposed to range based for-loop here
for (size_t i = 0; i < m_N; ++i) {
const auto& species = m_species[i];
double Y = result.getMolarAbundance(species);
constexpr double abundance_floor = 1.0e-100;
Y = std::max(abundance_floor, Y);
y_data[i] = Y;
scale_data[i] = 1.0;
}
std::vector<Species> qse_species_vector(m_qse_solve_species.begin(), m_qse_solve_species.end());
NetworkJacobian jac = m_view.getBaseEngine().generateJacobianMatrix(comp_trial, m_T9, m_rho, qse_species_vector);
const auto result = m_view.getBaseEngine().calculateRHSAndEnergy(comp_trial, m_T9, m_rho);
auto initial_rhs = m_engine.calculateRHSAndEnergy(result, T9, rho);
if (!initial_rhs) {
throw std::runtime_error("In QSE solver failed to calculate initial RHS");
}
sunrealtype* f_scale_data = N_VGetArrayPointer(m_f_scale);
for (size_t i = 0; i < m_N; ++i) {
const auto& species = m_species[i];
double dydt = std::abs(initial_rhs.value().dydt.at(species));
f_scale_data[i] = 1.0 / (dydt + 1e-15);
}
if (m_solves > 0) {
// After the initial solution we want to allow kinsol to reuse its state
utils::check_cvode_flag(KINSetNoInitSetup(m_kinsol_mem, SUNTRUE), "KINSetNoInitSetup");
} else {
utils::check_cvode_flag(KINSetNoInitSetup(m_kinsol_mem, SUNFALSE), "KINSetNoInitSetup");
}
const int flag = KINSol(m_kinsol_mem, m_Y, KIN_LINESEARCH, m_scale, m_f_scale);
if (flag < 0) {
LOG_WARNING(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}.", utils::cvode_ret_code_map.at(flag));
return comp;
}
for (size_t i = 0; i < m_N; ++i) {
const auto& species = m_species[i];
result.setMolarAbundance(species, y_data[i]);
}
m_solves++;
return result;
}
size_t MultiscalePartitioningEngineView::QSESolver::solves() const {
return m_solves;
}
void MultiscalePartitioningEngineView::QSESolver::log_diagnostics() const {
long int nni, nfe, nje;
int flag = KINGetNumNonlinSolvIters(m_kinsol_mem, &nni);
flag = KINGetNumFuncEvals(m_kinsol_mem, &nfe);
flag = KINGetNumJacEvals(m_kinsol_mem, &nje);
LOG_INFO(getLogger(),
"QSE Stats | Iters: {} | RHS Evals: {} | Jac Evals: {} | Ratio (J/I): {:.2f}",
nni, nfe, nje, static_cast<double>(nje) / static_cast<double>(nni)
);
getLogger()->flush_log(true);
}
int MultiscalePartitioningEngineView::QSESolver::sys_func(
const N_Vector y,
const N_Vector f,
void *user_data
) {
const auto* data = static_cast<UserData*>(user_data);
const sunrealtype* y_data = N_VGetArrayPointer(y);
sunrealtype* f_data = N_VGetArrayPointer(f);
const auto& map = data->qse_solve_species_index_map;
for (size_t index = 0; index < data->qse_solve_species.size(); ++index) {
const auto& species = data->qse_solve_species[index];
data->comp.setMolarAbundance(species, y_data[index]);
}
const auto result = data->engine.calculateRHSAndEnergy(data->comp, data->T9, data->rho);
if (!result) {
throw exceptions::StaleEngineError("Failed to calculate RHS and energy due to stale engine state");
}
const auto&[dydt, nuclearEnergyGenerationRate, _] = result.value();
const long N = static_cast<long>(m_qse_solve_species.size());
J_qse.resize(N, N);
long rowID = 0;
for (const auto& rowSpecies : m_qse_solve_species) {
long colID = 0;
for (const auto& colSpecies: m_qse_solve_species) {
J_qse(rowID, colID) = jac(rowSpecies, colSpecies);
colID += 1;
LOG_TRACE_L3(m_view.m_logger, "Jacobian[{}, {}] (d(dY({}))/dY({})) = {}", rowID, colID - 1, rowSpecies.name(), colSpecies.name(), J_qse(rowID, colID - 1));
}
rowID += 1;
return 1; // Potentially recoverable error
}
for (long i = 0; i < J_qse.rows(); ++i) {
for (long j = 0; j < J_qse.cols(); ++j) {
double on_diag_correction = 0.0;
if (i == j) {
auto rowSpecies = *(std::next(m_qse_solve_species.begin(), i));
const double Fi = dydt.at(rowSpecies);
on_diag_correction = Fi / y_qse(i);
}
J_qse(i, j) = y_qse(j) * (J_qse(i, j) - on_diag_correction) / y_qse(i); // Apply chain rule J'(i,j) = y_j * (J(i,j) - δ_ij(F_i/Y_i)) / Y_i
}
}
const auto& dydt = result.value().dydt;
m_cached_jacobian = J_qse; // Cache the computed Jacobian for future use
for (const auto& [species, index] : map) {
f_data[index] = dydt.at(species);
}
return 0; // Success
}
int MultiscalePartitioningEngineView::QSESolver::sys_jac(
const N_Vector y,
N_Vector fy,
SUNMatrix J,
void *user_data,
N_Vector tmp1,
N_Vector tmp2
) {
const auto* data = static_cast<UserData*>(user_data);
const sunrealtype* y_data = N_VGetArrayPointer(y);
const auto& map = data->qse_solve_species_index_map;
for (const auto& [species, index] : map) {
data->comp.setMolarAbundance(species, y_data[index]);
}
const NetworkJacobian jac = data->engine.generateJacobianMatrix(
data->comp,
data->T9,
data->rho,
data->qse_solve_species
);
sunrealtype* J_data = SUNDenseMatrix_Data(J);
const sunindextype N = SUNDenseMatrix_Columns(J);
for (const auto& [col_species, col_idx] : map) {
for (const auto& [row_species, row_idx] : map) {
J_data[col_idx * N + row_idx] = jac(row_species, col_species);
}
}
return 0;
}
/////////////////////////////////
/// QSEGroup Member Functions ///