Files
SERiF/src/python/eos/bindings.cpp

160 lines
8.4 KiB
C++
Raw Normal View History

#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/numpy.h>
#include <string>
#include "helm.h"
#include "bindings.h"
#include "EOSio.h"
#include "helm.h"
namespace serif::eos {
class EOSio;
}
namespace py = pybind11;
void register_eos_bindings(pybind11::module &eos_submodule) {
py::class_<serif::eos::EOSio>(eos_submodule, "EOSio")
.def(py::init<std::string>(), py::arg("filename"))
// .def("load", &EOSio::load)
.def("getFormat", &serif::eos::EOSio::getFormat, "Get the format of the EOS table.")
.def("getTable", [](serif::eos::EOSio &self) -> serif::eos::helmholtz::HELMTable* {
auto& table_variant = self.getTable();
// Use std::get_if to safely access the contents of the variant.
// This returns a pointer to the value if the variant holds that type, otherwise nullptr.
if (auto* ptr_to_unique_ptr = std::get_if<std::unique_ptr<serif::eos::helmholtz::HELMTable>>(&table_variant)) {
return (*ptr_to_unique_ptr).get();
}
return nullptr;
}, py::return_value_policy::reference_internal, // IMPORTANT: Keep this policy!
"Get the EOS table data.")
.def("__repr__", [](const serif::eos::EOSio &eos) {
return "<EOSio(filename='" + eos.getFilename() + "', format='" + eos.getFormatName() + "')>";
});
py::class_<serif::eos::EOSTable>(eos_submodule, "EOSTable");
py::class_<serif::eos::helmholtz::HELMTable>(eos_submodule, "HELMTable")
.def_readonly("loaded", &serif::eos::helmholtz::HELMTable::loaded)
.def_readonly("imax", &serif::eos::helmholtz::HELMTable::imax)
.def_readonly("jmax", &serif::eos::helmholtz::HELMTable::jmax)
.def_readonly("t", &serif::eos::helmholtz::HELMTable::t)
.def_readonly("d", &serif::eos::helmholtz::HELMTable::d)
.def("__repr__", [](const serif::eos::helmholtz::HELMTable &table) {
return "<HELMTable(loaded=" + std::to_string(table.loaded) + ", imax=" + std::to_string(table.imax) +
", jmax=" + std::to_string(table.jmax) + ")>";
})
.def_property_readonly("f", [](serif::eos::helmholtz::HELMTable &table) -> py::array_t<double> {
// --- Check Preconditions ---
// 1. Check if dimensions are valid
if (table.imax <= 0 || table.jmax <= 0) {
// Return empty array or throw error for invalid dimensions
throw std::runtime_error("HELMTable dimensions (imax, jmax) are non-positive.");
// Alternatively: return py::array_t<double>();
}
// 2. Check if pointer 'f' and the data block 'f[0]' are non-null
// (Essential check assuming f could be null if not loaded/initialized)
if (!table.f || !table.f[0]) {
throw std::runtime_error("HELMTable data buffer 'f' is null or not initialized.");
// Alternatively: return py::array_t<double>();
}
// --- Get necessary info ---
py::ssize_t rows = static_cast<py::ssize_t>(table.imax);
py::ssize_t cols = static_cast<py::ssize_t>(table.jmax);
double* data_ptr = table.f[0]; // Pointer to the start of contiguous data block
// --- Define NumPy array shape and strides ---
std::vector<py::ssize_t> shape = {rows, cols};
std::vector<py::ssize_t> strides = {
static_cast<py::ssize_t>(cols * sizeof(double)), // Stride to next row
static_cast<py::ssize_t>( sizeof(double)) // Stride to next element in row
};
// --- Create and return the py::array_t ---
// py::cast(table) creates a py::object that acts as the 'base'.
// This tells NumPy not to manage the memory of 'data_ptr' and
// ensures the 'table' object stays alive as long as the NumPy array view exists.
return py::array_t<double>(
shape, // The dimensions of the array
strides, // How many bytes to step in each dimension
data_ptr, // Pointer to the actual data
py::cast(table) // Owner object (keeps C++ object alive)
);
}, py::return_value_policy::reference_internal); // Keep parent 'table' alive
py::class_<serif::eos::helmholtz::HELMEOSOutput>(eos_submodule, "EOS")
.def(py::init<>())
.def_readonly("ye", &serif::eos::helmholtz::HELMEOSOutput::ye)
.def_readonly("etaele", &serif::eos::helmholtz::HELMEOSOutput::etaele)
.def_readonly("xnefer", &serif::eos::helmholtz::HELMEOSOutput::xnefer)
.def_readonly("ptot", &serif::eos::helmholtz::HELMEOSOutput::ptot)
.def_readonly("pgas", &serif::eos::helmholtz::HELMEOSOutput::pgas)
.def_readonly("prad", &serif::eos::helmholtz::HELMEOSOutput::prad)
.def_readonly("etot", &serif::eos::helmholtz::HELMEOSOutput::etot)
.def_readonly("egas", &serif::eos::helmholtz::HELMEOSOutput::egas)
.def_readonly("erad", &serif::eos::helmholtz::HELMEOSOutput::erad)
.def_readonly("stot", &serif::eos::helmholtz::HELMEOSOutput::stot)
.def_readonly("sgas", &serif::eos::helmholtz::HELMEOSOutput::sgas)
.def_readonly("srad", &serif::eos::helmholtz::HELMEOSOutput::srad)
.def_readonly("dpresdd", &serif::eos::helmholtz::HELMEOSOutput::dpresdd)
.def_readonly("dpresdt", &serif::eos::helmholtz::HELMEOSOutput::dpresdt)
.def_readonly("dpresda", &serif::eos::helmholtz::HELMEOSOutput::dpresda)
.def_readonly("dpresdz", &serif::eos::helmholtz::HELMEOSOutput::dpresdz)
.def_readonly("dentrdd", &serif::eos::helmholtz::HELMEOSOutput::dentrdd)
.def_readonly("dentrdt", &serif::eos::helmholtz::HELMEOSOutput::dentrdt)
.def_readonly("dentrda", &serif::eos::helmholtz::HELMEOSOutput::dentrda)
.def_readonly("dentrdz", &serif::eos::helmholtz::HELMEOSOutput::dentrdz)
.def_readonly("denerdd", &serif::eos::helmholtz::HELMEOSOutput::denerdd)
.def_readonly("denerdt", &serif::eos::helmholtz::HELMEOSOutput::denerdt)
.def_readonly("denerda", &serif::eos::helmholtz::HELMEOSOutput::denerda)
.def_readonly("denerdz", &serif::eos::helmholtz::HELMEOSOutput::denerdz)
.def_readonly("chiT", &serif::eos::helmholtz::HELMEOSOutput::chiT)
.def_readonly("chiRho", &serif::eos::helmholtz::HELMEOSOutput::chiRho)
.def_readonly("csound", &serif::eos::helmholtz::HELMEOSOutput::csound)
.def_readonly("grad_ad", &serif::eos::helmholtz::HELMEOSOutput::grad_ad)
.def_readonly("gamma1", &serif::eos::helmholtz::HELMEOSOutput::gamma1)
.def_readonly("gamma2", &serif::eos::helmholtz::HELMEOSOutput::gamma2)
.def_readonly("gamma3", &serif::eos::helmholtz::HELMEOSOutput::gamma3)
.def_readonly("cV", &serif::eos::helmholtz::HELMEOSOutput::cV)
.def_readonly("cP", &serif::eos::helmholtz::HELMEOSOutput::cP)
.def_readonly("dse", &serif::eos::helmholtz::HELMEOSOutput::dse)
.def_readonly("dpe", &serif::eos::helmholtz::HELMEOSOutput::dpe)
.def_readonly("dsp", &serif::eos::helmholtz::HELMEOSOutput::dsp)
.def("__repr__", [](const serif::eos::helmholtz::HELMEOSOutput &eos) {
return "<EOS (output from helmholtz eos)>";
});
py::class_<serif::eos::helmholtz::HELMEOSInput>(eos_submodule, "HELMEOSInput")
.def(py::init<>())
.def_readwrite("T", &serif::eos::helmholtz::HELMEOSInput::T)
.def_readwrite("rho", &serif::eos::helmholtz::HELMEOSInput::rho)
.def_readwrite("abar", &serif::eos::helmholtz::HELMEOSInput::abar)
.def_readwrite("zbar", &serif::eos::helmholtz::HELMEOSInput::zbar)
.def("__repr__", [](const serif::eos::helmholtz::HELMEOSInput &input) {
return "<HELMEOSInput(T=" + std::to_string(input.T) +
", rho=" + std::to_string(input.rho) +
", abar=" + std::to_string(input.abar) +
", zbar=" + std::to_string(input.zbar) + ")>";
});
eos_submodule.def("get_helm_eos",
&serif::eos::helmholtz::get_helm_EOS,
py::arg("q"), py::arg("table"),
"Calculate the Helmholtz EOS components based on input parameters and table data.");
}