Line data Source code
1 : //
2 : // Copyright (c) 2026 Steve Gerbino
3 : //
4 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 : //
7 : // Official repository: https://github.com/cppalliance/capy
8 : //
9 :
10 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 : #define BOOST_CAPY_WHEN_ALL_HPP
12 :
13 : #include <boost/capy/detail/config.hpp>
14 : #include <boost/capy/concept/affine_awaitable.hpp>
15 : #include <boost/capy/ex/any_coro.hpp>
16 : #include <boost/capy/ex/frame_allocator.hpp>
17 : #include <boost/capy/task.hpp>
18 :
19 : #include <array>
20 : #include <atomic>
21 : #include <exception>
22 : #include <optional>
23 : #if BOOST_CAPY_HAS_STOP_TOKEN
24 : #include <stop_token>
25 : #endif
26 : #include <tuple>
27 : #include <type_traits>
28 : #include <utility>
29 :
30 : namespace boost {
31 : namespace capy {
32 :
33 : namespace detail {
34 :
35 : /** Type trait to filter void types from a tuple.
36 :
37 : Void-returning tasks do not contribute a value to the result tuple.
38 : This trait computes the filtered result type.
39 :
40 : Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 : */
42 : template<typename T>
43 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44 :
45 : template<typename... Ts>
46 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47 :
48 : /** Holds the result of a single task within when_all.
49 : */
50 : template<typename T>
51 : struct result_holder
52 : {
53 : std::optional<T> value_;
54 :
55 45 : void set(T v)
56 : {
57 45 : value_ = std::move(v);
58 45 : }
59 :
60 38 : T get() &&
61 : {
62 38 : return std::move(*value_);
63 : }
64 : };
65 :
66 : /** Specialization for void tasks - no value storage needed.
67 : */
68 : template<>
69 : struct result_holder<void>
70 : {
71 : };
72 :
73 : /** Shared state for when_all operation.
74 :
75 : @tparam Ts The result types of the tasks.
76 : */
77 : template<typename... Ts>
78 : struct when_all_state
79 : {
80 : static constexpr std::size_t task_count = sizeof...(Ts);
81 :
82 : // Completion tracking - when_all waits for all children
83 : std::atomic<std::size_t> remaining_count_;
84 :
85 : // Result storage in input order
86 : std::tuple<result_holder<Ts>...> results_;
87 :
88 : // Runner handles - destroyed in await_resume while allocator is valid
89 : std::array<any_coro, task_count> runner_handles_{};
90 :
91 : // Exception storage - first error wins, others discarded
92 : std::atomic<bool> has_exception_{false};
93 : std::exception_ptr first_exception_;
94 :
95 : #if BOOST_CAPY_HAS_STOP_TOKEN
96 : // Stop propagation - on error, request stop for siblings
97 : std::stop_source stop_source_;
98 :
99 : // Connects parent's stop_token to our stop_source
100 : struct stop_callback_fn
101 : {
102 : std::stop_source* source_;
103 1 : void operator()() const { source_->request_stop(); }
104 : };
105 : using stop_callback_t = std::stop_callback<stop_callback_fn>;
106 : std::optional<stop_callback_t> parent_stop_callback_;
107 : #endif
108 :
109 : // Parent resumption
110 : any_coro continuation_;
111 : any_dispatcher caller_dispatcher_;
112 :
113 24 : when_all_state()
114 24 : : remaining_count_(task_count)
115 : {
116 24 : }
117 :
118 24 : ~when_all_state()
119 : {
120 85 : for(auto h : runner_handles_)
121 61 : if(h)
122 61 : h.destroy();
123 24 : }
124 :
125 : /** Capture an exception (first one wins).
126 : */
127 11 : void capture_exception(std::exception_ptr ep)
128 : {
129 11 : bool expected = false;
130 11 : if(has_exception_.compare_exchange_strong(
131 : expected, true, std::memory_order_relaxed))
132 8 : first_exception_ = ep;
133 11 : }
134 :
135 : /** Signal that a task has completed.
136 :
137 : The last child to complete triggers resumption of the parent.
138 : */
139 61 : any_coro signal_completion()
140 : {
141 61 : auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel);
142 61 : if(remaining == 1)
143 24 : return caller_dispatcher_(continuation_);
144 37 : return std::noop_coroutine();
145 : }
146 :
147 : };
148 :
149 : /** Wrapper coroutine that intercepts task completion.
150 :
151 : This runner awaits its assigned task and stores the result in
152 : the shared state, or captures the exception and requests stop.
153 : */
154 : template<typename T, typename... Ts>
155 : struct when_all_runner
156 : {
157 : struct promise_type : frame_allocating_base
158 : {
159 : when_all_state<Ts...>* state_ = nullptr;
160 : any_dispatcher ex_;
161 : #if BOOST_CAPY_HAS_STOP_TOKEN
162 : std::stop_token stop_token_;
163 : #endif
164 :
165 61 : when_all_runner get_return_object()
166 : {
167 61 : return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
168 : }
169 :
170 61 : std::suspend_always initial_suspend() noexcept
171 : {
172 61 : return {};
173 : }
174 :
175 61 : auto final_suspend() noexcept
176 : {
177 : struct awaiter
178 : {
179 : promise_type* p_;
180 :
181 61 : bool await_ready() const noexcept
182 : {
183 61 : return false;
184 : }
185 :
186 61 : any_coro await_suspend(any_coro) noexcept
187 : {
188 : // Signal completion; last task resumes parent
189 61 : return p_->state_->signal_completion();
190 : }
191 :
192 0 : void await_resume() const noexcept
193 : {
194 0 : }
195 : };
196 61 : return awaiter{this};
197 : }
198 :
199 50 : void return_void()
200 : {
201 50 : }
202 :
203 11 : void unhandled_exception()
204 : {
205 11 : state_->capture_exception(std::current_exception());
206 : #if BOOST_CAPY_HAS_STOP_TOKEN
207 : // Request stop for sibling tasks
208 11 : state_->stop_source_.request_stop();
209 : #endif
210 11 : }
211 :
212 : template<class Awaitable>
213 : struct transform_awaiter
214 : {
215 : std::decay_t<Awaitable> a_;
216 : promise_type* p_;
217 :
218 61 : bool await_ready()
219 : {
220 61 : return a_.await_ready();
221 : }
222 :
223 61 : auto await_resume()
224 : {
225 61 : return a_.await_resume();
226 : }
227 :
228 : template<class Promise>
229 61 : auto await_suspend(std::coroutine_handle<Promise> h)
230 : {
231 : #if BOOST_CAPY_HAS_STOP_TOKEN
232 : using A = std::decay_t<Awaitable>;
233 : // Propagate stop_token to nested awaitables
234 : if constexpr (stoppable_awaitable<A, any_dispatcher>)
235 61 : return a_.await_suspend(h, p_->ex_, p_->stop_token_);
236 : else
237 : #endif
238 : return a_.await_suspend(h, p_->ex_);
239 : }
240 : };
241 :
242 : template<class Awaitable>
243 61 : auto await_transform(Awaitable&& a)
244 : {
245 : using A = std::decay_t<Awaitable>;
246 : if constexpr (affine_awaitable<A, any_dispatcher>)
247 : {
248 : return transform_awaiter<Awaitable>{
249 122 : std::forward<Awaitable>(a), this};
250 : }
251 : else
252 : {
253 : return make_affine(std::forward<Awaitable>(a), ex_);
254 : }
255 61 : }
256 : };
257 :
258 : std::coroutine_handle<promise_type> h_;
259 :
260 61 : explicit when_all_runner(std::coroutine_handle<promise_type> h)
261 61 : : h_(h)
262 : {
263 61 : }
264 :
265 : #if defined(__clang__) && __clang_major__ == 14 && !defined(__apple_build_version__)
266 : // Clang 14 has a bug where it calls the move constructor for coroutine
267 : // return objects even though they should be constructed in-place via RVO.
268 : // This happens when returning a non-movable type from a coroutine.
269 : when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
270 : #endif
271 :
272 : // Non-copyable, non-movable - release() is always called immediately
273 : when_all_runner(when_all_runner const&) = delete;
274 : when_all_runner& operator=(when_all_runner const&) = delete;
275 :
276 : #if !defined(__clang__) || __clang_major__ != 14 || defined(__apple_build_version__)
277 : when_all_runner(when_all_runner&&) = delete;
278 : #endif
279 :
280 : when_all_runner& operator=(when_all_runner&&) = delete;
281 :
282 61 : auto release() noexcept
283 : {
284 61 : return std::exchange(h_, nullptr);
285 : }
286 : };
287 :
288 : /** Create a runner coroutine for a single task.
289 :
290 : Task is passed directly to ensure proper coroutine frame storage.
291 : */
292 : template<std::size_t Index, typename T, typename... Ts>
293 : when_all_runner<T, Ts...>
294 61 : make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
295 : {
296 : if constexpr (std::is_void_v<T>)
297 : co_await std::move(inner);
298 : else
299 : std::get<Index>(state->results_).set(co_await std::move(inner));
300 122 : }
301 :
302 : /** Internal awaitable that launches all runner coroutines and waits.
303 :
304 : This awaitable is used inside the when_all coroutine to handle
305 : the concurrent execution of child tasks.
306 : */
307 : template<typename... Ts>
308 : class when_all_launcher
309 : {
310 : std::tuple<task<Ts>...>* tasks_;
311 : when_all_state<Ts...>* state_;
312 :
313 : public:
314 24 : when_all_launcher(
315 : std::tuple<task<Ts>...>* tasks,
316 : when_all_state<Ts...>* state)
317 24 : : tasks_(tasks)
318 24 : , state_(state)
319 : {
320 24 : }
321 :
322 24 : bool await_ready() const noexcept
323 : {
324 24 : return sizeof...(Ts) == 0;
325 : }
326 :
327 : #if BOOST_CAPY_HAS_STOP_TOKEN
328 : template<dispatcher D>
329 24 : any_coro await_suspend(any_coro continuation, D const& caller_ex, std::stop_token parent_token = {})
330 : {
331 24 : state_->continuation_ = continuation;
332 24 : state_->caller_dispatcher_ = caller_ex;
333 :
334 : // Forward parent's stop requests to children
335 24 : if(parent_token.stop_possible())
336 : {
337 8 : state_->parent_stop_callback_.emplace(
338 : parent_token,
339 4 : typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
340 :
341 4 : if(parent_token.stop_requested())
342 1 : state_->stop_source_.request_stop();
343 : }
344 :
345 : // Launch all tasks concurrently
346 24 : auto token = state_->stop_source_.get_token();
347 48 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
348 24 : (..., launch_one<Is>(caller_ex, token));
349 24 : }(std::index_sequence_for<Ts...>{});
350 :
351 : // Let signal_completion() handle resumption
352 48 : return std::noop_coroutine();
353 24 : }
354 : #else
355 : template<dispatcher D>
356 : any_coro await_suspend(any_coro continuation, D const& caller_ex)
357 : {
358 : state_->continuation_ = continuation;
359 : state_->caller_dispatcher_ = caller_ex;
360 :
361 : // Launch all tasks concurrently
362 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
363 : (..., launch_one<Is>(caller_ex));
364 : }(std::index_sequence_for<Ts...>{});
365 :
366 : // Let signal_completion() handle resumption
367 : return std::noop_coroutine();
368 : }
369 : #endif
370 :
371 24 : void await_resume() const noexcept
372 : {
373 : // Results are extracted by the when_all coroutine from state
374 24 : }
375 :
376 : private:
377 : #if BOOST_CAPY_HAS_STOP_TOKEN
378 : template<std::size_t I, dispatcher D>
379 61 : void launch_one(D const& caller_ex, std::stop_token token)
380 : {
381 61 : auto runner = make_when_all_runner<I>(
382 61 : std::move(std::get<I>(*tasks_)), state_);
383 :
384 61 : auto h = runner.release();
385 61 : h.promise().state_ = state_;
386 61 : h.promise().ex_ = caller_ex;
387 61 : h.promise().stop_token_ = token;
388 :
389 61 : any_coro ch{h};
390 61 : state_->runner_handles_[I] = ch;
391 61 : caller_ex(ch).resume();
392 61 : }
393 : #else
394 : template<std::size_t I, dispatcher D>
395 : void launch_one(D const& caller_ex)
396 : {
397 : auto runner = make_when_all_runner<I>(
398 : std::move(std::get<I>(*tasks_)), state_);
399 :
400 : auto h = runner.release();
401 : h.promise().state_ = state_;
402 : h.promise().ex_ = caller_ex;
403 :
404 : any_coro ch{h};
405 : state_->runner_handles_[I] = ch;
406 : caller_ex(ch).resume();
407 : }
408 : #endif
409 : };
410 :
411 : /** Compute the result type for when_all.
412 :
413 : Returns void when all tasks are void (P2300 aligned),
414 : otherwise returns a tuple with void types filtered out.
415 : */
416 : template<typename... Ts>
417 : using when_all_result_t = std::conditional_t<
418 : std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
419 : void,
420 : filter_void_tuple_t<Ts...>>;
421 :
422 : /** Helper to extract a single result, returning empty tuple for void.
423 : This is a separate function to work around a GCC-11 ICE that occurs
424 : when using nested immediately-invoked lambdas with pack expansion.
425 : */
426 : template<std::size_t I, typename... Ts>
427 40 : auto extract_single_result(when_all_state<Ts...>& state)
428 : {
429 : using T = std::tuple_element_t<I, std::tuple<Ts...>>;
430 : if constexpr (std::is_void_v<T>)
431 2 : return std::tuple<>();
432 : else
433 38 : return std::make_tuple(std::move(std::get<I>(state.results_)).get());
434 : }
435 :
436 : /** Extract results from state, filtering void types.
437 : */
438 : template<typename... Ts>
439 15 : auto extract_results(when_all_state<Ts...>& state)
440 : {
441 30 : return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
442 15 : return std::tuple_cat(extract_single_result<Is>(state)...);
443 30 : }(std::index_sequence_for<Ts...>{});
444 : }
445 :
446 : } // namespace detail
447 :
448 : /** Wait for all tasks to complete concurrently.
449 :
450 : @par Example
451 : @code
452 : task<void> example() {
453 : auto [a, b] = co_await when_all(
454 : fetch_int(), // task<int>
455 : fetch_string() // task<std::string>
456 : );
457 : }
458 : @endcode
459 :
460 : @param tasks The tasks to execute concurrently.
461 : @return A task yielding a tuple of results (void types filtered out).
462 :
463 : Key features:
464 : @li All child tasks are launched concurrently
465 : @li Results are collected in input order
466 : @li First error is captured; subsequent errors are discarded
467 : @li On error, stop is requested for all siblings
468 : @li Completes only after all children have completed
469 : @li Void tasks do not contribute to the result tuple
470 : @li Properly propagates frame allocators to all child coroutines
471 : */
472 : template<typename... Ts>
473 : [[nodiscard]] task<detail::when_all_result_t<Ts...>>
474 24 : when_all(task<Ts>... tasks)
475 : {
476 : using result_type = detail::when_all_result_t<Ts...>;
477 :
478 : // State is stored in the coroutine frame, using the frame allocator
479 : detail::when_all_state<Ts...> state;
480 :
481 : // Store tasks in the frame
482 : std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
483 :
484 : // Launch all tasks and wait for completion
485 : co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
486 :
487 : // Propagate first exception if any.
488 : // Safe without explicit acquire: capture_exception() is sequenced-before
489 : // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
490 : // last task's decrement that resumes this coroutine.
491 : if(state.first_exception_)
492 : std::rethrow_exception(state.first_exception_);
493 :
494 : // Extract and return results
495 : if constexpr (std::is_void_v<result_type>)
496 : co_return;
497 : else
498 : co_return detail::extract_results(state);
499 48 : }
500 :
501 : // For backwards compatibility and type queries, expose result type computation
502 : template<typename... Ts>
503 : using when_all_result_type = detail::when_all_result_t<Ts...>;
504 :
505 : } // namespace capy
506 : } // namespace boost
507 :
508 : #endif
|