feat(solver): added callback functions to solver in C++ and python
This commit is contained in:
@@ -29,6 +29,21 @@
|
||||
|
||||
static std::terminate_handler g_previousHandler = nullptr;
|
||||
|
||||
static std::ofstream consumptionFile("consumption.txt");
|
||||
|
||||
void callback(const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) {
|
||||
const auto H1IndexPtr = std::ranges::find(ctx.engine.getNetworkSpecies(), fourdst::atomic::H_1);
|
||||
const auto He4IndexPtr = std::ranges::find(ctx.engine.getNetworkSpecies(), fourdst::atomic::He_4);
|
||||
|
||||
const size_t H1Index = H1IndexPtr != ctx.engine.getNetworkSpecies().end() ? std::distance(ctx.engine.getNetworkSpecies().begin(), H1IndexPtr) : -1;
|
||||
const size_t He4Index = He4IndexPtr != ctx.engine.getNetworkSpecies().end() ? std::distance(ctx.engine.getNetworkSpecies().begin(), He4IndexPtr) : -1;
|
||||
|
||||
if (H1Index != -1 && He4Index != -1) {
|
||||
std::cout << "Found H-1 at index: " << H1Index << ", He-4 at index: " << He4Index << "\n";
|
||||
consumptionFile << ctx.t << "," << ctx.state(H1Index) << "," << ctx.state(He4Index) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void measure_execution_time(const std::function<void()>& callback, const std::string& name)
|
||||
{
|
||||
const auto startTime = std::chrono::steady_clock::now();
|
||||
@@ -71,31 +86,34 @@ int main() {
|
||||
|
||||
NetIn netIn;
|
||||
netIn.composition = composition;
|
||||
netIn.temperature = 5e9;
|
||||
netIn.density = 1.6e6;
|
||||
netIn.temperature = 1.5e7;
|
||||
netIn.density = 1.6e2;
|
||||
netIn.energy = 0;
|
||||
// netIn.tMax = 3.1536e17; // ~ 10Gyr
|
||||
netIn.tMax = 1e-14;
|
||||
netIn.tMax = 5e17;
|
||||
// netIn.tMax = 1e-14;
|
||||
netIn.dt0 = 1e-12;
|
||||
|
||||
GraphEngine ReaclibEngine(composition, partitionFunction, NetworkBuildDepth::SecondOrder);
|
||||
ReaclibEngine.setUseReverseReactions(true);
|
||||
ReaclibEngine.setUseReverseReactions(false);
|
||||
// ReaclibEngine.setScreeningModel(screening::ScreeningType::WEAK);
|
||||
//
|
||||
MultiscalePartitioningEngineView partitioningView(ReaclibEngine);
|
||||
AdaptiveEngineView adaptiveView(partitioningView);
|
||||
//
|
||||
solver::DirectNetworkSolver solver(adaptiveView);
|
||||
consumptionFile << "t,X,a,b,c\n";
|
||||
solver.set_callback(callback);
|
||||
NetOut netOut;
|
||||
|
||||
|
||||
netOut = solver.evaluate(netIn);
|
||||
consumptionFile.close();
|
||||
std::cout << "Initial H-1: " << netIn.composition.getMassFraction("H-1") << std::endl;
|
||||
std::cout << "NetOut H-1: " << netOut.composition.getMassFraction("H-1") << std::endl;
|
||||
std::cout << "Consumed " << (netIn.composition.getMassFraction("H-1") - netOut.composition.getMassFraction("H-1")) * 100 << " % H-1 by mass" << std::endl;
|
||||
// measure_execution_time([&](){netOut = solver.evaluate(netIn);}, "DirectNetworkSolver Evaluation");
|
||||
// std::cout << "DirectNetworkSolver completed in " << netOut.num_steps << " steps.\n";
|
||||
// std::cout << "Final composition:\n";
|
||||
// for (const auto& [symbol, entry] : netOut.composition) {
|
||||
// std::cout << symbol << ": " << entry.mass_fraction() << "\n";
|
||||
// }
|
||||
|
||||
double initialHydrogen = netIn.composition.getMassFraction("H-1");
|
||||
double finalHydrogen = netOut.composition.getMassFraction("H-1");
|
||||
double fractionalConsumedHydrogen = (initialHydrogen - finalHydrogen) / initialHydrogen * 100.0;
|
||||
std::cout << "Fractional consumed hydrogen: " << fractionalConsumedHydrogen << "%" << std::endl;
|
||||
|
||||
}
|
||||
@@ -3,6 +3,7 @@ from gridfire.solver import DirectNetworkSolver
|
||||
from gridfire.type import NetIn
|
||||
|
||||
from fourdst.composition import Composition
|
||||
from fourdst.atomic import species
|
||||
|
||||
symbols : list[str] = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
|
||||
X : list[float] = [0.708, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4]
|
||||
@@ -19,7 +20,7 @@ netIn = NetIn()
|
||||
netIn.composition = comp
|
||||
netIn.temperature = 1.5e7
|
||||
netIn.density = 1.6e2
|
||||
netIn.tMax = 1e-9
|
||||
netIn.tMax = 4e17
|
||||
netIn.dt0 = 1e-12
|
||||
|
||||
baseEngine = GraphEngine(netIn.composition, 2)
|
||||
@@ -31,6 +32,15 @@ adaptiveEngine = AdaptiveEngineView(qseEngine)
|
||||
|
||||
solver = DirectNetworkSolver(adaptiveEngine)
|
||||
|
||||
|
||||
def callback(context):
|
||||
H1Index = context.engine.getSpeciesIndex(species["H-1"])
|
||||
He4Index = context.engine.getSpeciesIndex(species["He-4"])
|
||||
C12ndex = context.engine.getSpeciesIndex(species["C-12"])
|
||||
Mgh24ndex = context.engine.getSpeciesIndex(species["Mg-24"])
|
||||
print(f"Time: {context.t}, H-1: {context.state[H1Index]}, He-4: {context.state[He4Index]}, C-12: {context.state[C12ndex]}, Mg-24: {context.state[Mgh24ndex]}")
|
||||
|
||||
# solver.set_callback(callback)
|
||||
results = solver.evaluate(netIn)
|
||||
|
||||
print(f"Final H-1 mass fraction {results.composition.getMassFraction("H-1")}")
|
||||
|
||||
Reference in New Issue
Block a user