feat(python): Python Bindings

Python Bindings are working again
This commit is contained in:
2025-12-20 16:02:52 -05:00
parent d65c237b26
commit 11a596b75b
78 changed files with 4411 additions and 1110 deletions

View File

@@ -7,125 +7,226 @@
#include "bindings.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "trampoline/py_solver.h"
namespace py = pybind11;
void register_solver_bindings(const py::module &m) {
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
auto py_cvode_timestep_context = py::class_<gridfire::solver::PointSolverTimestepContext>(m, "PointSolverTimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::PointSolverTimestepContext::t);
py_cvode_timestep_context.def_property_readonly(
"state",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<double> {
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
const sunindextype length = N_VGetLength(self.state);
return std::vector<double>(nvec_data, nvec_data + length);
return {nvec_data, nvec_data + length};
}
);
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures);
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations);
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::PointSolverTimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::PointSolverTimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::PointSolverTimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::PointSolverTimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::PointSolverTimestepContext::num_steps);
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::PointSolverTimestepContext::currentConvergenceFailures);
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::PointSolverTimestepContext::currentNonlinearIterations);
py_cvode_timestep_context.def_property_readonly(
"engine",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::engine::DynamicEngine& {
[](const gridfire::solver::PointSolverTimestepContext& self) -> const gridfire::engine::DynamicEngine& {
return self.engine;
}
);
py_cvode_timestep_context.def_property_readonly(
"networkSpecies",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<fourdst::atomic::Species> {
return self.networkSpecies;
}
);
py_cvode_timestep_context.def_property_readonly(
"state_ctx",
[](const gridfire::solver::PointSolverTimestepContext& self) {
return &(self.state_ctx);
},
py::return_value_policy::reference_internal
);
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
py_dynamic_network_solver_strategy.def(
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
auto py_point_solver_context = py::class_<gridfire::solver::PointSolverContext, gridfire::solver::SolverContextBase>(m, "PointSolverContext");
py_point_solver_context
.def_readonly(
"sun_ctx", &gridfire::solver::PointSolverContext::sun_ctx
)
.def_readonly(
"cvode_mem", &gridfire::solver::PointSolverContext::cvode_mem
)
.def_readonly(
"Y", &gridfire::solver::PointSolverContext::Y
)
.def_readonly(
"YErr", &gridfire::solver::PointSolverContext::YErr
)
.def_readonly(
"J", &gridfire::solver::PointSolverContext::J
)
.def_readonly(
"LS", &gridfire::solver::PointSolverContext::LS
)
.def_property_readonly(
"engine_ctx",
[](const gridfire::solver::PointSolverContext& self) -> gridfire::engine::scratch::StateBlob& {
return *(self.engine_ctx);
},
py::return_value_policy::reference
)
.def_readonly(
"num_steps", &gridfire::solver::PointSolverContext::num_steps
)
.def_property(
"abs_tol",
[](const gridfire::solver::PointSolverContext& self) -> double {
return self.abs_tol.value();
},
[](gridfire::solver::PointSolverContext& self, double abs_tol) -> void {
self.abs_tol = abs_tol;
}
)
.def_property(
"rel_tol",
[](const gridfire::solver::PointSolverContext& self) -> double {
return self.rel_tol.value();
},
[](gridfire::solver::PointSolverContext& self, double rel_tol) -> void {
self.rel_tol = rel_tol;
}
)
.def_property(
"stdout_logging",
[](const gridfire::solver::PointSolverContext& self) -> bool {
return self.stdout_logging;
},
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
self.stdout_logging = enable;
}
)
.def_property(
"detailed_logging",
[](const gridfire::solver::PointSolverContext& self) -> bool {
return self.detailed_step_logging;
},
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
self.detailed_step_logging = enable;
}
)
.def_property(
"callback",
[](const gridfire::solver::PointSolverContext& self) -> std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>> {
return self.callback;
},
[](gridfire::solver::PointSolverContext& self, const std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>>& cb) {
self.callback = cb;
}
)
.def("reset_all", &gridfire::solver::PointSolverContext::reset_all)
.def("reset_user", &gridfire::solver::PointSolverContext::reset_user)
.def("reset_cvode", &gridfire::solver::PointSolverContext::reset_cvode)
.def("clear_context", &gridfire::solver::PointSolverContext::clear_context)
.def("init_context", &gridfire::solver::PointSolverContext::init_context)
.def("has_context", &gridfire::solver::PointSolverContext::has_context)
.def("init", &gridfire::solver::PointSolverContext::init)
.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("engine_ctx"));
auto py_single_zone_dynamic_network_solver = py::class_<gridfire::solver::SingleZoneDynamicNetworkSolver, PySingleZoneDynamicNetworkSolver>(m, "SingleZoneDynamicNetworkSolver");
py_single_zone_dynamic_network_solver.def(
"evaluate",
&gridfire::solver::DynamicNetworkSolverStrategy::evaluate,
&gridfire::solver::SingleZoneDynamicNetworkSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIn"),
"evaluate the dynamic engine using the dynamic engine class"
"evaluate the dynamic engine using the dynamic engine class for a single zone"
);
auto py_multi_zone_dynamic_network_solver = py::class_<gridfire::solver::MultiZoneDynamicNetworkSolver, PyMultiZoneDynamicNetworkSolver>(m, "MultiZoneDynamicNetworkSolver");
py_multi_zone_dynamic_network_solver.def(
"evaluate",
&gridfire::solver::MultiZoneDynamicNetworkSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIns"),
"evaluate the dynamic engine using the dynamic engine class for multiple zones (using openmp if available)"
);
auto py_point_solver = py::class_<gridfire::solver::PointSolver, gridfire::solver::SingleZoneDynamicNetworkSolver>(m, "PointSolver");
py_dynamic_network_solver_strategy.def(
"describe_callback_context",
&gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context,
"Get a structure representing what data is in the callback context in a human readable format"
);
auto py_cvode_solver_strategy = py::class_<gridfire::solver::CVODESolverStrategy, gridfire::solver::DynamicNetworkSolverStrategy>(m, "CVODESolverStrategy");
py_cvode_solver_strategy.def(
py_point_solver.def(
py::init<gridfire::engine::DynamicEngine&>(),
py::arg("engine"),
"Initialize the CVODESolverStrategy object."
"Initialize the PointSolver object."
);
py_cvode_solver_strategy.def(
py_point_solver.def(
"evaluate",
py::overload_cast<const gridfire::NetIn&, bool>(&gridfire::solver::CVODESolverStrategy::evaluate),
py::overload_cast<gridfire::solver::SolverContextBase&, const gridfire::NetIn&, bool, bool>(&gridfire::solver::PointSolver::evaluate, py::const_),
py::arg("solver_ctx"),
py::arg("netIn"),
py::arg("display_trigger") = false,
py::arg("force_reinitialization") = false,
"evaluate the dynamic engine using the dynamic engine class"
);
py_cvode_solver_strategy.def(
"get_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled,
"Check if solver logging to standard output is enabled."
auto py_grid_solver_context = py::class_<gridfire::solver::GridSolverContext, gridfire::solver::SolverContextBase>(m, "GridSolverContext");
py_grid_solver_context.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("ctx_template"));
py_grid_solver_context.def("init", &gridfire::solver::GridSolverContext::init);
py_grid_solver_context.def("reset", &gridfire::solver::GridSolverContext::reset);
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"));
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&, size_t>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"), py::arg("zone_idx"));
py_grid_solver_context.def("clear_callback", py::overload_cast<>(&gridfire::solver::GridSolverContext::clear_callback));
py_grid_solver_context.def("clear_callback", py::overload_cast<size_t>(&gridfire::solver::GridSolverContext::clear_callback), py::arg("zone_idx"));
py_grid_solver_context.def_property(
"stdout_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_stdout_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_stdout_logging = enable;
}
)
.def_property(
"detailed_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_detailed_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_detailed_logging = enable;
}
)
.def_property(
"zone_completion_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_completion_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_completion_logging = enable;
}
);
py_cvode_solver_strategy.def(
"set_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled,
py::arg("logging_enabled"),
"Enable logging to standard output."
auto py_grid_solver = py::class_<gridfire::solver::GridSolver, gridfire::solver::MultiZoneDynamicNetworkSolver>(m, "GridSolver");
py_grid_solver.def(
py::init<const gridfire::engine::DynamicEngine&, const gridfire::solver::SingleZoneDynamicNetworkSolver&>(),
py::arg("engine"),
py::arg("solver"),
"Initialize the GridSolver object."
);
py_cvode_solver_strategy.def(
"set_absTol",
&gridfire::solver::CVODESolverStrategy::set_absTol,
py::arg("absTol"),
"Set the absolute tolerance for the CVODE solver."
py_grid_solver.def(
"evaluate",
&gridfire::solver::GridSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIns"),
"evaluate the dynamic engine using the dynamic engine class"
);
py_cvode_solver_strategy.def(
"set_relTol",
&gridfire::solver::CVODESolverStrategy::set_relTol,
py::arg("relTol"),
"Set the relative tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"get_absTol",
&gridfire::solver::CVODESolverStrategy::get_absTol,
"Get the absolute tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"get_relTol",
&gridfire::solver::CVODESolverStrategy::get_relTol,
"Get the relative tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"set_callback",
[](
gridfire::solver::CVODESolverStrategy& self,
std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
) {
self.set_callback(std::any(cb));
},
py::arg("cb"),
"Set a callback function which will run at the end of every successful timestep"
);
}