22#ifndef COMMUNICATION_MPI_CALLBACKS
23#define COMMUNICATION_MPI_CALLBACKS
42#include <boost/mpi/collectives/broadcast.hpp>
43#include <boost/mpi/communicator.hpp>
44#include <boost/mpi/environment.hpp>
45#include <boost/mpi/packed_iarchive.hpp>
46#include <boost/range/algorithm/remove_if.hpp>
68using is_allowed_argument =
69 std::integral_constant<bool,
70 not(std::is_pointer_v<T> ||
71 (!std::is_const_v<std::remove_reference_t<T>> &&
72 std::is_lvalue_reference_v<T>))>;
74template <
class... Args>
75using are_allowed_arguments =
89template <
class F,
class... Args>
90auto invoke(F
f, boost::mpi::packed_iarchive &ia) {
91 static_assert(are_allowed_arguments<Args...>::value,
92 "Pointers and non-const references are not allowed as "
93 "arguments for callbacks.");
97 std::tuple<std::remove_const_t<std::remove_reference_t<Args>>...>
params;
104 return std::apply(
f, std::as_const(
params));
114struct callback_concept_t {
120 virtual void operator()(boost::mpi::communicator
const &,
121 boost::mpi::packed_iarchive &)
const = 0;
122 virtual ~callback_concept_t() =
default;
131template <
class F,
class... Args>
132struct callback_void_t final :
public callback_concept_t {
135 callback_void_t(callback_void_t
const &) =
delete;
136 callback_void_t(callback_void_t &&) =
delete;
138 template <
class FRef>
139 explicit callback_void_t(FRef &&
f) : m_f(std::forward<FRef>(
f)) {}
140 void operator()(boost::mpi::communicator
const &,
141 boost::mpi::packed_iarchive &ia)
const override {
142 detail::invoke<F, Args...>(m_f, ia);
146template <
class F,
class R,
class... Args>
struct FunctorTypes {
147 using functor_type = F;
148 using return_type = R;
149 using argument_types = std::tuple<Args...>;
152template <
class C,
class R,
class... Args>
153auto functor_types_impl(R (C::*)(Args...) const) {
154 return FunctorTypes<C, R, Args...>{};
159 decltype(functor_types_impl(&std::remove_reference_t<F>::operator()));
161template <
class CRef,
class C,
class R,
class... Args>
162auto make_model_impl(CRef &&c, FunctorTypes<C, R, Args...>) {
163 return std::make_unique<callback_void_t<C, Args...>>(std::forward<CRef>(c));
172template <
typename F>
auto make_model(F &&
f) {
173 return make_model_impl(std::forward<F>(
f), functor_types<F>{});
179template <
class... Args>
auto make_model(
void (*f_ptr)(Args...)) {
180 return std::make_unique<callback_void_t<void (*)(Args...), Args...>>(f_ptr);
202 template <
typename F,
class = std::enable_if_t<std::is_same_v<
203 typename detail::functor_types<F>::argument_types,
204 std::tuple<Args...>>>>
206 : m_id(cb->add(std::forward<F>(
f))), m_cb(std::move(cb)) {}
215 std::shared_ptr<MpiCallbacks> m_cb;
225 template <
class... ArgRef>
230 std::is_void_v<
decltype(std::declval<void (*)(Args...)>()(
231 std::forward<ArgRef>(args)...))>> {
233 m_cb->call(m_id, std::forward<ArgRef>(args)...);
241 int id()
const {
return m_id; }
249 static auto &static_callbacks() {
251 std::pair<void (*)(), std::unique_ptr<detail::callback_concept_t>>>
259 std::shared_ptr<boost::mpi::environment> mpi_env)
260 : m_comm(std::move(
comm)), m_mpi_env(std::move(mpi_env)) {
262 m_callback_map.add(
nullptr);
264 for (
auto &kv : static_callbacks()) {
265 m_func_ptr_to_id[kv.first] = m_callback_map.add(kv.second.get());
271 if (m_comm.rank() == 0) {
291 template <
typename F>
auto add(F &&
f) {
292 m_callbacks.emplace_back(detail::make_model(std::forward<F>(
f)));
293 return m_callback_map.add(m_callbacks.back().get());
305 template <
class... Args>
void add(
void (*fp)(Args...)) {
306 m_callbacks.emplace_back(detail::make_model(fp));
307 const int id = m_callback_map.add(m_callbacks.back().get());
308 m_func_ptr_to_id[
reinterpret_cast<void (*)()
>(fp)] =
id;
319 template <
class... Args>
static void add_static(
void (*fp)(Args...)) {
320 static_callbacks().emplace_back(
reinterpret_cast<void (*)()
>(fp),
321 detail::make_model(fp));
333 void remove(
int id) {
335 boost::remove_if(m_callbacks,
336 [ptr = m_callback_map[
id]](
auto const &e) {
337 return e.get() == ptr;
340 m_callback_map.remove(
id);
355 template <
class... Args>
void call(
int id, Args &&...args)
const {
356 if (m_comm.rank() != 0) {
357 throw std::logic_error(
"Callbacks can only be invoked on rank 0.");
360 assert(m_callback_map.find(
id) != m_callback_map.end() &&
361 "m_callback_map and m_func_ptr_to_id disagree");
364 boost::mpi::packed_oarchive oa(m_comm);
369 std::forward_as_tuple(std::forward<Args>(args)...));
371 boost::mpi::broadcast(m_comm, oa, 0);
385 template <
class... Args,
class... ArgRef>
386 auto call(
void (*fp)(Args...), ArgRef &&...args) const ->
388 std::enable_if_t<std::is_void_v<decltype(fp(args...))>> {
389 const int id = m_func_ptr_to_id.at(
reinterpret_cast<void (*)()
>(fp));
391 call(
id, std::forward<ArgRef>(args)...);
404 template <
class... Args,
class... ArgRef>
405 auto call_all(
void (*fp)(Args...), ArgRef &&...args) const ->
407 std::enable_if_t<std::is_void_v<decltype(fp(args...))>> {
424 boost::mpi::packed_iarchive ia(m_comm);
425 boost::mpi::broadcast(m_comm, ia, 0);
430 if (request == LOOP_ABORT) {
434 m_callback_map[request]->operator()(m_comm, ia);
448 boost::mpi::communicator
const &
comm()
const {
return m_comm; }
458 static constexpr int LOOP_ABORT = 0;
463 boost::mpi::communicator m_comm;
468 std::shared_ptr<boost::mpi::environment> m_mpi_env;
473 std::vector<std::unique_ptr<detail::callback_concept_t>> m_callbacks;
484 std::unordered_map<void (*)(),
int> m_func_ptr_to_id;
487template <
class... Args>
514#define REGISTER_CALLBACK(cb) \
515 namespace Communication { \
516 static ::Communication::RegisterCallback register_##cb(&(cb)); \
Keep an enumerated list of T objects, managed by the class.
RAII handle for a callback.
auto operator()(ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(std::declval< void(*)(Args...)>()(std::forward< ArgRef >(args)...))> >
Call the callback managed by this handle.
CallbackHandle(CallbackHandle &&rhs) noexcept=default
CallbackHandle(CallbackHandle const &)=delete
CallbackHandle & operator=(CallbackHandle &&rhs) noexcept=default
CallbackHandle(std::shared_ptr< MpiCallbacks > cb, F &&f)
CallbackHandle & operator=(CallbackHandle const &)=delete
The interface of the MPI callback mechanism.
auto call_all(void(*fp)(Args...), ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(fp(args...))> >
Call a callback on all nodes.
auto call(void(*fp)(Args...), ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(fp(args...))> >
Call a callback on worker nodes.
MpiCallbacks(boost::mpi::communicator comm, std::shared_ptr< boost::mpi::environment > mpi_env)
void add(void(*fp)(Args...))
Add a new callback.
std::shared_ptr< boost::mpi::environment > share_mpi_env() const
boost::mpi::communicator const & comm() const
The boost mpi communicator used by this instance.
void abort_loop()
Abort the MPI loop.
static void add_static(void(*fp)(Args...))
Add a new callback.
MpiCallbacks & operator=(MpiCallbacks const &)=delete
MpiCallbacks(MpiCallbacks const &)=delete
void loop() const
Start the MPI loop.
Helper class to add callbacks before main.
RegisterCallback()=delete
RegisterCallback(void(*cb)(Args...))
Container for objects that are identified by a numeric id.
void for_each(F &&f, Tuple &t)
static SteepestDescentParameters params
Currently active steepest descent instance.
Algorithms for tuple-like inhomogeneous containers.