// semaphore standard header

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef _SEMAPHORE_
#define _SEMAPHORE_
#include <yvals.h>
#if _STL_COMPILER_PREPROCESSOR

#ifdef _M_CEE_PURE
#error <semaphore> is not supported when compiling with /clr:pure.
#endif // defined(_M_CEE_PURE)

#if !_HAS_CXX20
_EMIT_STL_WARNING(STL4038, "The contents of <semaphore> are available only with C++20 or later.");
#else // ^^^ !_HAS_CXX20 / _HAS_CXX20 vvv

#include <__msvc_chrono.hpp>
#include <atomic>
#include <climits>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

_STD_BEGIN

template <class _Rep, class _Period>
_NODISCARD unsigned long long _Semaphore_deadline(const chrono::duration<_Rep, _Period>& _Rel_time) {
    return __std_atomic_wait_get_deadline(
        chrono::duration_cast<chrono::duration<unsigned long long, milli>>(_Rel_time).count());
}

template <class _Clock, class _Duration>
_NODISCARD unsigned long _Semaphore_remaining_timeout(const chrono::time_point<_Clock, _Duration>& _Abs_time) {
    const auto _Now = _Clock::now();
    if (_Now >= _Abs_time) {
        return 0;
    }

    const auto _Rel_time = chrono::ceil<chrono::milliseconds>(_Abs_time - _Now);
    constexpr chrono::milliseconds _Ten_days{chrono::hours{24 * 10}};
    _STL_INTERNAL_STATIC_ASSERT(_STD in_range<unsigned long>(_Ten_days.count()));
    if (_Rel_time >= _Ten_days) {
        return static_cast<unsigned long>(_Ten_days.count());
    }

    return static_cast<unsigned long>(_Rel_time.count());
}

inline constexpr ptrdiff_t _Semaphore_max = PTRDIFF_MAX;

_EXPORT_STD template <ptrdiff_t _Least_max_value = _Semaphore_max>
class counting_semaphore {
public:
    static_assert(_Least_max_value >= 0,
        "The least maximum value for a counting_semaphore must be nonnegative (N4950 [thread.sema.cnt]/2).");

    _NODISCARD static constexpr ptrdiff_t(max)() noexcept {
        return _Least_max_value;
    }

    constexpr explicit counting_semaphore(const ptrdiff_t _Desired) noexcept /* strengthened */
        : _Counter(_Desired) {
        _STL_VERIFY(_Desired >= 0 && _Desired <= _Least_max_value,
            "Precondition: desired >= 0, and desired <= max() (N4950 [thread.sema.cnt]/5)");
    }

    counting_semaphore(const counting_semaphore&)            = delete;
    counting_semaphore& operator=(const counting_semaphore&) = delete;

    void release(ptrdiff_t _Update = 1) noexcept /* strengthened */ {
        if (_Update == 0) {
            return;
        }
        _STL_VERIFY(_Update > 0 && _Update <= _Least_max_value,
            "Precondition: update >= 0, and update <= max() - counter (N4950 [thread.sema.cnt]/8)");

        // We need to notify (wake) at least _Update waiting threads.
        // Errors towards waking more cannot be always avoided, but they are performance issues.
        // Errors towards waking fewer must be avoided, as they are correctness issues.

        // release thread: Increment semaphore counter, then load waiting counter;
        // acquire thread: Increment waiting counter, then load semaphore counter;

        // memory_order_seq_cst for all four operations guarantees that the release thread loads
        // the incremented value, or the acquire thread loads the incremented value, or both, but not neither.
        // memory_order_seq_cst might be superfluous for some hardware mappings of the C++ memory model,
        // but from the point of view of the C++ memory model itself it is needed; weaker orders don't work.

        const ptrdiff_t _Prev = _Counter.fetch_add(_Update);
        _STL_VERIFY(_Prev >= 0 && _Update <= _Least_max_value - _Prev,
            "Precondition: update <= max() - counter (N4950 [thread.sema.cnt]/8)");

        const ptrdiff_t _Waiting_upper_bound = _Waiting.load();

        if (_Waiting_upper_bound == 0) {
            // Definitely no one is waiting
        } else if (_Waiting_upper_bound <= _Update) {
            // No more waiting threads than update, can wake everyone.
            _Counter.notify_all();
        } else {
            // Wake at most _Update. Though repeated notify_one() is somewhat less efficient than single notify_all(),
            // the amount of OS calls is still the same; the benefit from trying not to wake unnecessary threads
            // is expected to be greater than the loss on extra calls and atomic operations.
            for (; _Update != 0; --_Update) {
                _Counter.notify_one();
            }
        }
    }

    void acquire() noexcept /* strengthened */ {
        ptrdiff_t _Current = _Counter.load(memory_order_relaxed);
        for (;;) {
            while (_Current == 0) {
                _Wait(__std_atomic_wait_no_timeout);
                _Current = _Counter.load(memory_order_relaxed);
            }
            _STL_VERIFY(_Current > 0 && _Current <= _Least_max_value,
                "Invariant: counter >= 0, and counter <= max() "
                "possibly caused by preconditions violation (N4950 [thread.sema.cnt]/8)");

            // "happens after release" ordering is provided by this CAS, so loads and waits can be relaxed
            if (_Counter.compare_exchange_weak(_Current, _Current - 1)) {
                return;
            }
        }
    }

    _NODISCARD_TRY_CHANGE_STATE bool try_acquire() noexcept {
        ptrdiff_t _Current = _Counter.load();
        if (_Current == 0) {
            return false;
        }
        _STL_VERIFY(_Current > 0 && _Current <= _Least_max_value,
            "Invariant: counter >= 0, and counter <= max() "
            "possibly caused by preconditions violation (N4950 [thread.sema.cnt]/8)");

        return _Counter.compare_exchange_weak(_Current, _Current - 1);
    }

    template <class _Rep, class _Period>
    _NODISCARD_TRY_CHANGE_STATE bool try_acquire_for(const chrono::duration<_Rep, _Period>& _Rel_time) {
        auto _Deadline     = _Semaphore_deadline(_Rel_time);
        ptrdiff_t _Current = _Counter.load(memory_order_relaxed);
        for (;;) {
            while (_Current == 0) {
                const auto _Remaining_timeout = __std_atomic_wait_get_remaining_timeout(_Deadline);
                if (_Remaining_timeout == 0) {
                    return false;
                }
                _Wait(_Remaining_timeout);
                _Current = _Counter.load(memory_order_relaxed);
            }
            _STL_VERIFY(_Current > 0 && _Current <= _Least_max_value,
                "Invariant: counter >= 0, and counter <= max() "
                "possibly caused by preconditions violation (N4950 [thread.sema.cnt]/8)");

            // "happens after release" ordering is provided by this CAS, so loads and waits can be relaxed
            if (_Counter.compare_exchange_weak(_Current, _Current - 1)) {
                return true;
            }
        }
    }

    template <class _Clock, class _Duration>
    _NODISCARD_TRY_CHANGE_STATE bool try_acquire_until(const chrono::time_point<_Clock, _Duration>& _Abs_time) {
        static_assert(chrono::is_clock_v<_Clock>, "Clock type required");
        ptrdiff_t _Current = _Counter.load(memory_order_relaxed);
        for (;;) {
            while (_Current == 0) {
                const unsigned long _Remaining_timeout = _Semaphore_remaining_timeout(_Abs_time);
                if (_Remaining_timeout == 0) {
                    return false;
                }
                _Wait(_Remaining_timeout);
                _Current = _Counter.load(memory_order_relaxed);
            }
            _STL_VERIFY(_Current > 0 && _Current <= _Least_max_value,
                "Invariant: counter >= 0, and counter <= max() "
                "possibly caused by preconditions violation (N4950 [thread.sema.cnt]/8)");

            // "happens after release" ordering is provided by this CAS, so loads and waits can be relaxed
            if (_Counter.compare_exchange_weak(_Current, _Current - 1)) {
                return true;
            }
        }
    }

private:
    void _Wait(const unsigned long _Remaining_timeout) noexcept {
        // See the comment in release()
        _Waiting.fetch_add(1);
        ptrdiff_t _Current = _Counter.load();
        if (_Current == 0) {
            __std_atomic_wait_direct(&_Counter, &_Current, sizeof(_Current), _Remaining_timeout);
        }
        _Waiting.fetch_sub(1, memory_order_relaxed);
    }

    atomic<ptrdiff_t> _Counter;
    atomic<ptrdiff_t> _Waiting;
};

template <>
class counting_semaphore<1> {
public:
    _NODISCARD static constexpr ptrdiff_t(max)() noexcept {
        return 1;
    }

    constexpr explicit counting_semaphore(const ptrdiff_t _Desired) noexcept /* strengthened */
        : _Counter(static_cast<unsigned char>(_Desired)) {
        _STL_VERIFY((_Desired & ~1) == 0, "Precondition: desired >= 0, and desired <= max() "
                                          "(N4950 [thread.sema.cnt]/5)");
    }

    counting_semaphore(const counting_semaphore&)            = delete;
    counting_semaphore& operator=(const counting_semaphore&) = delete;

    void release(const ptrdiff_t _Update = 1) noexcept /* strengthened */ {
        if (_Update == 0) {
            return;
        }
        _STL_VERIFY(_Update == 1, "Precondition: update >= 0, "
                                  "and update <= max() - counter (N4950 [thread.sema.cnt]/8)");
        // TRANSITION, GH-1133: should be memory_order_release
        _Counter.store(1);
        _Counter.notify_one();
    }

    void acquire() noexcept /* strengthened */ {
        for (;;) {
            // "happens after release" ordering is provided by this exchange, so loads and waits can be relaxed
            // TRANSITION, GH-1133: should be memory_order_acquire
            unsigned char _Prev = _Counter.exchange(0);
            if (_Prev == 1) {
                break;
            }
            _STL_VERIFY(_Prev == 0, "Invariant: semaphore counter is non-negative and doesn't exceed max(), "
                                    "possibly caused by memory corruption");
            _Counter.wait(0, memory_order_relaxed);
        }
    }

    _NODISCARD_TRY_CHANGE_STATE bool try_acquire() noexcept {
        // TRANSITION, GH-1133: should be memory_order_acquire
        unsigned char _Prev = _Counter.exchange(0);
        _STL_VERIFY((_Prev & ~1) == 0, "Invariant: semaphore counter is non-negative and doesn't exceed max(), "
                                       "possibly caused by memory corruption");
        return reinterpret_cast<const bool&>(_Prev);
    }

    template <class _Rep, class _Period>
    _NODISCARD_TRY_CHANGE_STATE bool try_acquire_for(const chrono::duration<_Rep, _Period>& _Rel_time) {
        auto _Deadline = _Semaphore_deadline(_Rel_time);
        for (;;) {
            // "happens after release" ordering is provided by this exchange, so loads and waits can be relaxed
            // TRANSITION, GH-1133: should be memory_order_acquire
            unsigned char _Prev = _Counter.exchange(0);
            if (_Prev == 1) {
                return true;
            }
            _STL_VERIFY(_Prev == 0, "Invariant: semaphore counter is non-negative and doesn't exceed max(), "
                                    "possibly caused by memory corruption");
            const auto _Remaining_timeout = __std_atomic_wait_get_remaining_timeout(_Deadline);
            if (_Remaining_timeout == 0) {
                return false;
            }
            __std_atomic_wait_direct(&_Counter, &_Prev, sizeof(_Prev), _Remaining_timeout);
        }
    }

    template <class _Clock, class _Duration>
    _NODISCARD_TRY_CHANGE_STATE bool try_acquire_until(const chrono::time_point<_Clock, _Duration>& _Abs_time) {
        static_assert(chrono::is_clock_v<_Clock>, "Clock type required");
        for (;;) {
            // "happens after release" ordering is provided by this exchange, so loads and waits can be relaxed
            // TRANSITION, GH-1133: should be memory_order_acquire
            unsigned char _Prev = _Counter.exchange(0);
            if (_Prev == 1) {
                return true;
            }
            _STL_VERIFY(_Prev == 0, "Invariant: semaphore counter is non-negative and doesn't exceed max(), "
                                    "possibly caused by memory corruption");

            const unsigned long _Remaining_timeout = _Semaphore_remaining_timeout(_Abs_time);
            if (_Remaining_timeout == 0) {
                return false;
            }

            __std_atomic_wait_direct(&_Counter, &_Prev, sizeof(_Prev), _Remaining_timeout);
        }
    }

private:
    atomic<unsigned char> _Counter;
};

_EXPORT_STD using binary_semaphore = counting_semaphore<1>;

_STD_END

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)
#endif // ^^^ _HAS_CXX20 ^^^

#endif // _STL_COMPILER_PREPROCESSOR
#endif // _SEMAPHORE_
