fix(python-bindings): Updated python bindings to new interface

The python bindings now work with the polymorphic reaction class and the CVODE solver
This commit is contained in:
2025-10-30 15:05:08 -04:00
parent 23df87f915
commit 7fded59814
27 changed files with 962 additions and 255 deletions

View File

@@ -2,71 +2,97 @@
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <pybind11/numpy.h>
#include <functional>
#include <boost/numeric/ublas/vector.hpp>
#include "bindings.h"
#include "gridfire/solver/solver.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "trampoline/py_solver.h"
namespace py = pybind11;
void register_solver_bindings(const py::module &m) {
auto py_dynamic_network_solving_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
auto py_direct_network_solver = py::class_<gridfire::solver::DirectNetworkSolver, gridfire::solver::DynamicNetworkSolverStrategy>(m, "DirectNetworkSolver");
py_direct_network_solver.def(py::init<gridfire::DynamicEngine&>(),
py::arg("engine"),
"Constructor for the DirectNetworkSolver. Takes a DynamicEngine instance to use for evaluating the network."
);
py_direct_network_solver.def("evaluate",
&gridfire::solver::DirectNetworkSolver::evaluate,
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
py_dynamic_network_solver_strategy.def(
"evaluate",
&gridfire::solver::DynamicNetworkSolverStrategy::evaluate,
py::arg("netIn"),
"Evaluate the network for a given timestep. Returns the output conditions after the timestep."
"evaluate the dynamic engine using the dynamic engine class"
);
py_direct_network_solver.def("set_callback",
[](gridfire::solver::DirectNetworkSolver &self, const gridfire::solver::DirectNetworkSolver::TimestepCallback& cb) {
py_dynamic_network_solver_strategy.def(
"set_callback",
[](gridfire::solver::DynamicNetworkSolverStrategy& self, std::function<void(const gridfire::solver::SolverContextBase&)> cb) {
self.set_callback(cb);
},
py::arg("callback"),
"Sets a callback function to be called at each timestep."
"Set a callback function which will run at the end of every successful timestep"
);
py::class_<gridfire::solver::DirectNetworkSolver::TimestepContext>(py_direct_network_solver, "TimestepContext")
.def_readonly("t", &gridfire::solver::DirectNetworkSolver::TimestepContext::t, "Current time in the simulation.")
.def_property_readonly(
"state", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) {
std::vector<double> state(ctx.state.size());
std::ranges::copy(ctx.state, state.begin());
return py::array_t<double>(static_cast<ssize_t>(state.size()), state.data());
})
.def_readonly("dt", &gridfire::solver::DirectNetworkSolver::TimestepContext::dt, "Current timestep size.")
.def_readonly("cached_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::cached_time, "Cached time for the last computed result.")
.def_readonly("last_observed_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_observed_time, "Last time the state was observed.")
.def_readonly("last_step_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_step_time, "Last step time taken for the integration.")
.def_readonly("T9", &gridfire::solver::DirectNetworkSolver::TimestepContext::T9, "Temperature in units of 10^9 K.")
.def_readonly("rho", &gridfire::solver::DirectNetworkSolver::TimestepContext::rho, "Temperature in units of 10^9 K.")
.def_property_readonly("cached_result", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) -> py::object {
if (ctx.cached_result.has_value()) {
const auto&[dydt, nuclearEnergyGenerationRate] = ctx.cached_result.value();
return py::make_tuple(
py::array_t<double>(static_cast<ssize_t>(dydt.size()), dydt.data()),
nuclearEnergyGenerationRate
);
}
return py::none();
}, "Cached result of the step derivatives.")
.def_readonly("num_steps", &gridfire::solver::DirectNetworkSolver::TimestepContext::num_steps, "Total number of steps taken in the simulation.")
.def_property_readonly("engine", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const gridfire::DynamicEngine & {
return ctx.engine;
}, py::return_value_policy::reference)
py_dynamic_network_solver_strategy.def(
"describe_callback_context",
&gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context,
"Get a structure representing what data is in the callback context in a human readable format"
);
auto py_cvode_solver_strategy = py::class_<gridfire::solver::CVODESolverStrategy, gridfire::solver::DynamicNetworkSolverStrategy>(m, "CVODESolverStrategy");
py_cvode_solver_strategy.def(
py::init<gridfire::DynamicEngine&>(),
py::arg("engine"),
"Initialize the CVODESolverStrategy object."
);
py_cvode_solver_strategy.def(
"evaluate",
py::overload_cast<const gridfire::NetIn&, bool>(&gridfire::solver::CVODESolverStrategy::evaluate),
py::arg("netIn"),
py::arg("display_trigger"),
"evaluate the dynamic engine using the dynamic engine class"
);
py_cvode_solver_strategy.def(
"get_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled,
"Check if solver logging to standard output is enabled."
);
py_cvode_solver_strategy.def(
"set_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled,
py::arg("logging_enabled"),
"Enable logging to standard output."
);
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext>(m, "CVODETimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
py_cvode_timestep_context.def_property_readonly(
"state",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
const sunindextype length = N_VGetLength(self.state);
return std::vector<double>(nvec_data, nvec_data + length);
}
);
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
py_cvode_timestep_context.def_property_readonly(
"engine",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
return self.engine;
}
);
py_cvode_timestep_context.def_property_readonly(
"networkSpecies",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
return self.networkSpecies;
}
);
.def_property_readonly("network_species", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const std::vector<fourdst::atomic::Species> & {
return ctx.networkSpecies;
}, py::return_value_policy::reference);
}