perf(engine_multiscale): performance enhancments due to improved hashing, locality, and data structure optimization
This particular commit speeds up QSE solving for systems where reverse reactions and engine caching is disabled by about 24%
This commit is contained in:
@@ -90,6 +90,9 @@ namespace gridfire::partition {
|
||||
private:
|
||||
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
||||
std::vector<std::unique_ptr<PartitionFunction>> m_partitionFunctions; ///< Set of partition functions to use in the composite partition function.
|
||||
|
||||
mutable std::unordered_map<uint_fast32_t, const PartitionFunction&> m_supportCache; ///< Cache mapping isotope keys to supporting partition functions for fast lookup.
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Instantiate a sub-function by its type.
|
||||
|
||||
@@ -116,12 +116,6 @@ namespace gridfire::rates::weak {
|
||||
) const;
|
||||
private:
|
||||
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
||||
/**
|
||||
* @brief Pack (A,Z) into a 32-bit key used for the internal map.
|
||||
*
|
||||
* Layout: (A << 8) | Z. To unpack, use (key >> 8) for A and (key & 0xFF) for Z.
|
||||
*/
|
||||
static uint32_t pack_isotope_id(uint16_t A, uint8_t Z);
|
||||
|
||||
/**
|
||||
* @brief Per-isotope grids over (T9, log10(rho*Ye), mu_e) with payloads at lattice nodes.
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
//
|
||||
// Created by Emily Boudreaux on 10/22/25.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#ifndef GRIDFIRE_HASHING_H
|
||||
#define GRIDFIRE_HASHING_H
|
||||
#include <cstdint>
|
||||
|
||||
#endif //GRIDFIRE_HASHING_H
|
||||
namespace gridfire::utils {
|
||||
/**
|
||||
* @brief Generate a unique hash for an isotope given its mass number (A) and atomic number (Z).
|
||||
* @details This function combines the mass number and atomic number into a single 32-bit integer
|
||||
* by shifting the mass number 8 bits to the left and OR'ing it with the atomic number.
|
||||
* This ensures a unique representation for each isotope within physically possible ranges.
|
||||
* @param a The mass number (A) of the isotope.
|
||||
* @param z The atomic number (Z) of the isotope.
|
||||
* @return A unique 32-bit hash representing the isotope. This is computed as (A << 8) | Z into an uint32_t.
|
||||
*/
|
||||
inline uint_fast32_t hash_atomic(const uint16_t a, const uint8_t z) noexcept {
|
||||
return (static_cast<uint_fast32_t>(a) << 8) | static_cast<uint_fast32_t>(z);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1312,11 +1312,9 @@ namespace gridfire {
|
||||
if ( p != 0) { return false; }
|
||||
const double T9 = tx[0];
|
||||
|
||||
// This is an interesting problem because the reverse rate should only ever be computed for strong reactions
|
||||
// Which do not depend on rho or Y. However, the signature requires them...
|
||||
// For now, we just pass dummy values for rho and Y
|
||||
// We can pass a dummy comp and rho because reverse rates should only be calculated for strong reactions whose
|
||||
// rates of progression do not depend on composition or density.
|
||||
const double reverseRate = m_engine.calculateReverseRate(m_reaction, T9, 0.0, {});
|
||||
// std::cout << m_reaction.peName() << " reverseRate: " << reverseRate << " at T9: " << T9 << "\n";
|
||||
ty[0] = reverseRate; // Store the reverse rate in the output vector
|
||||
|
||||
if (vx.size() > 0) {
|
||||
@@ -1335,9 +1333,6 @@ namespace gridfire {
|
||||
const double T9 = tx[0];
|
||||
const double reverseRate = ty[0];
|
||||
|
||||
// This is an interesting problem because the reverse rate should only ever be computed for strong reactions
|
||||
// Which do not depend on rho or Y. However, the signature requires them...
|
||||
// For now, we just pass dummy values for rho and Y
|
||||
const double derivative = m_engine.calculateReverseRateTwoBodyDerivative(m_reaction, T9, 0, {}, reverseRate);
|
||||
|
||||
px[0] = py[0] * derivative; // Return the derivative of the reverse rate with respect to T9
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "gridfire/partition/partition_ground.h"
|
||||
#include "gridfire/partition/partition_rauscher_thielemann.h"
|
||||
#include "gridfire/utils/hashing.h"
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
namespace gridfire::partition {
|
||||
@@ -25,13 +26,14 @@ namespace gridfire::partition {
|
||||
}
|
||||
|
||||
double CompositePartitionFunction::evaluate(int z, int a, double T9) const {
|
||||
LOG_TRACE_L3(m_logger, "Evaluating partition function for Z={} A={} T9={}", z, a, T9);
|
||||
const uint_fast32_t hash = utils::hash_atomic(a, z);
|
||||
if (m_supportCache.contains(hash)) {
|
||||
return m_supportCache.at(hash).evaluate(z, a, T9);
|
||||
}
|
||||
for (const auto& partitionFunction : m_partitionFunctions) {
|
||||
if (partitionFunction->supports(z, a)) {
|
||||
LOG_TRACE_L3(m_logger, "Partition function of type {} supports Z={} A={}", partitionFunction->type(), z, a);
|
||||
m_supportCache.emplace(hash, *partitionFunction);
|
||||
return partitionFunction->evaluate(z, a, T9);
|
||||
} else {
|
||||
LOG_TRACE_L3(m_logger, "Partition function of type {} does not support Z={} A={}", partitionFunction->type(), z, a);
|
||||
}
|
||||
}
|
||||
LOG_ERROR(
|
||||
@@ -46,9 +48,13 @@ namespace gridfire::partition {
|
||||
}
|
||||
|
||||
double CompositePartitionFunction::evaluateDerivative(int z, int a, double T9) const {
|
||||
const uint_fast32_t hash = utils::hash_atomic(a, z);
|
||||
if (m_supportCache.contains(hash)) {
|
||||
return m_supportCache.at(hash).evaluateDerivative(z, a, T9);
|
||||
}
|
||||
for (const auto& partitionFunction : m_partitionFunctions) {
|
||||
if (partitionFunction->supports(z, a)) {
|
||||
LOG_TRACE_L3(m_logger, "Evaluating derivative of partition function for Z={} A={} T9={}", z, a, T9);
|
||||
m_supportCache.emplace(hash, *partitionFunction);
|
||||
return partitionFunction->evaluateDerivative(z, a, T9);
|
||||
}
|
||||
}
|
||||
@@ -64,9 +70,12 @@ namespace gridfire::partition {
|
||||
}
|
||||
|
||||
bool CompositePartitionFunction::supports(int z, int a) const {
|
||||
const uint_fast32_t hash = utils::hash_atomic(a, z);
|
||||
if (m_supportCache.contains(hash)) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& partitionFunction : m_partitionFunctions) {
|
||||
if (partitionFunction->supports(z, a)) {
|
||||
LOG_TRACE_L2(m_logger, "Partition function supports Z={} A={}", z, a);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "gridfire/reaction/weak/weak_interpolator.h"
|
||||
#include "gridfire/reaction/reaction.h"
|
||||
#include "gridfire/reaction/weak/weak.h"
|
||||
#include "gridfire/utils/hashing.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
@@ -19,9 +20,9 @@ namespace gridfire::rates::weak {
|
||||
|
||||
WeakRateInterpolator::WeakRateInterpolator(const RowDataTable &raw_data) {
|
||||
// Group all raw data rows by their isotope ID.
|
||||
std::map<uint32_t, std::vector<const RateDataRow*>> grouped_rows;
|
||||
std::unordered_map<uint32_t, std::vector<const RateDataRow*>> grouped_rows;
|
||||
for (const auto& row : raw_data) {
|
||||
grouped_rows[pack_isotope_id(row.A, row.Z)].push_back(&row);
|
||||
grouped_rows[utils::hash_atomic(row.A, row.Z)].push_back(&row);
|
||||
}
|
||||
|
||||
// Process each isotope's data to build a simple 2D grid.
|
||||
@@ -48,22 +49,8 @@ namespace gridfire::rates::weak {
|
||||
for (size_t i = 0; i < nt9; i++) { t9_map[grid.t9_axis[i]] = i; }
|
||||
for (size_t j = 0; j < nrhoYe; j++) { rhoYe_map[grid.rhoYe_axis[j]] = j; }
|
||||
|
||||
// Use a set to detect duplicate (T9, rhoYe) pairs, which would be a data error.
|
||||
std::set<std::pair<float, float>> seen_coords;
|
||||
|
||||
// Populate the 2D grid.
|
||||
for (const auto* row: rows) {
|
||||
if (auto [it, inserted] = seen_coords.insert({row->t9, row->log_rhoye}); !inserted) {
|
||||
auto A = static_cast<uint16_t>(isotope_id >> 8);
|
||||
auto Z = static_cast<uint8_t>(isotope_id & 0xFF);
|
||||
std::string msg = std::format(
|
||||
"Duplicate data point for isotope (A={}, Z={}) at (T9={}, log10(rho*Ye)={}) in weak rate table. This indicates corrupted or malformed input data and should be taken as an unrecoverable error.",
|
||||
A, Z, row->t9, row->log_rhoye
|
||||
);
|
||||
LOG_ERROR(m_logger, "{}", msg);
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
|
||||
size_t i_t9 = t9_map.at(row->t9);
|
||||
size_t j_rhoYe = rhoYe_map.at(row->log_rhoye);
|
||||
|
||||
@@ -107,7 +94,7 @@ namespace gridfire::rates::weak {
|
||||
const double t9,
|
||||
const double log_rhoYe
|
||||
) const {
|
||||
const auto it = m_rate_table.find(pack_isotope_id(A, Z));
|
||||
const auto it = m_rate_table.find(utils::hash_atomic(A, Z));
|
||||
if (it == m_rate_table.end()) {
|
||||
return std::unexpected(InterpolationError{InterpolationErrorType::UNKNOWN_SPECIES_ERROR});
|
||||
}
|
||||
@@ -222,9 +209,4 @@ namespace gridfire::rates::weak {
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
uint32_t WeakRateInterpolator::pack_isotope_id(const uint16_t A, const uint8_t Z) {
|
||||
return (static_cast<uint32_t>(A) << 8) | static_cast<uint32_t>(Z);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user