// barrier standard header

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef _BARRIER_
#define _BARRIER_
#include <yvals.h>
#if _STL_COMPILER_PREPROCESSOR

#ifdef _M_CEE_PURE
#error <barrier> is not supported when compiling with /clr:pure.
#endif // defined(_M_CEE_PURE)

#if !_HAS_CXX20
_EMIT_STL_WARNING(STL4038, "The contents of <barrier> are available only with C++20 or later.");
#else // ^^^ !_HAS_CXX20 / _HAS_CXX20 vvv

#include <atomic>
#include <climits>
#include <type_traits>
#include <xmemory>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

_STD_BEGIN

struct _No_completion_function {
    void operator()() noexcept {}
};

_EXPORT_STD template <class _Completion_function = _No_completion_function>
class barrier;

inline constexpr ptrdiff_t _Barrier_arrival_token_mask = 1;
inline constexpr ptrdiff_t _Barrier_value_mask         = ~_Barrier_arrival_token_mask;
inline constexpr ptrdiff_t _Barrier_value_shift        = 1;
inline constexpr ptrdiff_t _Barrier_invalid_token      = 0;
inline constexpr ptrdiff_t _Barrier_value_step         = 1 << _Barrier_value_shift;
inline constexpr ptrdiff_t _Barrier_max                = PTRDIFF_MAX >> _Barrier_value_shift;

template <class _Completion_function>
class _Arrival_token {
public:
    _Arrival_token(_Arrival_token&& _Other) noexcept {
        _Value        = _Other._Value;
        _Other._Value = _Barrier_invalid_token;
    }

    _Arrival_token& operator=(_Arrival_token&& _Other) noexcept {
        _Value        = _Other._Value;
        _Other._Value = _Barrier_invalid_token;
        return *this;
    }

private:
    explicit _Arrival_token(ptrdiff_t _Value_) noexcept : _Value(_Value_) {}
    friend barrier<_Completion_function>;

    ptrdiff_t _Value;
};

_EXPORT_STD template <class _Completion_function>
class barrier {
public:
    static_assert(
#ifndef __cpp_noexcept_function_type
        is_function_v<remove_pointer_t<_Completion_function>> ||
#endif // !defined(__cpp_noexcept_function_type)
            is_nothrow_invocable_v<_Completion_function&>,
        "N4950 [thread.barrier.class]/5: is_nothrow_invocable_v<CompletionFunction&> shall be true");

    using arrival_token = _Arrival_token<_Completion_function>;

    constexpr explicit barrier(
        const ptrdiff_t _Expected, _Completion_function _Fn = _Completion_function()) noexcept /* strengthened */
        : _Val(_One_then_variadic_args_t{}, _STD move(_Fn), _Expected << _Barrier_value_shift) {
        _STL_VERIFY(_Expected >= 0 && _Expected <= (max) (),
            "Precondition: expected >= 0 and expected <= max() (N4950 [thread.barrier.class]/9)");
    }

    barrier(const barrier&)            = delete;
    barrier& operator=(const barrier&) = delete;

    _NODISCARD static constexpr ptrdiff_t(max)() noexcept {
        return _Barrier_max;
    }

    _NODISCARD_BARRIER_TOKEN arrival_token arrive(ptrdiff_t _Update = 1) noexcept /* strengthened */ {
        _STL_VERIFY(_Update > 0 && _Update <= (max) (), "Precondition: update > 0 (N4950 [thread.barrier.class]/12)");
        _Update <<= _Barrier_value_shift;
        // TRANSITION, GH-1133: should be memory_order_release
        ptrdiff_t _Current = _Val._Myval2._Current.fetch_sub(_Update) - _Update;
        _STL_VERIFY(_Current >= 0, "Precondition: update is less than or equal to the expected count "
                                   "for the current barrier phase (N4950 [thread.barrier.class]/12)");
        if ((_Current & _Barrier_value_mask) == 0) {
            // TRANSITION, GH-1133: should have this fence:
            // atomic_thread_fence(memory_order_acquire);
            _Completion(_Current);
        }
        // Embedding this into the token to provide an additional correctness check that the token is from the same
        // barrier and wasn't used. All bits of this fit, as barrier should be aligned to at least the size of an
        // atomic counter.
        return arrival_token{(_Current & _Barrier_arrival_token_mask) | reinterpret_cast<intptr_t>(this)};
    }

    void wait(arrival_token&& _Arrival) const noexcept /* strengthened */ {
        _STL_VERIFY((_Arrival._Value & _Barrier_value_mask) == reinterpret_cast<intptr_t>(this),
            "Preconditions: arrival is associated with the phase synchronization point for the current phase "
            "or the immediately preceding phase of the same barrier object (N4950 [thread.barrier.class]/19)");
        const ptrdiff_t _Arrival_value = _Arrival._Value & _Barrier_arrival_token_mask;
        _Arrival._Value                = _Barrier_invalid_token;
        for (;;) {
            // TRANSITION, GH-1133: should be memory_order_acquire
            const ptrdiff_t _Current = _Val._Myval2._Current.load();
            _STL_VERIFY(_Current >= 0, "Invariant counter >= 0, possibly caused by preconditions violation "
                                       "(N4950 [thread.barrier.class]/12)");
            if ((_Current & _Barrier_arrival_token_mask) != _Arrival_value) {
                break;
            }
            _Val._Myval2._Current.wait(_Current, memory_order_relaxed);
        }
    }

    void arrive_and_wait() noexcept /* strengthened */ {
        // TRANSITION, GH-1133: should be memory_order_acq_rel
        ptrdiff_t _Current       = _Val._Myval2._Current.fetch_sub(_Barrier_value_step) - _Barrier_value_step;
        const ptrdiff_t _Arrival = _Current & _Barrier_arrival_token_mask;
        _STL_VERIFY(_Current >= 0, "Precondition: update is less than or equal to the expected count "
                                   "for the current barrier phase (N4950 [thread.barrier.class]/12)");
        if ((_Current & _Barrier_value_mask) == 0) {
            _Completion(_Current);
            return;
        }

        for (;;) {
            _Val._Myval2._Current.wait(_Current, memory_order_relaxed);
            // TRANSITION, GH-1133: should be memory_order_acquire
            _Current = _Val._Myval2._Current.load();
            _STL_VERIFY(_Current >= 0, "Invariant counter >= 0, possibly caused by preconditions violation "
                                       "(N4950 [thread.barrier.class]/12)");
            if ((_Current & _Barrier_arrival_token_mask) != _Arrival) {
                break;
            }
        }
    }

    void arrive_and_drop() noexcept /* strengthened */ {
        const ptrdiff_t _Rem_count =
            _Val._Myval2._Total.fetch_sub(_Barrier_value_step, memory_order_relaxed) - _Barrier_value_step;
        _STL_VERIFY(_Rem_count >= 0, "Precondition: The expected count for the current barrier phase "
                                     "is greater than zero (N4950 [thread.barrier.class]/24) "
                                     "(checked initial expected count, which is not less than the current)");
        (void) arrive(1);
    }

private:
    void _Completion(const ptrdiff_t _Current) noexcept {
        const ptrdiff_t _Rem_count = _Val._Myval2._Total.load(memory_order_relaxed);
        _STL_VERIFY(_Rem_count >= 0, "Invariant: initial expected count less than zero, "
                                     "possibly caused by preconditions violation "
                                     "(N4950 [thread.barrier.class]/24)");
        _Val._Get_first()();
        const ptrdiff_t _New_phase_count = _Rem_count | ((_Current + 1) & _Barrier_arrival_token_mask);
        // TRANSITION, GH-1133: should be memory_order_release
        _Val._Myval2._Current.store(_New_phase_count);
        _Val._Myval2._Current.notify_all();
    }

    struct _Counter_t {
        constexpr explicit _Counter_t(ptrdiff_t _Initial) : _Current(_Initial), _Total(_Initial) {}
        // wait(arrival_token&&) accepts a token from the current phase or the immediately preceding phase; this means
        // we can track which phase is the current phase using 1 bit which alternates between each phase. For this
        // purpose we use the low order bit of _Current.
        atomic<ptrdiff_t> _Current;
        atomic<ptrdiff_t> _Total;
    };

    _Compressed_pair<_Completion_function, _Counter_t> _Val;
};

_STD_END

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)
#endif // ^^^ _HAS_CXX20 ^^^

#endif // _STL_COMPILER_PREPROCESSOR
#endif // _BARRIER_
