// condition_variable standard header

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef _CONDITION_VARIABLE_
#define _CONDITION_VARIABLE_
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR

#ifdef _M_CEE_PURE
#error <condition_variable> is not supported when compiling with /clr:pure.
#endif // defined(_M_CEE_PURE)

#include <__msvc_chrono.hpp>
#include <memory>
#include <mutex>
#include <xthreads.h>
#if _HAS_CXX20
#include <stop_token>
#endif // _HAS_CXX20

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

_STD_BEGIN
template <class _Lock>
struct _NODISCARD _Unlock_guard {
    explicit _Unlock_guard(_Lock& _Mtx_) : _Mtx(_Mtx_) {
        _Mtx.unlock();
    }

    ~_Unlock_guard() noexcept /* terminates */ {
        // relock mutex or terminate()
        // condition_variable_any wait functions are required to terminate if
        // the mutex cannot be relocked;
        // we slam into noexcept here for easier user debugging.
        _Mtx.lock();
    }

    _Unlock_guard(const _Unlock_guard&)            = delete;
    _Unlock_guard& operator=(const _Unlock_guard&) = delete;

private:
    _Lock& _Mtx;
};

_EXPORT_STD class condition_variable_any { // class for waiting for conditions with any kind of mutex
public:
    condition_variable_any() : _Myptr{_STD make_shared<mutex>()} {
#ifdef _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR
        _Cnd_init_in_situ(_Mycnd());
#endif // ^^^ defined(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR) ^^^
    }

    ~condition_variable_any() noexcept = default;

    condition_variable_any(const condition_variable_any&)            = delete;
    condition_variable_any& operator=(const condition_variable_any&) = delete;

    void notify_one() noexcept { // wake up one waiter
        lock_guard<mutex> _Guard{*_Myptr};
        _Cnd_signal(_Mycnd());
    }

    void notify_all() noexcept { // wake up all waiters
        lock_guard<mutex> _Guard{*_Myptr};
        _Cnd_broadcast(_Mycnd());
    }

    template <class _Lock>
    void wait(_Lock& _Lck) noexcept /* terminates */ { // wait for signal
        const shared_ptr<mutex> _Ptr = _Myptr; // for immunity to *this destruction
        unique_lock<mutex> _Guard{*_Ptr};
        _Unlock_guard<_Lock> _Unlock_outer{_Lck};
        _Cnd_wait(_Mycnd(), _Ptr->_Mymtx());
        _Guard.unlock();
    } // relock _Lck

    template <class _Lock, class _Predicate>
    void wait(_Lock& _Lck, _Predicate _Pred) noexcept(noexcept(static_cast<bool>(_Pred()))) /* strengthened */ {
        // wait for signal and check predicate
        while (!static_cast<bool>(_Pred())) {
            wait(_Lck);
        }
    }

    template <class _Lock, class _Clock, class _Duration>
    cv_status wait_until(_Lock& _Lck, const chrono::time_point<_Clock, _Duration>& _Abs_time) {
        // wait until time point
        static_assert(chrono::_Is_clock_v<_Clock>, "Clock type required");
        const auto _Now                  = _Clock::now();
        using _Common_duration           = decltype(_Abs_time - _Now);
        const _Common_duration _Rel_time = (_Abs_time <= _Now) ? _Common_duration::zero() : _Abs_time - _Now;
        return wait_for(_Lck, _Rel_time);
    }

    template <class _Lock, class _Clock, class _Duration, class _Predicate>
    bool wait_until(_Lock& _Lck, const chrono::time_point<_Clock, _Duration>& _Abs_time, _Predicate _Pred) {
        // wait for signal with timeout and check predicate
#if _HAS_CXX20
        static_assert(chrono::is_clock_v<_Clock>, "Clock type required");
#endif // _HAS_CXX20
        while (!_Pred()) {
            if (wait_until(_Lck, _Abs_time) == cv_status::timeout) {
                return _Pred();
            }
        }

        return true;
    }

    template <class _Lock, class _Rep, class _Period>
    cv_status wait_for(_Lock& _Lck, const chrono::duration<_Rep, _Period>& _Rel_time) { // wait for duration
        if (_Rel_time <= chrono::duration<_Rep, _Period>::zero()) {
            _Unlock_guard<_Lock> _Unlock_outer{_Lck};
            (void) _Unlock_outer;
            return cv_status::timeout;
        }

        const auto _Rel_time_ms = _Clamped_rel_time_ms_count(_Rel_time);

        // wait for signal with timeout
        const shared_ptr<mutex> _Ptr = _Myptr; // for immunity to *this destruction
        const auto _Start            = _CHRONO steady_clock::now();
        const auto _Target           = _Start + _Rel_time;

        unique_lock<mutex> _Guard{*_Ptr};
        _Unlock_guard<_Lock> _Unlock_outer{_Lck};
        const _Thrd_result _Res = _Cnd_timedwait_for_unchecked(_Mycnd(), _Ptr->_Mymtx(), _Rel_time_ms._Count);
        _Guard.unlock();

        if (_Res == _Thrd_result::_Success) {
            return cv_status::no_timeout; // Real or WinAPI-internal spurious wake
        } else if (_Rel_time_ms._Clamped) {
            return cv_status::no_timeout; // Spuriously wake due to clamped timeout
        } else if (_CHRONO steady_clock::now() < _Target) {
            return cv_status::no_timeout; // Spuriously wake due to premature timeout
        } else {
            return cv_status::timeout;
        }
    }

    template <class _Lock, class _Rep, class _Period, class _Predicate>
    bool wait_for(_Lock& _Lck, const chrono::duration<_Rep, _Period>& _Rel_time, _Predicate _Pred) {
        // wait for signal with timeout and check predicate
        return wait_until(_Lck, _To_absolute_time(_Rel_time), _STD move(_Pred));
    }

#if _HAS_CXX20
private:
    struct _Cv_any_notify_all {
        condition_variable_any* _This;

        explicit _Cv_any_notify_all(condition_variable_any* _This_) noexcept : _This{_This_} {}

        _Cv_any_notify_all(const _Cv_any_notify_all&)            = delete;
        _Cv_any_notify_all& operator=(const _Cv_any_notify_all&) = delete;

        void operator()() const noexcept {
            _This->notify_all();
        }
    };

public:
    template <class _Lock, class _Predicate>
    bool wait(_Lock& _Lck, stop_token _Stoken, _Predicate _Pred)
        noexcept(noexcept(static_cast<bool>(_Pred()))) /* strengthened */ {
        // TRANSITION, ABI: Due to the unsynchronized delivery of notify_all by _Stoken,
        // this implementation cannot tolerate *this destruction while an interruptible wait
        // is outstanding. A future ABI should store both the internal CV and internal mutex
        // in the reference counted block to allow this.
        stop_callback<_Cv_any_notify_all> _Cb{_Stoken, this};
        for (;;) {
            if (_Pred()) {
                return true;
            }

            unique_lock<mutex> _Guard{*_Myptr};
            if (_Stoken.stop_requested()) {
                _Guard.unlock();
                return _Pred();
            }

            _Unlock_guard<_Lock> _Unlock_outer{_Lck};
            _Cnd_wait(_Mycnd(), _Myptr->_Mymtx());
            _Guard.unlock();
        } // relock
    }

    template <class _Lock, class _Clock, class _Duration, class _Predicate>
    bool wait_until(
        _Lock& _Lck, stop_token _Stoken, const chrono::time_point<_Clock, _Duration>& _Abs_time, _Predicate _Pred) {
        static_assert(chrono::is_clock_v<_Clock>, "Clock type required");
        stop_callback<_Cv_any_notify_all> _Cb{_Stoken, this};
        for (;;) {
            if (_Pred()) {
                return true;
            }

            unique_lock<mutex> _Guard{*_Myptr};
            if (_Stoken.stop_requested()) {
                break;
            }

            _Unlock_guard<_Lock> _Unlock_outer{_Lck};
            unique_lock<mutex> _Guard_unlocks_before_locking_outer{_STD move(_Guard)};

            const auto _Now = _Clock::now();
            if (_Now >= _Abs_time) {
                break;
            }

            const unsigned long _Rel_ms_count = _Clamped_rel_time_ms_count(_Abs_time - _Now)._Count;
            (void) _Cnd_timedwait_for_unchecked(_Mycnd(), _Myptr->_Mymtx(), _Rel_ms_count);
            _Guard_unlocks_before_locking_outer.unlock();
        } // relock

        return _Pred();
    }

    template <class _Lock, class _Rep, class _Period, class _Predicate>
    bool wait_for(_Lock& _Lck, stop_token _Stoken, const chrono::duration<_Rep, _Period>& _Rel_time, _Predicate _Pred) {
        return wait_until(_Lck, _STD move(_Stoken), _To_absolute_time(_Rel_time), _STD move(_Pred));
    }
#endif // _HAS_CXX20

private:
    shared_ptr<mutex> _Myptr;

    _Cnd_internal_imp_t _Cnd_storage{};

    _NODISCARD _Cnd_t _Mycnd() noexcept {
        return &_Cnd_storage;
    }
};

_EXPORT_STD inline void notify_all_at_thread_exit(condition_variable& _Cnd, unique_lock<mutex> _Lck) noexcept
/* strengthened */ {
    // register _Cnd for release at thread exit
    _Cnd._Register(_Lck, nullptr);
}
_STD_END
#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)
#endif // _STL_COMPILER_PREPROCESSOR
#endif // _CONDITION_VARIABLE_
