#include #include // Needed for vectors, maps, sets, strings #include // Needed for binding std::vector, std::map etc. if needed directly #include #include #include #include "bindings.h" #include "gridfire/solver/strategies/PointSolver.h" #include "gridfire/engine/scratchpads/blob.h" #include "trampoline/py_solver.h" namespace py = pybind11; void register_solver_bindings(const py::module &m) { auto py_cvode_timestep_context = py::class_(m, "PointSolverTimestepContext"); py_cvode_timestep_context.def_readonly("t", &gridfire::solver::PointSolverTimestepContext::t); py_cvode_timestep_context.def_property_readonly( "state", [](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector { const sunrealtype* nvec_data = N_VGetArrayPointer(self.state); const sunindextype length = N_VGetLength(self.state); return {nvec_data, nvec_data + length}; } ); py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::PointSolverTimestepContext::dt); py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::PointSolverTimestepContext::last_step_time); py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::PointSolverTimestepContext::T9); py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::PointSolverTimestepContext::rho); py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::PointSolverTimestepContext::num_steps); py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::PointSolverTimestepContext::currentConvergenceFailures); py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::PointSolverTimestepContext::currentNonlinearIterations); py_cvode_timestep_context.def_property_readonly( "engine", [](const gridfire::solver::PointSolverTimestepContext& self) -> const gridfire::engine::DynamicEngine& { return self.engine; } ); py_cvode_timestep_context.def_property_readonly( "networkSpecies", [](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector { return self.networkSpecies; } ); py_cvode_timestep_context.def_property_readonly( "state_ctx", [](const gridfire::solver::PointSolverTimestepContext& self) { return &(self.state_ctx); }, py::return_value_policy::reference_internal ); auto py_solver_context_base = py::class_(m, "SolverContextBase"); auto py_point_solver_context = py::class_(m, "PointSolverContext"); py_point_solver_context .def_readonly( "sun_ctx", &gridfire::solver::PointSolverContext::sun_ctx ) .def_readonly( "cvode_mem", &gridfire::solver::PointSolverContext::cvode_mem ) .def_readonly( "Y", &gridfire::solver::PointSolverContext::Y ) .def_readonly( "YErr", &gridfire::solver::PointSolverContext::YErr ) .def_readonly( "J", &gridfire::solver::PointSolverContext::J ) .def_readonly( "LS", &gridfire::solver::PointSolverContext::LS ) .def_property_readonly( "engine_ctx", [](const gridfire::solver::PointSolverContext& self) -> gridfire::engine::scratch::StateBlob& { return *(self.engine_ctx); }, py::return_value_policy::reference ) .def_readonly( "num_steps", &gridfire::solver::PointSolverContext::num_steps ) .def_property( "abs_tol", [](const gridfire::solver::PointSolverContext& self) -> double { return self.abs_tol.value(); }, [](gridfire::solver::PointSolverContext& self, double abs_tol) -> void { self.abs_tol = abs_tol; } ) .def_property( "rel_tol", [](const gridfire::solver::PointSolverContext& self) -> double { return self.rel_tol.value(); }, [](gridfire::solver::PointSolverContext& self, double rel_tol) -> void { self.rel_tol = rel_tol; } ) .def_property( "stdout_logging", [](const gridfire::solver::PointSolverContext& self) -> bool { return self.stdout_logging; }, [](gridfire::solver::PointSolverContext& self, const bool enable) -> void { self.stdout_logging = enable; } ) .def_property( "detailed_logging", [](const gridfire::solver::PointSolverContext& self) -> bool { return self.detailed_step_logging; }, [](gridfire::solver::PointSolverContext& self, const bool enable) -> void { self.detailed_step_logging = enable; } ) .def_property( "callback", [](const gridfire::solver::PointSolverContext& self) -> std::optional> { return self.callback; }, [](gridfire::solver::PointSolverContext& self, const std::optional>& cb) { self.callback = cb; } ) .def("reset_all", &gridfire::solver::PointSolverContext::reset_all) .def("reset_user", &gridfire::solver::PointSolverContext::reset_user) .def("reset_cvode", &gridfire::solver::PointSolverContext::reset_cvode) .def("clear_context", &gridfire::solver::PointSolverContext::clear_context) .def("init_context", &gridfire::solver::PointSolverContext::init_context) .def("has_context", &gridfire::solver::PointSolverContext::has_context) .def("init", &gridfire::solver::PointSolverContext::init) .def(py::init(), py::arg("engine_ctx")); auto py_single_zone_dynamic_network_solver = py::class_(m, "SingleZoneDynamicNetworkSolver"); py_single_zone_dynamic_network_solver.def( "evaluate", &gridfire::solver::SingleZoneDynamicNetworkSolver::evaluate, py::arg("solver_ctx"), py::arg("netIn"), "evaluate the dynamic engine using the dynamic engine class for a single zone" ); auto py_multi_zone_dynamic_network_solver = py::class_(m, "MultiZoneDynamicNetworkSolver"); py_multi_zone_dynamic_network_solver.def( "evaluate", &gridfire::solver::MultiZoneDynamicNetworkSolver::evaluate, py::arg("solver_ctx"), py::arg("netIns"), "evaluate the dynamic engine using the dynamic engine class for multiple zones (using openmp if available)" ); auto py_point_solver = py::class_(m, "PointSolver"); py_point_solver.def( py::init(), py::arg("engine"), "Initialize the PointSolver object." ); py_point_solver.def( "evaluate", py::overload_cast(&gridfire::solver::PointSolver::evaluate, py::const_), py::arg("solver_ctx"), py::arg("netIn"), py::arg("display_trigger") = false, py::arg("force_reinitialization") = false, "evaluate the dynamic engine using the dynamic engine class" ); auto py_grid_solver_context = py::class_(m, "GridSolverContext"); py_grid_solver_context.def(py::init(), py::arg("ctx_template")); py_grid_solver_context.def("init", &gridfire::solver::GridSolverContext::init); py_grid_solver_context.def("reset", &gridfire::solver::GridSolverContext::reset); py_grid_solver_context.def("set_callback", py::overload_cast&>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback")); py_grid_solver_context.def("set_callback", py::overload_cast&, size_t>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"), py::arg("zone_idx")); py_grid_solver_context.def("clear_callback", py::overload_cast<>(&gridfire::solver::GridSolverContext::clear_callback)); py_grid_solver_context.def("clear_callback", py::overload_cast(&gridfire::solver::GridSolverContext::clear_callback), py::arg("zone_idx")); py_grid_solver_context.def_property( "stdout_logging", [](const gridfire::solver::GridSolverContext& self) -> bool { return self.zone_stdout_logging; }, [](gridfire::solver::GridSolverContext& self, const bool enable) -> void { self.zone_stdout_logging = enable; } ) .def_property( "detailed_logging", [](const gridfire::solver::GridSolverContext& self) -> bool { return self.zone_detailed_logging; }, [](gridfire::solver::GridSolverContext& self, const bool enable) -> void { self.zone_detailed_logging = enable; } ) .def_property( "zone_completion_logging", [](const gridfire::solver::GridSolverContext& self) -> bool { return self.zone_completion_logging; }, [](gridfire::solver::GridSolverContext& self, const bool enable) -> void { self.zone_completion_logging = enable; } ); auto py_grid_solver = py::class_(m, "GridSolver"); py_grid_solver.def( py::init(), py::arg("engine"), py::arg("solver"), "Initialize the GridSolver object." ); py_grid_solver.def( "evaluate", &gridfire::solver::GridSolver::evaluate, py::arg("solver_ctx"), py::arg("netIns"), "evaluate the dynamic engine using the dynamic engine class" ); }