feat(python): Python Bindings
Python Bindings are working again
This commit is contained in:
@@ -7,125 +7,226 @@
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
#include "gridfire/solver/strategies/PointSolver.h"
|
||||
#include "gridfire/engine/scratchpads/blob.h"
|
||||
#include "trampoline/py_solver.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void register_solver_bindings(const py::module &m) {
|
||||
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
|
||||
|
||||
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
|
||||
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
|
||||
auto py_cvode_timestep_context = py::class_<gridfire::solver::PointSolverTimestepContext>(m, "PointSolverTimestepContext");
|
||||
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::PointSolverTimestepContext::t);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"state",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<double> {
|
||||
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
|
||||
const sunindextype length = N_VGetLength(self.state);
|
||||
return std::vector<double>(nvec_data, nvec_data + length);
|
||||
return {nvec_data, nvec_data + length};
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
|
||||
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
|
||||
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
|
||||
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
|
||||
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
|
||||
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures);
|
||||
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations);
|
||||
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::PointSolverTimestepContext::dt);
|
||||
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::PointSolverTimestepContext::last_step_time);
|
||||
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::PointSolverTimestepContext::T9);
|
||||
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::PointSolverTimestepContext::rho);
|
||||
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::PointSolverTimestepContext::num_steps);
|
||||
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::PointSolverTimestepContext::currentConvergenceFailures);
|
||||
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::PointSolverTimestepContext::currentNonlinearIterations);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"engine",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::engine::DynamicEngine& {
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) -> const gridfire::engine::DynamicEngine& {
|
||||
return self.engine;
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"networkSpecies",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
||||
return self.networkSpecies;
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"state_ctx",
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) {
|
||||
return &(self.state_ctx);
|
||||
},
|
||||
py::return_value_policy::reference_internal
|
||||
);
|
||||
|
||||
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
|
||||
py_dynamic_network_solver_strategy.def(
|
||||
|
||||
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
|
||||
auto py_point_solver_context = py::class_<gridfire::solver::PointSolverContext, gridfire::solver::SolverContextBase>(m, "PointSolverContext");
|
||||
|
||||
py_point_solver_context
|
||||
.def_readonly(
|
||||
"sun_ctx", &gridfire::solver::PointSolverContext::sun_ctx
|
||||
)
|
||||
.def_readonly(
|
||||
"cvode_mem", &gridfire::solver::PointSolverContext::cvode_mem
|
||||
)
|
||||
.def_readonly(
|
||||
"Y", &gridfire::solver::PointSolverContext::Y
|
||||
)
|
||||
.def_readonly(
|
||||
"YErr", &gridfire::solver::PointSolverContext::YErr
|
||||
)
|
||||
.def_readonly(
|
||||
"J", &gridfire::solver::PointSolverContext::J
|
||||
)
|
||||
.def_readonly(
|
||||
"LS", &gridfire::solver::PointSolverContext::LS
|
||||
)
|
||||
.def_property_readonly(
|
||||
"engine_ctx",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> gridfire::engine::scratch::StateBlob& {
|
||||
return *(self.engine_ctx);
|
||||
},
|
||||
py::return_value_policy::reference
|
||||
)
|
||||
.def_readonly(
|
||||
"num_steps", &gridfire::solver::PointSolverContext::num_steps
|
||||
)
|
||||
.def_property(
|
||||
"abs_tol",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> double {
|
||||
return self.abs_tol.value();
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, double abs_tol) -> void {
|
||||
self.abs_tol = abs_tol;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"rel_tol",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> double {
|
||||
return self.rel_tol.value();
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, double rel_tol) -> void {
|
||||
self.rel_tol = rel_tol;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"stdout_logging",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> bool {
|
||||
return self.stdout_logging;
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
|
||||
self.stdout_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"detailed_logging",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> bool {
|
||||
return self.detailed_step_logging;
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
|
||||
self.detailed_step_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"callback",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>> {
|
||||
return self.callback;
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, const std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>>& cb) {
|
||||
self.callback = cb;
|
||||
}
|
||||
)
|
||||
.def("reset_all", &gridfire::solver::PointSolverContext::reset_all)
|
||||
.def("reset_user", &gridfire::solver::PointSolverContext::reset_user)
|
||||
.def("reset_cvode", &gridfire::solver::PointSolverContext::reset_cvode)
|
||||
.def("clear_context", &gridfire::solver::PointSolverContext::clear_context)
|
||||
.def("init_context", &gridfire::solver::PointSolverContext::init_context)
|
||||
.def("has_context", &gridfire::solver::PointSolverContext::has_context)
|
||||
.def("init", &gridfire::solver::PointSolverContext::init)
|
||||
.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("engine_ctx"));
|
||||
|
||||
|
||||
|
||||
auto py_single_zone_dynamic_network_solver = py::class_<gridfire::solver::SingleZoneDynamicNetworkSolver, PySingleZoneDynamicNetworkSolver>(m, "SingleZoneDynamicNetworkSolver");
|
||||
py_single_zone_dynamic_network_solver.def(
|
||||
"evaluate",
|
||||
&gridfire::solver::DynamicNetworkSolverStrategy::evaluate,
|
||||
&gridfire::solver::SingleZoneDynamicNetworkSolver::evaluate,
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIn"),
|
||||
"evaluate the dynamic engine using the dynamic engine class"
|
||||
"evaluate the dynamic engine using the dynamic engine class for a single zone"
|
||||
);
|
||||
auto py_multi_zone_dynamic_network_solver = py::class_<gridfire::solver::MultiZoneDynamicNetworkSolver, PyMultiZoneDynamicNetworkSolver>(m, "MultiZoneDynamicNetworkSolver");
|
||||
py_multi_zone_dynamic_network_solver.def(
|
||||
"evaluate",
|
||||
&gridfire::solver::MultiZoneDynamicNetworkSolver::evaluate,
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIns"),
|
||||
"evaluate the dynamic engine using the dynamic engine class for multiple zones (using openmp if available)"
|
||||
);
|
||||
|
||||
auto py_point_solver = py::class_<gridfire::solver::PointSolver, gridfire::solver::SingleZoneDynamicNetworkSolver>(m, "PointSolver");
|
||||
|
||||
py_dynamic_network_solver_strategy.def(
|
||||
"describe_callback_context",
|
||||
&gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context,
|
||||
"Get a structure representing what data is in the callback context in a human readable format"
|
||||
);
|
||||
|
||||
auto py_cvode_solver_strategy = py::class_<gridfire::solver::CVODESolverStrategy, gridfire::solver::DynamicNetworkSolverStrategy>(m, "CVODESolverStrategy");
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
py_point_solver.def(
|
||||
py::init<gridfire::engine::DynamicEngine&>(),
|
||||
py::arg("engine"),
|
||||
"Initialize the CVODESolverStrategy object."
|
||||
"Initialize the PointSolver object."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
py_point_solver.def(
|
||||
"evaluate",
|
||||
py::overload_cast<const gridfire::NetIn&, bool>(&gridfire::solver::CVODESolverStrategy::evaluate),
|
||||
py::overload_cast<gridfire::solver::SolverContextBase&, const gridfire::NetIn&, bool, bool>(&gridfire::solver::PointSolver::evaluate, py::const_),
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIn"),
|
||||
py::arg("display_trigger") = false,
|
||||
py::arg("force_reinitialization") = false,
|
||||
"evaluate the dynamic engine using the dynamic engine class"
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"get_stdout_logging_enabled",
|
||||
&gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled,
|
||||
"Check if solver logging to standard output is enabled."
|
||||
auto py_grid_solver_context = py::class_<gridfire::solver::GridSolverContext, gridfire::solver::SolverContextBase>(m, "GridSolverContext");
|
||||
py_grid_solver_context.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("ctx_template"));
|
||||
py_grid_solver_context.def("init", &gridfire::solver::GridSolverContext::init);
|
||||
py_grid_solver_context.def("reset", &gridfire::solver::GridSolverContext::reset);
|
||||
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"));
|
||||
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&, size_t>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"), py::arg("zone_idx"));
|
||||
py_grid_solver_context.def("clear_callback", py::overload_cast<>(&gridfire::solver::GridSolverContext::clear_callback));
|
||||
py_grid_solver_context.def("clear_callback", py::overload_cast<size_t>(&gridfire::solver::GridSolverContext::clear_callback), py::arg("zone_idx"));
|
||||
py_grid_solver_context.def_property(
|
||||
"stdout_logging",
|
||||
[](const gridfire::solver::GridSolverContext& self) -> bool {
|
||||
return self.zone_stdout_logging;
|
||||
},
|
||||
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
|
||||
self.zone_stdout_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"detailed_logging",
|
||||
[](const gridfire::solver::GridSolverContext& self) -> bool {
|
||||
return self.zone_detailed_logging;
|
||||
},
|
||||
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
|
||||
self.zone_detailed_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"zone_completion_logging",
|
||||
[](const gridfire::solver::GridSolverContext& self) -> bool {
|
||||
return self.zone_completion_logging;
|
||||
},
|
||||
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
|
||||
self.zone_completion_logging = enable;
|
||||
}
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_stdout_logging_enabled",
|
||||
&gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled,
|
||||
py::arg("logging_enabled"),
|
||||
"Enable logging to standard output."
|
||||
auto py_grid_solver = py::class_<gridfire::solver::GridSolver, gridfire::solver::MultiZoneDynamicNetworkSolver>(m, "GridSolver");
|
||||
py_grid_solver.def(
|
||||
py::init<const gridfire::engine::DynamicEngine&, const gridfire::solver::SingleZoneDynamicNetworkSolver&>(),
|
||||
py::arg("engine"),
|
||||
py::arg("solver"),
|
||||
"Initialize the GridSolver object."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_absTol",
|
||||
&gridfire::solver::CVODESolverStrategy::set_absTol,
|
||||
py::arg("absTol"),
|
||||
"Set the absolute tolerance for the CVODE solver."
|
||||
py_grid_solver.def(
|
||||
"evaluate",
|
||||
&gridfire::solver::GridSolver::evaluate,
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIns"),
|
||||
"evaluate the dynamic engine using the dynamic engine class"
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_relTol",
|
||||
&gridfire::solver::CVODESolverStrategy::set_relTol,
|
||||
py::arg("relTol"),
|
||||
"Set the relative tolerance for the CVODE solver."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"get_absTol",
|
||||
&gridfire::solver::CVODESolverStrategy::get_absTol,
|
||||
"Get the absolute tolerance for the CVODE solver."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"get_relTol",
|
||||
&gridfire::solver::CVODESolverStrategy::get_relTol,
|
||||
"Get the relative tolerance for the CVODE solver."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_callback",
|
||||
[](
|
||||
gridfire::solver::CVODESolverStrategy& self,
|
||||
std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
|
||||
) {
|
||||
self.set_callback(std::any(cb));
|
||||
},
|
||||
py::arg("cb"),
|
||||
"Set a callback function which will run at the end of every successful timestep"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user