From ed1c5a1ac73231e0bf6230c61f74cfc94e7f0b35 Mon Sep 17 00:00:00 2001 From: Emily Boudreaux Date: Fri, 15 Aug 2025 12:11:32 -0400 Subject: [PATCH] feat(solver): added CVODE solver from SUNDIALS --- .gitignore | 1 + build-config/cvode/meson.build | 47 +++ build-config/meson.build | 2 + meson.build | 8 + .../gridfire/screening/screening_weak.h | 4 +- .../solver/strategies/CVODE_solver_strategy.h | 136 +++++++ src/lib/engine/procedures/construction.cpp | 4 +- src/lib/engine/procedures/priming.cpp | 6 +- src/lib/engine/views/engine_adaptive.cpp | 4 +- src/lib/engine/views/engine_defined.cpp | 2 +- src/lib/engine/views/engine_multiscale.cpp | 8 +- src/lib/reaction/reaclib.cpp | 8 +- .../strategies/CVODE_solver_strategy.cpp | 362 ++++++++++++++++++ src/meson.build | 37 +- subprojects/cvode.wrap | 6 + tests/graphnet_sandbox/main.cpp | 15 +- 16 files changed, 588 insertions(+), 62 deletions(-) create mode 100644 build-config/cvode/meson.build create mode 100644 src/include/gridfire/solver/strategies/CVODE_solver_strategy.h create mode 100644 src/lib/solver/strategies/CVODE_solver_strategy.cpp create mode 100644 subprojects/cvode.wrap diff --git a/.gitignore b/.gitignore index f9277727..feedac37 100644 --- a/.gitignore +++ b/.gitignore @@ -78,6 +78,7 @@ subprojects/eigen-*/ subprojects/fourdst/ subprojects/libplugin/ subprojects/minizip-ng-4.0.10/ +subprojects/cvode-*/ *.fbundle *.csv diff --git a/build-config/cvode/meson.build b/build-config/cvode/meson.build new file mode 100644 index 00000000..c0a9e74d --- /dev/null +++ b/build-config/cvode/meson.build @@ -0,0 +1,47 @@ +cmake = import('cmake') +cvode_cmake_options = cmake.subproject_options() + +cvode_cmake_options.add_cmake_defines({ + 'CMAKE_CXX_FLAGS' : '-Wno-deprecated-declarations', + 'CMAKE_C_FLAGS' : '-Wno-deprecated-declarations', + 'BUILD_SHARED_LIBS' : 'ON', + 'BUILD_STATIC_LIBS' : 'OFF', + +}) + +cvode_cmake_options.add_cmake_defines({ + 'CMAKE_INSTALL_LIBDIR': get_option('libdir'), + 'CMAKE_INSTALL_INCLUDEDIR': get_option('includedir') +}) + +cvode_sp = cmake.subproject( + 'cvode', + options: cvode_cmake_options, +) + +# For the core SUNDIALS library (SUNContext, etc.) +sundials_core_dep = cvode_sp.dependency('sundials_core_shared') + +# For the CVODE integrator library +sundials_cvode_dep = cvode_sp.dependency('sundials_cvode_shared') + +# For the serial NVector library +sundials_nvecserial_dep = cvode_sp.dependency('sundials_nvecserial_shared') + +# For the dense matrix library +sundials_sunmatrixdense_dep = cvode_sp.dependency('sundials_sunmatrixdense_shared') + +# For the dense linear solver library +sundials_sunlinsoldense_dep = cvode_sp.dependency('sundials_sunlinsoldense_shared') + +sundials_dep = declare_dependency( + dependencies: [ + sundials_core_dep, + sundials_cvode_dep, + sundials_nvecserial_dep, + sundials_sunmatrixdense_dep, + sundials_sunlinsoldense_dep, + ], +) + + diff --git a/build-config/meson.build b/build-config/meson.build index d01e9626..619b242e 100644 --- a/build-config/meson.build +++ b/build-config/meson.build @@ -3,6 +3,8 @@ cmake = import('cmake') subdir('fourdst') subdir('libplugin') +subdir('cvode') + subdir('boost') subdir('cppad') subdir('xxHash') diff --git a/meson.build b/meson.build index d5dfa2d3..e58bca8e 100644 --- a/meson.build +++ b/meson.build @@ -45,20 +45,28 @@ llevel = get_option('log_level') logbase='QUILL_COMPILE_ACTIVE_LOG_LEVEL_' if (llevel == 'traceL3') + message('Setting log level to TRACE_L3') log_argument = logbase + 'TRACE_L3' elif (llevel == 'traceL2') + message('Setting log level to TRACE_L2') log_argument = logbase + 'TRACE_L2' elif (llevel == 'traceL1') + message('Setting log level to TRACE_L1') log_argument = logbase + 'TRACE_L1' elif (llevel == 'debug') + message('Setting log level to DEBUG') log_argument = logbase + 'DEBUG' elif (llevel == 'info') + message('Setting log level to INFO') log_argument = logbase + 'INFO' elif (llevel == 'warning') + message('Setting log level to WARNING') log_argument = logbase + 'WARNING' elif (llevel == 'error') + message('Setting log level to ERROR') log_argument = logbase + 'ERROR' elif (llevel == 'critical') + message('Setting log level to CRITICAL') log_argument = logbase + 'CRITICAL' endif diff --git a/src/include/gridfire/screening/screening_weak.h b/src/include/gridfire/screening/screening_weak.h index 2b53568a..ef8cd2ad 100644 --- a/src/include/gridfire/screening/screening_weak.h +++ b/src/include/gridfire/screening/screening_weak.h @@ -187,13 +187,13 @@ namespace gridfire::screening { reactants[1] == reactants[2] ); if (reactants.size() == 2) { - LOG_TRACE_L3(m_logger, "Calculating screening factor for reaction: {}", reaction.peName()); + LOG_TRACE_L3(m_logger, "Calculating screening factor for reaction: {}", reaction->id()); const T Z1 = static_cast(reactants[0].m_z); const T Z2 = static_cast(reactants[1].m_z); H_12 = prefactor * Z1 * Z2; } else if (isTripleAlpha) { - LOG_TRACE_L3(m_logger, "Special case for triple alpha process in reaction: {}", reaction.peName()); + LOG_TRACE_L3(m_logger, "Special case for triple alpha process in reaction: {}", reaction->id()); // Special case for triple alpha process const T Z_alpha = static_cast(2.0); const T H_alpha_alpha = prefactor * Z_alpha * Z_alpha; diff --git a/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h b/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h new file mode 100644 index 00000000..4b2f7e92 --- /dev/null +++ b/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h @@ -0,0 +1,136 @@ +#pragma once + +#include "gridfire/solver/solver.h" +#include "gridfire/engine/engine_abstract.h" +#include "gridfire/network.h" +#include "gridfire/exceptions/exceptions.h" + +#include "fourdst/composition/atomicSpecies.h" +#include "fourdst/config/config.h" + + +#include +#include +#include +#include +#include + +// SUNDIALS/CVODE headers +#include +#include + +// Include headers for linear solvers and N_Vectors +// We will use preprocessor directives to select the correct ones +#include // For CVDls (serial dense linear solver) +#include +#include +#include + +#ifdef SUNDIALS_HAVE_OPENMP + #include +#endif +#ifdef SUNDIALS_HAVE_PTHREADS + #include +#endif +// Default to serial if no parallelism is enabled +#ifndef SUNDIALS_HAVE_OPENMP + #ifndef SUNDIALS_HAVE_PTHREADS + #include + #endif +#endif + +namespace gridfire::solver { + class CVODESolverStrategy final : public DynamicNetworkSolverStrategy { + public: + explicit CVODESolverStrategy(DynamicEngine& engine); + ~CVODESolverStrategy() override; + + // Make the class non-copyable and non-movable to prevent shallow copies of CVODE pointers + CVODESolverStrategy(const CVODESolverStrategy&) = delete; + CVODESolverStrategy& operator=(const CVODESolverStrategy&) = delete; + CVODESolverStrategy(CVODESolverStrategy&&) = delete; + CVODESolverStrategy& operator=(CVODESolverStrategy&&) = delete; + + NetOut evaluate(const NetIn& netIn) override; + + void set_callback(const std::any &callback) override; + + bool get_stdout_logging_enabled() const; + + void set_stdout_logging_enabled(const bool value); + + [[nodiscard]] std::vector> describe_callback_context() const override; + + struct TimestepContext final : public SolverContextBase { + // This struct can be identical to the one in DirectNetworkSolver + const double t; + const N_Vector& state; // Note: state is now an N_Vector + const double dt; + const double last_step_time; + const double T9; + const double rho; + const int num_steps; + const DynamicEngine& engine; + const std::vector& networkSpecies; + + // Constructor + TimestepContext( + double t, const N_Vector& state, double dt, double last_step_time, + double t9, double rho, int num_steps, const DynamicEngine& engine, + const std::vector& networkSpecies + ); + + [[nodiscard]] std::vector> describe() const override; + }; + + using TimestepCallback = std::function; ///< Type alias for a timestep callback function. + private: + /** + * @struct CVODEUserData + * @brief A helper struct to pass C++ context to C-style CVODE callbacks. + * + * CVODE callbacks are C functions and use a void* pointer to pass user data. + * This struct bundles all the necessary C++ objects (like 'this', engine references, etc.) + * to be accessed safely within those static C wrappers. + */ + struct CVODEUserData { + CVODESolverStrategy* solver_instance; // Pointer back to the class instance + DynamicEngine* engine; + double T9; + double rho; + const std::vector* networkSpecies; + std::unique_ptr captured_exception = nullptr; + }; + + private: + Config& m_config = Config::getInstance(); + static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, void *user_data); + static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); + + int calculate_rhs(sunrealtype t, N_Vector y, N_Vector ydot, const CVODEUserData* data) const; + + void initialize_cvode_integration_resources( + uint64_t N, + size_t numSpecies, + double current_time, + const fourdst::composition::Composition& composition, + double absTol, + double relTol, + double accumulatedEnergy + ); + + void cleanup_cvode_resources(bool memFree); + private: + SUNContext m_sun_ctx = nullptr; + void* m_cvode_mem = nullptr; + N_Vector m_Y = nullptr; + SUNMatrix m_J = nullptr; + SUNLinearSolver m_LS = nullptr; + + + TimestepCallback m_callback; + int m_num_steps = 0; + + bool m_stdout_logging_enabled = true; + }; +} \ No newline at end of file diff --git a/src/lib/engine/procedures/construction.cpp b/src/lib/engine/procedures/construction.cpp index 25ee27b3..55ee8755 100644 --- a/src/lib/engine/procedures/construction.cpp +++ b/src/lib/engine/procedures/construction.cpp @@ -48,7 +48,7 @@ namespace gridfire { if (depth == static_cast(NetworkBuildDepth::Full)) { LOG_INFO(logger, "Building full nuclear network with a total of {} reactions.", allReactions.size()); const ReactionSet reactionSet(remainingReactions); - return reaction::packReactionSet(reactionSet); + return reactionSet; } std::unordered_set availableSpecies; @@ -104,7 +104,7 @@ namespace gridfire { LOG_INFO(logger, "Network construction completed with {} reactions collected.", collectedReactions.size()); const ReactionSet reactionSet(collectedReactions); - return reaction::packReactionSet(reactionSet); + return reactionSet; } } \ No newline at end of file diff --git a/src/lib/engine/procedures/priming.cpp b/src/lib/engine/procedures/priming.cpp index 163d83fc..7f3bfd79 100644 --- a/src/lib/engine/procedures/priming.cpp +++ b/src/lib/engine/procedures/priming.cpp @@ -116,7 +116,7 @@ namespace gridfire { LOG_TRACE_L3(logger, "Found equilibrium for {}: X_eq = {:.4e}", primingSpecies.name(), equilibriumMassFraction); if (const reaction::Reaction* dominantChannel = findDominantCreationChannel(primer, primingSpecies, Y, T9, rho)) { - LOG_TRACE_L3(logger, "Dominant creation channel for {}: {}", primingSpecies.name(), dominantChannel->peName()); + LOG_TRACE_L3(logger, "Dominant creation channel for {}: {}", primingSpecies.name(), dominantChannel->id()); double totalReactantMass = 0.0; for (const auto& reactant : dominantChannel->reactants()) { @@ -209,7 +209,7 @@ namespace gridfire { double calculateCreationRate( const DynamicEngine& engine, - const fourdst::atomic::Species& species, + const Species& species, const std::vector& Y, const double T9, const double rho @@ -218,6 +218,8 @@ namespace gridfire { for (const auto& reaction: engine.getNetworkReactions()) { const int stoichiometry = reaction->stoichiometry(species); if (stoichiometry > 0) { + if (engine.calculateMolarReactionFlow(*reaction, Y, T9, rho) > 0.0) { + } creationRate += stoichiometry * engine.calculateMolarReactionFlow(*reaction, Y, T9, rho); } } diff --git a/src/lib/engine/views/engine_adaptive.cpp b/src/lib/engine/views/engine_adaptive.cpp index b89cbf73..34d830ca 100644 --- a/src/lib/engine/views/engine_adaptive.cpp +++ b/src/lib/engine/views/engine_adaptive.cpp @@ -385,7 +385,7 @@ namespace gridfire { for (const auto& reaction : fullReactionSet) { const double flow = m_baseEngine.calculateMolarReactionFlow(*reaction, out_Y_Full, T9, rho); reactionFlows.push_back({reaction.get(), flow}); - LOG_TRACE_L1(m_logger, "Reaction '{}' has flow rate: {:0.3E} [mol/s/g]", reaction.id(), flow); + LOG_TRACE_L1(m_logger, "Reaction '{}' has flow rate: {:0.3E} [mol/s/g]", reaction->id(), flow); } return reactionFlows; } @@ -423,7 +423,7 @@ namespace gridfire { if (!reachable.contains(product)) { reachable.insert(product); new_species_found_in_pass = true; - LOG_TRACE_L2(m_logger, "Network Connectivity Analysis: Species '{}' is reachable via reaction '{}'.", product.name(), reaction.id()); + LOG_TRACE_L2(m_logger, "Network Connectivity Analysis: Species '{}' is reachable via reaction '{}'.", product.name(), reaction->id()); } } } diff --git a/src/lib/engine/views/engine_defined.cpp b/src/lib/engine/views/engine_defined.cpp index e2430c4d..c05d36a9 100644 --- a/src/lib/engine/views/engine_defined.cpp +++ b/src/lib/engine/views/engine_defined.cpp @@ -358,7 +358,7 @@ namespace gridfire { LOG_TRACE_L3(m_logger, "Active reactions: {}", [this]() -> std::string { std::string result; for (const auto& reaction : m_activeReactions) { - result += std::string(reaction.id()) + ", "; + result += std::string(reaction->id()) + ", "; } if (!result.empty()) { result.pop_back(); // Remove last space diff --git a/src/lib/engine/views/engine_multiscale.cpp b/src/lib/engine/views/engine_multiscale.cpp index 2f3880fe..13309ffe 100644 --- a/src/lib/engine/views/engine_multiscale.cpp +++ b/src/lib/engine/views/engine_multiscale.cpp @@ -1020,7 +1020,7 @@ namespace gridfire { LOG_TRACE_L3( m_logger, "Reaction {} is internal to the group containing {} and contributes to internal flux by {}", - reaction.id(), + reaction->id(), [&]() -> std::string { std::stringstream ss; int count = 0; @@ -1040,7 +1040,7 @@ namespace gridfire { LOG_TRACE_L3( m_logger, "Reaction {} is external to the group containing {} and contributes to external flux by {}", - reaction.id(), + reaction->id(), [&]() -> std::string { std::stringstream ss; int count = 0; @@ -1406,13 +1406,13 @@ namespace gridfire { for (const auto& reactant : reaction->reactants()) { if (std::ranges::find(pool, m_baseEngine.getSpeciesIndex(reactant)) == pool.end()) { has_external_reactant = true; - LOG_TRACE_L3(m_logger, "Found external reactant {} in reaction {} for species {}.", reactant.name(), reaction.id(), ash.name()); + LOG_TRACE_L3(m_logger, "Found external reactant {} in reaction {} for species {}.", reactant.name(), reaction->id(), ash.name()); break; // Found an external reactant, no need to check further } } if (has_external_reactant) { double flow = std::abs(m_baseEngine.calculateMolarReactionFlow(*reaction, Y, T9, rho)); - LOG_TRACE_L3(m_logger, "Found bridge reaction {} with flow {} for species {}.", reaction.id(), flow, ash.name()); + LOG_TRACE_L3(m_logger, "Found bridge reaction {} with flow {} for species {}.", reaction->id(), flow, ash.name()); bridge_reactions.emplace_back(reaction.get(), flow); } } diff --git a/src/lib/reaction/reaclib.cpp b/src/lib/reaction/reaclib.cpp index 5382ec80..94bb50c9 100644 --- a/src/lib/reaction/reaclib.cpp +++ b/src/lib/reaction/reaclib.cpp @@ -22,11 +22,11 @@ std::string trim_whitespace(const std::string& str) { } const auto ritr = std::find_if(str.rbegin(), std::string::const_reverse_iterator(startIt), [](const unsigned char ch){ return !std::isspace(ch); }); - return std::string(startIt, ritr.base()); + return {startIt, ritr.base()}; } namespace gridfire::reaclib { - static reaction::ReactionSet* s_all_reaclib_reactions_ptr = nullptr; + static std::unique_ptr s_all_reaclib_reactions_ptr = nullptr; #pragma pack(push, 1) struct ReactionRecord { @@ -125,9 +125,7 @@ namespace gridfire::reaclib { const reaction::ReactionSet reaction_set(std::move(reaction_list)); // The LogicalReactionSet groups reactions by their peName, which is what we want. - s_all_reaclib_reactions_ptr = new reaction::ReactionSet( - reaction::packReactionSet(reaction_set) - ); + s_all_reaclib_reactions_ptr = std::make_unique(reaction::packReactionSet(reaction_set)); s_initialized = true; } diff --git a/src/lib/solver/strategies/CVODE_solver_strategy.cpp b/src/lib/solver/strategies/CVODE_solver_strategy.cpp new file mode 100644 index 00000000..22fb48b4 --- /dev/null +++ b/src/lib/solver/strategies/CVODE_solver_strategy.cpp @@ -0,0 +1,362 @@ +#include "gridfire/solver/strategies/CVODE_solver_strategy.h" + +#include "gridfire/network.h" + +#include "fourdst/composition/composition.h" + +// ReSharper disable once CppUnusedIncludeDirective +#include +#include +#include +#include +#include + + +namespace { + std::unordered_map cvode_ret_code_map { + {0, "CV_SUCCESS: The solver succeeded."}, + {1, "CV_TSTOP_RETURN: The solver reached the specified stopping time."}, + {2, "CV_ROOT_RETURN: A root was found."}, + {-99, "CV_WARNING: CVODE succeeded but in an unusual manner"}, + {-1, "CV_TOO_MUCH_WORK: The solver took too many internal steps."}, + {-2, "CV_TOO_MUCH_ACC: The solver could not satisfy the accuracy requested."}, + {-3, "CV_ERR_FAILURE: The solver encountered a non-recoverable error."}, + {-4, "CV_CONV_FAILURE: The solver failed to converge."}, + {-5, "CV_LINIT_FAIL: The linear solver's initialization function failed."}, + {-6, "CV_LSETUP_FAIL: The linear solver's setup function failed."}, + {-7, "CV_LSOLVE_FAIL: The linear solver's solve function failed."}, + {-8, "CV_RHSFUNC_FAIL: The right-hand side function failed in an unrecoverable manner."}, + {-9, "CV_FIRST_RHSFUNC_ERR: The right-hand side function failed at the first call."}, + {-10, "CV_REPTD_RHSFUNC_ERR: The right-hand side function repeatedly failed recoverable."}, + {-11, "CV_UNREC_RHSFUNC_ERR: The right-hand side function failed unrecoverably."}, + {-12, "CV_RTFUNC_FAIL: The rootfinding function failed in an unrecoverable manner."}, + {-13, "CV_NLS_INIT_FAIL: The nonlinear solver's initialization function failed."}, + {-14, "CV_NLS_SETUP_FAIL: The nonlinear solver's setup function failed."}, + {-15, "CV_CONSTR_FAIL : The inequality constraint was violated and the solver was unable to recover."}, + {-16, "CV_NLS_FAIL: The nonlinear solver's solve function failed."}, + {-20, "CV_MEM_FAIL: Memory allocation failed."}, + {-21, "CV_MEM_NULL: The CVODE memory structure is NULL."}, + {-22, "CV_ILL_INPUT: An illegal input was detected."}, + {-23, "CV_NO_MALLOC: The CVODE memory structure has not been allocated."}, + {-24, "CV_BAD_K: The value of k is invalid."}, + {-25, "CV_BAD_T: The value of t is invalid."}, + {-26, "CV_BAD_DKY: The value of dky is invalid."}, + {-27, "CV_TOO_CLOSE: The time points are too close together."}, + {-28, "CV_VECTOROP_ERR: A vector operation failed."}, + {-29, "CV_PROJ_MEM_NULL: The projection memory structure is NULL."}, + {-30, "CV_PROJFUNC_FAIL: The projection function failed in an unrecoverable manner."}, + {-31, "CV_REPTD_PROJFUNC_ERR: THe projection function has repeated recoverable errors."} + }; + void check_cvode_flag(const int flag, const std::string& func_name) { + if (flag < 0) { + if (!cvode_ret_code_map.contains(flag)) { + throw std::runtime_error("CVODE error in " + func_name + ": Unknown error code: " + std::to_string(flag)); + } + throw std::runtime_error("CVODE error in " + func_name + ": " + cvode_ret_code_map.at(flag)); + } + } + + N_Vector init_sun_vector(uint64_t size, SUNContext sun_ctx) { +#ifdef SUNDIALS_HAVE_OPENMP + N_Vector vec = N_VNew_OpenMP(size, 0, sun_ctx); +#elif SUNDIALS_HAVE_PTHREADS + N_Vector vec = N_VNew_Pthreads(size, sun_ctx); +#else + N_Vector vec = N_VNew_Serial(size, sun_ctx); +#endif + check_cvode_flag(vec == nullptr ? -1 : 0, "N_VNew"); + return vec; + } +} + +namespace gridfire::solver { + + CVODESolverStrategy::CVODESolverStrategy(DynamicEngine &engine): NetworkSolverStrategy(engine) { + // TODO: In order to support MPI this function must be changed + const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx); + if (flag < 0) { + throw std::runtime_error("Failed to create SUNDIALS context (Errno: " + std::to_string(flag) + ")"); + } + } + + CVODESolverStrategy::~CVODESolverStrategy() { + cleanup_cvode_resources(true); + + if (m_sun_ctx) { + SUNContext_Free(&m_sun_ctx); + } + } + + NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) { + const double T9 = netIn.temperature / 1e9; // Convert temperature from Kelvin to T9 (T9 = T / 1e9) + + const auto absTol = m_config.get("gridfire:solver:CVODESolverStrategy:absTol", 1.0e-8); + const auto relTol = m_config.get("gridfire:solver:CVODESolverStrategy:relTol", 1.0e-8); + + fourdst::composition::Composition equilibratedComposition = m_engine.update(netIn); + + size_t numSpecies = m_engine.getNetworkSpecies().size(); + uint64_t N = numSpecies + 1; + + m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx); + check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate"); + + initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0); + + CVODEUserData user_data; + user_data.solver_instance = this; + user_data.engine = &m_engine; + + double current_time = 0; + [[maybe_unused]] double last_callback_time = 0; + m_num_steps = 0; + double accumulated_energy = 0.0; + + while (current_time < netIn.tMax) { + try { + user_data.T9 = T9; + user_data.rho = netIn.density; + user_data.networkSpecies = &m_engine.getNetworkSpecies(); + user_data.captured_exception.reset(); + + check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData"); + + int flag = -1; + if (m_stdout_logging_enabled) { + flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP); + } else { + flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_NORMAL); + } + + if (user_data.captured_exception){ + std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception)); + } + + check_cvode_flag(flag, "CVode"); + + long int n_steps; + double last_step_size; + CVodeGetNumSteps(m_cvode_mem, &n_steps); + CVodeGetLastStep(m_cvode_mem, &last_step_size); + std::cout << std::scientific << std::setprecision(3) << "Step: " << std::setw(6) << n_steps << " | Time: " << current_time << " | Last Step Size: " << last_step_size << std::endl; + + } catch (const exceptions::StaleEngineTrigger& e) { + exceptions::StaleEngineTrigger::state staleState = e.getState(); + accumulated_energy += e.energy(); // Add the specific energy rate to the accumulated energy + // total_update_stages_triggered++; + + fourdst::composition::Composition temp_comp; + std::vector mass_fractions; + size_t num_species_at_stop = e.numSpecies(); + + if (num_species_at_stop != m_engine.getNetworkSpecies().size()) { + throw std::runtime_error( + "StaleEngineError state has a different number of species than the engine. This should not happen." + ); + } + mass_fractions.reserve(num_species_at_stop); + + for (size_t i = 0; i < num_species_at_stop; ++i) { + const auto& species = m_engine.getNetworkSpecies()[i]; + temp_comp.registerSpecies(species); + mass_fractions.push_back(e.getMolarAbundance(i) * species.mass()); // Convert from molar abundance to mass fraction + } + temp_comp.setMassFraction(m_engine.getNetworkSpecies(), mass_fractions); + temp_comp.finalize(true); + + NetIn netInTemp = netIn; + netInTemp.temperature = e.temperature(); + netInTemp.density = e.density(); + netInTemp.composition = temp_comp; + + fourdst::composition::Composition currentComposition = m_engine.update(netInTemp); + + numSpecies = m_engine.getNetworkSpecies().size(); + N = numSpecies + 1; + + initialize_cvode_integration_resources(N, numSpecies, current_time, temp_comp, absTol, relTol, accumulated_energy); + + check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit"); + } + } + + sunrealtype* y_data = N_VGetArrayPointer(m_Y); + accumulated_energy += y_data[numSpecies]; + + std::vector finalMassFractions(numSpecies); + for (size_t i = 0; i < numSpecies; ++i) { + const double molarMass = m_engine.getNetworkSpecies()[i].mass(); + finalMassFractions[i] = y_data[i] * molarMass; // Convert from molar abundance to mass fraction + if (finalMassFractions[i] < MIN_ABUNDANCE_THRESHOLD) { + finalMassFractions[i] = 0.0; + } + } + + std::vector speciesNames; + speciesNames.reserve(numSpecies); + for (const auto& species : m_engine.getNetworkSpecies()) { + speciesNames.push_back(std::string(species.name())); + } + + fourdst::composition::Composition outputComposition(speciesNames); + outputComposition.setMassFraction(speciesNames, finalMassFractions); + outputComposition.finalize(true); + + NetOut netOut; + netOut.composition = std::move(outputComposition); + netOut.energy = accumulated_energy; + check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast(&netOut.num_steps)), "CVodeGetNumSteps"); + return netOut; + } + + void CVODESolverStrategy::set_callback(const std::any &callback) { + m_callback = std::any_cast(callback); + } + + bool CVODESolverStrategy::get_stdout_logging_enabled() const { + return m_stdout_logging_enabled; + } + + void CVODESolverStrategy::set_stdout_logging_enabled(const bool value) { + m_stdout_logging_enabled = value; + } + + std::vector> CVODESolverStrategy::describe_callback_context() const { + return {}; + } + + int CVODESolverStrategy::cvode_rhs_wrapper( + sunrealtype t, + N_Vector y, + N_Vector ydot, + void *user_data + ) { + auto* data = static_cast(user_data); + auto* instance = data->solver_instance; + + try { + return instance->calculate_rhs(t, y, ydot, data); + } catch (const exceptions::StaleEngineTrigger& e) { + data->captured_exception = std::make_unique(e); + return 1; // 1 Indicates a recoverable error, CVODE will retry the step + } catch (...) { + return -1; // Some unrecoverable error + } + } + + int CVODESolverStrategy::cvode_jac_wrapper( + sunrealtype t, + N_Vector y, + N_Vector ydot, + SUNMatrix J, + void *user_data, + N_Vector tmp1, + N_Vector tmp2, + N_Vector tmp3 + ) { + const auto* data = static_cast(user_data); + const auto* engine = data->engine; + + const size_t numSpecies = engine->getNetworkSpecies().size(); + + sunrealtype* J_data = SUNDenseMatrix_Data(J); + const long int N = SUNDenseMatrix_Columns(J); + + for (size_t j = 0; j < numSpecies; ++j) { + for (size_t i = 0; i < numSpecies; ++i) { + // J(i,j) = d(f_i)/d(y_j) + // Column-major order format for SUNDenseMatrix: J_data[j*N + i] + J_data[j * N + i] = engine->getJacobianMatrixEntry(i, j); + } + } + + // For now assume that the energy derivatives wrt. abundances are zero + for (size_t i = 0; i < N; ++i) { + J_data[(N - 1) * N + i] = 0.0; // df(energy_dot)/df(y_i) + J_data[i * N + (N - 1)] = 0.0; // df(f_i)/df(energy_dot) + } + + return 0; + } + + int CVODESolverStrategy::calculate_rhs( + const sunrealtype t, + const N_Vector y, + const N_Vector ydot, + const CVODEUserData *data + ) const { + const size_t numSpecies = m_engine.getNetworkSpecies().size(); + sunrealtype* y_data = N_VGetArrayPointer(y); + + std::vector y_vec(y_data, y_data + numSpecies); + + std::ranges::replace_if(y_vec, [](const double val) { return val < 0.0; }, 0.0); + + const auto result = m_engine.calculateRHSAndEnergy(y_vec, data->T9, data->rho); + if (!result) { + throw exceptions::StaleEngineTrigger({data->T9, data->rho, y_vec, t, m_num_steps, y_data[numSpecies]}); + } + + sunrealtype* ydot_data = N_VGetArrayPointer(ydot); + const auto& [dydt, nuclearEnergyGenerationRate] = result.value(); + + for (size_t i = 0; i < numSpecies; ++i) { + ydot_data[i] = dydt[i]; + } + ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate + return 0; + } + + void CVODESolverStrategy::initialize_cvode_integration_resources( + const uint64_t N, + const size_t numSpecies, + const double current_time, + const fourdst::composition::Composition &composition, + const double absTol, + const double relTol, + const double accumulatedEnergy + ) { + cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones + + m_Y = init_sun_vector(N, m_sun_ctx); + + sunrealtype *y_data = N_VGetArrayPointer(m_Y); + for (size_t i = 0; i < numSpecies; i++) { + const auto& species = m_engine.getNetworkSpecies()[i]; + if (composition.contains(species)) { + y_data[i] = composition.getMolarAbundance(species); + } else { + y_data[i] = std::numeric_limits::min(); // Species not in the composition, set to a small value + } + } + y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero + + + check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit"); + check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances"); + + m_J = SUNDenseMatrix(static_cast(N), static_cast(N), m_sun_ctx); + check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix"); + m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx); + check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense"); + + check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver"); + check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn"); + } + + void CVODESolverStrategy::cleanup_cvode_resources(const bool memFree) { + if (m_LS) SUNLinSolFree(m_LS); + if (m_J) SUNMatDestroy(m_J); + if (m_Y) N_VDestroy(m_Y); + + m_LS = nullptr; + m_J = nullptr; + m_Y = nullptr; + + if (memFree) { + if (m_cvode_mem) CVodeFree(&m_cvode_mem); + m_cvode_mem = nullptr; + } + } + +} diff --git a/src/meson.build b/src/meson.build index 21785ee3..76320efd 100644 --- a/src/meson.build +++ b/src/meson.build @@ -13,6 +13,7 @@ gridfire_sources = files( 'lib/reaction/reaclib.cpp', 'lib/io/network_file.cpp', 'lib/solver/solver.cpp', + 'lib/solver/strategies/CVODE_solver_strategy.cpp', 'lib/screening/screening_types.cpp', 'lib/screening/screening_weak.cpp', 'lib/screening/screening_bare.cpp', @@ -33,6 +34,7 @@ gridfire_build_dependencies = [ xxhash_dep, eigen_dep, plugin_dep, + sundials_dep, ] # Define the libnetwork library so it can be linked against by other parts of the build system @@ -49,39 +51,6 @@ gridfire_dep = declare_dependency( dependencies: gridfire_build_dependencies, ) -# Make headers accessible -gridfire_headers = files( - 'include/gridfire/network.h', - 'include/gridfire/engine/engine_abstract.h', - 'include/gridfire/engine/views/engine_view_abstract.h', - 'include/gridfire/engine/engine_approx8.h', - 'include/gridfire/engine/engine_graph.h', - 'include/gridfire/engine/views/engine_adaptive.h', - 'include/gridfire/engine/views/engine_defined.h', - 'include/gridfire/engine/views/engine_multiscale.h', - 'include/gridfire/engine/views/engine_priming.h', - 'include/gridfire/engine/procedures/priming.h', - 'include/gridfire/engine/procedures/construction.h', - 'include/gridfire/reaction/reaction.h', - 'include/gridfire/reaction/reaclib.h', - 'include/gridfire/io/network_file.h', - 'include/gridfire/solver/solver.h', - 'include/gridfire/screening/screening_abstract.h', - 'include/gridfire/screening/screening_bare.h', - 'include/gridfire/screening/screening_weak.h', - 'include/gridfire/screening/screening_types.h', - 'include/gridfire/partition/partition_abstract.h', - 'include/gridfire/partition/partition_rauscher_thielemann.h', - 'include/gridfire/partition/partition_ground.h', - 'include/gridfire/partition/composite/partition_composite.h', - 'include/gridfire/utils/logging.h', -) -install_headers(gridfire_headers, subdir : 'gridfire') - -solver_interface_headers = files( - 'include/gridfire/interfaces/solver/solver_interfaces.h', -) - -install_headers(solver_interface_headers, subdir : 'gridfire/interfaces/solver') +install_subdir('include/gridfire', install_dir: get_option('includedir')) subdir('python') diff --git a/subprojects/cvode.wrap b/subprojects/cvode.wrap new file mode 100644 index 00000000..99ef269e --- /dev/null +++ b/subprojects/cvode.wrap @@ -0,0 +1,6 @@ +[wrap-file] +directory = cvode-7.3.0 + +source_url = https://github.com/LLNL/sundials/releases/download/v7.3.0/cvode-7.3.0.tar.gz +source_filename = cvode-7.3.0.tar.gz +source_hash = 8b15a646882f2414b1915cad4d53136717a077539e7cfc480f2002c5898ae568 diff --git a/tests/graphnet_sandbox/main.cpp b/tests/graphnet_sandbox/main.cpp index 7249ab43..5bba4689 100644 --- a/tests/graphnet_sandbox/main.cpp +++ b/tests/graphnet_sandbox/main.cpp @@ -11,6 +11,7 @@ #include "gridfire/io/network_file.h" #include "gridfire/solver/solver.h" +#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/network.h" @@ -42,8 +43,6 @@ void callback(const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) std::cout << "Time: " << ctx.t << ", H-1: " << ctx.state(H1Index) << ", He-4: " << ctx.state(He4Index) << "\n"; - size_t i = 0; - } void measure_execution_time(const std::function& callback, const std::string& name) @@ -100,8 +99,8 @@ int main(int argc, char* argv[]){ g_previousHandler = std::set_terminate(quill_terminate_handler); quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log"); - logger->set_log_level(quill::LogLevel::Info); - LOG_DEBUG(logger, "Starting Adaptive Engine View Example..."); + logger->set_log_level(quill::LogLevel::TraceL1); + LOG_INFO(logger, "Starting Adaptive Engine View Example..."); using namespace gridfire; const std::vector comp = {0.708, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4}; @@ -129,14 +128,10 @@ int main(int argc, char* argv[]){ GraphEngine ReaclibEngine(composition, partitionFunction, NetworkBuildDepth::SecondOrder); ReaclibEngine.setUseReverseReactions(false); - // ReaclibEngine.setScreeningModel(screening::ScreeningType::WEAK); - // MultiscalePartitioningEngineView partitioningView(ReaclibEngine); AdaptiveEngineView adaptiveView(partitioningView); - // - solver::DirectNetworkSolver solver(adaptiveView); - // consumptionFile << "t,X,a,b,c\n"; - solver.set_callback(callback); + + solver::CVODESolverStrategy solver(adaptiveView); NetOut netOut;