LCOV - code coverage report
Current view: top level - boost/capy - when_all.hpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 98.0 % 100 98
Test Date: 2026-01-15 18:27:21 Functions: 89.0 % 382 340

            Line data    Source code
       1              : //
       2              : // Copyright (c) 2026 Steve Gerbino
       3              : //
       4              : // Distributed under the Boost Software License, Version 1.0. (See accompanying
       5              : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
       6              : //
       7              : // Official repository: https://github.com/cppalliance/capy
       8              : //
       9              : 
      10              : #ifndef BOOST_CAPY_WHEN_ALL_HPP
      11              : #define BOOST_CAPY_WHEN_ALL_HPP
      12              : 
      13              : #include <boost/capy/detail/config.hpp>
      14              : #include <boost/capy/concept/affine_awaitable.hpp>
      15              : #include <boost/capy/ex/any_coro.hpp>
      16              : #include <boost/capy/ex/frame_allocator.hpp>
      17              : #include <boost/capy/task.hpp>
      18              : 
      19              : #include <array>
      20              : #include <atomic>
      21              : #include <exception>
      22              : #include <optional>
      23              : #if BOOST_CAPY_HAS_STOP_TOKEN
      24              : #include <stop_token>
      25              : #endif
      26              : #include <tuple>
      27              : #include <type_traits>
      28              : #include <utility>
      29              : 
      30              : namespace boost {
      31              : namespace capy {
      32              : 
      33              : namespace detail {
      34              : 
      35              : /** Type trait to filter void types from a tuple.
      36              : 
      37              :     Void-returning tasks do not contribute a value to the result tuple.
      38              :     This trait computes the filtered result type.
      39              : 
      40              :     Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
      41              : */
      42              : template<typename T>
      43              : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
      44              : 
      45              : template<typename... Ts>
      46              : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
      47              : 
      48              : /** Holds the result of a single task within when_all.
      49              : */
      50              : template<typename T>
      51              : struct result_holder
      52              : {
      53              :     std::optional<T> value_;
      54              : 
      55           45 :     void set(T v)
      56              :     {
      57           45 :         value_ = std::move(v);
      58           45 :     }
      59              : 
      60           38 :     T get() &&
      61              :     {
      62           38 :         return std::move(*value_);
      63              :     }
      64              : };
      65              : 
      66              : /** Specialization for void tasks - no value storage needed.
      67              : */
      68              : template<>
      69              : struct result_holder<void>
      70              : {
      71              : };
      72              : 
      73              : /** Shared state for when_all operation.
      74              : 
      75              :     @tparam Ts The result types of the tasks.
      76              : */
      77              : template<typename... Ts>
      78              : struct when_all_state
      79              : {
      80              :     static constexpr std::size_t task_count = sizeof...(Ts);
      81              : 
      82              :     // Completion tracking - when_all waits for all children
      83              :     std::atomic<std::size_t> remaining_count_;
      84              : 
      85              :     // Result storage in input order
      86              :     std::tuple<result_holder<Ts>...> results_;
      87              : 
      88              :     // Runner handles - destroyed in await_resume while allocator is valid
      89              :     std::array<any_coro, task_count> runner_handles_{};
      90              : 
      91              :     // Exception storage - first error wins, others discarded
      92              :     std::atomic<bool> has_exception_{false};
      93              :     std::exception_ptr first_exception_;
      94              : 
      95              : #if BOOST_CAPY_HAS_STOP_TOKEN
      96              :     // Stop propagation - on error, request stop for siblings
      97              :     std::stop_source stop_source_;
      98              : 
      99              :     // Connects parent's stop_token to our stop_source
     100              :     struct stop_callback_fn
     101              :     {
     102              :         std::stop_source* source_;
     103            1 :         void operator()() const { source_->request_stop(); }
     104              :     };
     105              :     using stop_callback_t = std::stop_callback<stop_callback_fn>;
     106              :     std::optional<stop_callback_t> parent_stop_callback_;
     107              : #endif
     108              : 
     109              :     // Parent resumption
     110              :     any_coro continuation_;
     111              :     any_dispatcher caller_dispatcher_;
     112              : 
     113           24 :     when_all_state()
     114           24 :         : remaining_count_(task_count)
     115              :     {
     116           24 :     }
     117              : 
     118           24 :     ~when_all_state()
     119              :     {
     120           85 :         for(auto h : runner_handles_)
     121           61 :             if(h)
     122           61 :                 h.destroy();
     123           24 :     }
     124              : 
     125              :     /** Capture an exception (first one wins).
     126              :     */
     127           11 :     void capture_exception(std::exception_ptr ep)
     128              :     {
     129           11 :         bool expected = false;
     130           11 :         if(has_exception_.compare_exchange_strong(
     131              :             expected, true, std::memory_order_relaxed))
     132            8 :             first_exception_ = ep;
     133           11 :     }
     134              : 
     135              :     /** Signal that a task has completed.
     136              : 
     137              :         The last child to complete triggers resumption of the parent.
     138              :     */
     139           61 :     any_coro signal_completion()
     140              :     {
     141           61 :         auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel);
     142           61 :         if(remaining == 1)
     143           24 :             return caller_dispatcher_(continuation_);
     144           37 :         return std::noop_coroutine();
     145              :     }
     146              : 
     147              : };
     148              : 
     149              : /** Wrapper coroutine that intercepts task completion.
     150              : 
     151              :     This runner awaits its assigned task and stores the result in
     152              :     the shared state, or captures the exception and requests stop.
     153              : */
     154              : template<typename T, typename... Ts>
     155              : struct when_all_runner
     156              : {
     157              :     struct promise_type : frame_allocating_base
     158              :     {
     159              :         when_all_state<Ts...>* state_ = nullptr;
     160              :         any_dispatcher ex_;
     161              : #if BOOST_CAPY_HAS_STOP_TOKEN
     162              :         std::stop_token stop_token_;
     163              : #endif
     164              : 
     165           61 :         when_all_runner get_return_object()
     166              :         {
     167           61 :             return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
     168              :         }
     169              : 
     170           61 :         std::suspend_always initial_suspend() noexcept
     171              :         {
     172           61 :             return {};
     173              :         }
     174              : 
     175           61 :         auto final_suspend() noexcept
     176              :         {
     177              :             struct awaiter
     178              :             {
     179              :                 promise_type* p_;
     180              : 
     181           61 :                 bool await_ready() const noexcept
     182              :                 {
     183           61 :                     return false;
     184              :                 }
     185              : 
     186           61 :                 any_coro await_suspend(any_coro) noexcept
     187              :                 {
     188              :                     // Signal completion; last task resumes parent
     189           61 :                     return p_->state_->signal_completion();
     190              :                 }
     191              : 
     192            0 :                 void await_resume() const noexcept
     193              :                 {
     194            0 :                 }
     195              :             };
     196           61 :             return awaiter{this};
     197              :         }
     198              : 
     199           50 :         void return_void()
     200              :         {
     201           50 :         }
     202              : 
     203           11 :         void unhandled_exception()
     204              :         {
     205           11 :             state_->capture_exception(std::current_exception());
     206              : #if BOOST_CAPY_HAS_STOP_TOKEN
     207              :             // Request stop for sibling tasks
     208           11 :             state_->stop_source_.request_stop();
     209              : #endif
     210           11 :         }
     211              : 
     212              :         template<class Awaitable>
     213              :         struct transform_awaiter
     214              :         {
     215              :             std::decay_t<Awaitable> a_;
     216              :             promise_type* p_;
     217              : 
     218           61 :             bool await_ready()
     219              :             {
     220           61 :                 return a_.await_ready();
     221              :             }
     222              : 
     223           61 :             auto await_resume()
     224              :             {
     225           61 :                 return a_.await_resume();
     226              :             }
     227              : 
     228              :             template<class Promise>
     229           61 :             auto await_suspend(std::coroutine_handle<Promise> h)
     230              :             {
     231              : #if BOOST_CAPY_HAS_STOP_TOKEN
     232              :                 using A = std::decay_t<Awaitable>;
     233              :                 // Propagate stop_token to nested awaitables
     234              :                 if constexpr (stoppable_awaitable<A, any_dispatcher>)
     235           61 :                     return a_.await_suspend(h, p_->ex_, p_->stop_token_);
     236              :                 else
     237              : #endif
     238              :                     return a_.await_suspend(h, p_->ex_);
     239              :             }
     240              :         };
     241              : 
     242              :         template<class Awaitable>
     243           61 :         auto await_transform(Awaitable&& a)
     244              :         {
     245              :             using A = std::decay_t<Awaitable>;
     246              :             if constexpr (affine_awaitable<A, any_dispatcher>)
     247              :             {
     248              :                 return transform_awaiter<Awaitable>{
     249          122 :                     std::forward<Awaitable>(a), this};
     250              :             }
     251              :             else
     252              :             {
     253              :                 return make_affine(std::forward<Awaitable>(a), ex_);
     254              :             }
     255           61 :         }
     256              :     };
     257              : 
     258              :     std::coroutine_handle<promise_type> h_;
     259              : 
     260           61 :     explicit when_all_runner(std::coroutine_handle<promise_type> h)
     261           61 :         : h_(h)
     262              :     {
     263           61 :     }
     264              : 
     265              : #if defined(__clang__) && __clang_major__ == 14 && !defined(__apple_build_version__)
     266              :     // Clang 14 has a bug where it calls the move constructor for coroutine
     267              :     // return objects even though they should be constructed in-place via RVO.
     268              :     // This happens when returning a non-movable type from a coroutine.
     269              :     when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
     270              : #endif
     271              : 
     272              :     // Non-copyable, non-movable - release() is always called immediately
     273              :     when_all_runner(when_all_runner const&) = delete;
     274              :     when_all_runner& operator=(when_all_runner const&) = delete;
     275              : 
     276              : #if !defined(__clang__) || __clang_major__ != 14 || defined(__apple_build_version__)
     277              :     when_all_runner(when_all_runner&&) = delete;
     278              : #endif
     279              : 
     280              :     when_all_runner& operator=(when_all_runner&&) = delete;
     281              : 
     282           61 :     auto release() noexcept
     283              :     {
     284           61 :         return std::exchange(h_, nullptr);
     285              :     }
     286              : };
     287              : 
     288              : /** Create a runner coroutine for a single task.
     289              : 
     290              :     Task is passed directly to ensure proper coroutine frame storage.
     291              : */
     292              : template<std::size_t Index, typename T, typename... Ts>
     293              : when_all_runner<T, Ts...>
     294           61 : make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
     295              : {
     296              :     if constexpr (std::is_void_v<T>)
     297              :         co_await std::move(inner);
     298              :     else
     299              :         std::get<Index>(state->results_).set(co_await std::move(inner));
     300          122 : }
     301              : 
     302              : /** Internal awaitable that launches all runner coroutines and waits.
     303              : 
     304              :     This awaitable is used inside the when_all coroutine to handle
     305              :     the concurrent execution of child tasks.
     306              : */
     307              : template<typename... Ts>
     308              : class when_all_launcher
     309              : {
     310              :     std::tuple<task<Ts>...>* tasks_;
     311              :     when_all_state<Ts...>* state_;
     312              : 
     313              : public:
     314           24 :     when_all_launcher(
     315              :         std::tuple<task<Ts>...>* tasks,
     316              :         when_all_state<Ts...>* state)
     317           24 :         : tasks_(tasks)
     318           24 :         , state_(state)
     319              :     {
     320           24 :     }
     321              : 
     322           24 :     bool await_ready() const noexcept
     323              :     {
     324           24 :         return sizeof...(Ts) == 0;
     325              :     }
     326              : 
     327              : #if BOOST_CAPY_HAS_STOP_TOKEN
     328              :     template<dispatcher D>
     329           24 :     any_coro await_suspend(any_coro continuation, D const& caller_ex, std::stop_token parent_token = {})
     330              :     {
     331           24 :         state_->continuation_ = continuation;
     332           24 :         state_->caller_dispatcher_ = caller_ex;
     333              : 
     334              :         // Forward parent's stop requests to children
     335           24 :         if(parent_token.stop_possible())
     336              :         {
     337            8 :             state_->parent_stop_callback_.emplace(
     338              :                 parent_token,
     339            4 :                 typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
     340              : 
     341            4 :             if(parent_token.stop_requested())
     342            1 :                 state_->stop_source_.request_stop();
     343              :         }
     344              : 
     345              :         // Launch all tasks concurrently
     346           24 :         auto token = state_->stop_source_.get_token();
     347           48 :         [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     348           24 :             (..., launch_one<Is>(caller_ex, token));
     349           24 :         }(std::index_sequence_for<Ts...>{});
     350              : 
     351              :         // Let signal_completion() handle resumption
     352           48 :         return std::noop_coroutine();
     353           24 :     }
     354              : #else
     355              :     template<dispatcher D>
     356              :     any_coro await_suspend(any_coro continuation, D const& caller_ex)
     357              :     {
     358              :         state_->continuation_ = continuation;
     359              :         state_->caller_dispatcher_ = caller_ex;
     360              : 
     361              :         // Launch all tasks concurrently
     362              :         [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     363              :             (..., launch_one<Is>(caller_ex));
     364              :         }(std::index_sequence_for<Ts...>{});
     365              : 
     366              :         // Let signal_completion() handle resumption
     367              :         return std::noop_coroutine();
     368              :     }
     369              : #endif
     370              : 
     371           24 :     void await_resume() const noexcept
     372              :     {
     373              :         // Results are extracted by the when_all coroutine from state
     374           24 :     }
     375              : 
     376              : private:
     377              : #if BOOST_CAPY_HAS_STOP_TOKEN
     378              :     template<std::size_t I, dispatcher D>
     379           61 :     void launch_one(D const& caller_ex, std::stop_token token)
     380              :     {
     381           61 :         auto runner = make_when_all_runner<I>(
     382           61 :             std::move(std::get<I>(*tasks_)), state_);
     383              : 
     384           61 :         auto h = runner.release();
     385           61 :         h.promise().state_ = state_;
     386           61 :         h.promise().ex_ = caller_ex;
     387           61 :         h.promise().stop_token_ = token;
     388              : 
     389           61 :         any_coro ch{h};
     390           61 :         state_->runner_handles_[I] = ch;
     391           61 :         caller_ex(ch).resume();
     392           61 :     }
     393              : #else
     394              :     template<std::size_t I, dispatcher D>
     395              :     void launch_one(D const& caller_ex)
     396              :     {
     397              :         auto runner = make_when_all_runner<I>(
     398              :             std::move(std::get<I>(*tasks_)), state_);
     399              : 
     400              :         auto h = runner.release();
     401              :         h.promise().state_ = state_;
     402              :         h.promise().ex_ = caller_ex;
     403              : 
     404              :         any_coro ch{h};
     405              :         state_->runner_handles_[I] = ch;
     406              :         caller_ex(ch).resume();
     407              :     }
     408              : #endif
     409              : };
     410              : 
     411              : /** Compute the result type for when_all.
     412              : 
     413              :     Returns void when all tasks are void (P2300 aligned),
     414              :     otherwise returns a tuple with void types filtered out.
     415              : */
     416              : template<typename... Ts>
     417              : using when_all_result_t = std::conditional_t<
     418              :     std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
     419              :     void,
     420              :     filter_void_tuple_t<Ts...>>;
     421              : 
     422              : /** Helper to extract a single result, returning empty tuple for void.
     423              :     This is a separate function to work around a GCC-11 ICE that occurs
     424              :     when using nested immediately-invoked lambdas with pack expansion.
     425              : */
     426              : template<std::size_t I, typename... Ts>
     427           40 : auto extract_single_result(when_all_state<Ts...>& state)
     428              : {
     429              :     using T = std::tuple_element_t<I, std::tuple<Ts...>>;
     430              :     if constexpr (std::is_void_v<T>)
     431            2 :         return std::tuple<>();
     432              :     else
     433           38 :         return std::make_tuple(std::move(std::get<I>(state.results_)).get());
     434              : }
     435              : 
     436              : /** Extract results from state, filtering void types.
     437              : */
     438              : template<typename... Ts>
     439           15 : auto extract_results(when_all_state<Ts...>& state)
     440              : {
     441           30 :     return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     442           15 :         return std::tuple_cat(extract_single_result<Is>(state)...);
     443           30 :     }(std::index_sequence_for<Ts...>{});
     444              : }
     445              : 
     446              : } // namespace detail
     447              : 
     448              : /** Wait for all tasks to complete concurrently.
     449              : 
     450              :     @par Example
     451              :     @code
     452              :     task<void> example() {
     453              :         auto [a, b] = co_await when_all(
     454              :             fetch_int(),     // task<int>
     455              :             fetch_string()   // task<std::string>
     456              :         );
     457              :     }
     458              :     @endcode
     459              : 
     460              :     @param tasks The tasks to execute concurrently.
     461              :     @return A task yielding a tuple of results (void types filtered out).
     462              : 
     463              :     Key features:
     464              :     @li All child tasks are launched concurrently
     465              :     @li Results are collected in input order
     466              :     @li First error is captured; subsequent errors are discarded
     467              :     @li On error, stop is requested for all siblings
     468              :     @li Completes only after all children have completed
     469              :     @li Void tasks do not contribute to the result tuple
     470              :     @li Properly propagates frame allocators to all child coroutines
     471              : */
     472              : template<typename... Ts>
     473              : [[nodiscard]] task<detail::when_all_result_t<Ts...>>
     474           24 : when_all(task<Ts>... tasks)
     475              : {
     476              :     using result_type = detail::when_all_result_t<Ts...>;
     477              : 
     478              :     // State is stored in the coroutine frame, using the frame allocator
     479              :     detail::when_all_state<Ts...> state;
     480              : 
     481              :     // Store tasks in the frame
     482              :     std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
     483              : 
     484              :     // Launch all tasks and wait for completion
     485              :     co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
     486              : 
     487              :     // Propagate first exception if any.
     488              :     // Safe without explicit acquire: capture_exception() is sequenced-before
     489              :     // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
     490              :     // last task's decrement that resumes this coroutine.
     491              :     if(state.first_exception_)
     492              :         std::rethrow_exception(state.first_exception_);
     493              : 
     494              :     // Extract and return results
     495              :     if constexpr (std::is_void_v<result_type>)
     496              :         co_return;
     497              :     else
     498              :         co_return detail::extract_results(state);
     499           48 : }
     500              : 
     501              : // For backwards compatibility and type queries, expose result type computation
     502              : template<typename... Ts>
     503              : using when_all_result_type = detail::when_all_result_t<Ts...>;
     504              : 
     505              : } // namespace capy
     506              : } // namespace boost
     507              : 
     508              : #endif
        

Generated by: LCOV version 2.3