// stop_token standard header

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef _STOP_TOKEN_
#define _STOP_TOKEN_
#include <yvals.h>
#if _STL_COMPILER_PREPROCESSOR

#if !_HAS_CXX20
_EMIT_STL_WARNING(STL4038, "The contents of <stop_token> are available only with C++20 or later.");
#else // ^^^ !_HAS_CXX20 / _HAS_CXX20 vvv

#include <atomic>
#include <xmemory>
#include <xthreads.h>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

_STD_BEGIN
_EXPORT_STD struct nostopstate_t {
    explicit nostopstate_t() = default;
};

_EXPORT_STD inline constexpr nostopstate_t nostopstate{};

struct _Stop_state;
_EXPORT_STD class stop_token;

class _Stop_callback_base {
    friend _Stop_state;

private:
    using _Callback_fn = void(__cdecl*)(_Stop_callback_base*) _NOEXCEPT_FNPTR;

public:
    explicit _Stop_callback_base(const _Callback_fn _Fn_) noexcept : _Fn{_Fn_} {}

    _Stop_callback_base(const _Stop_callback_base&)            = delete;
    _Stop_callback_base& operator=(const _Stop_callback_base&) = delete;

protected:
    // if _Token is _Stop_requested, calls the callback;
    // otherwise, inserts *this into the callback list if stop is possible
    inline void _Attach(const stop_token& _Token) noexcept;
    inline void _Attach(stop_token&& _Token) noexcept;

    // if *this is in a callback list, removes it
    inline void _Detach() noexcept;

private:
    template <bool _Transfer_ownership>
    void _Do_attach(conditional_t<_Transfer_ownership, _Stop_state*&, _Stop_state* const> _State) noexcept;

protected:
    _Stop_state* _Parent       = nullptr;
    _Stop_callback_base* _Next = nullptr;
    _Stop_callback_base* _Prev = nullptr;
    _Callback_fn _Fn;
};

struct _Stop_state {
    atomic<uint32_t> _Stop_tokens  = 1; // plus one shared by all stop_sources
    atomic<uint32_t> _Stop_sources = 2; // plus the low order bit is the stop requested bit
    _Locked_pointer<_Stop_callback_base> _Callbacks;
    // always uses relaxed operations; ordering provided by the _Callbacks lock
    // (atomic just to get wait/notify support)
    atomic<const _Stop_callback_base*> _Current_callback = nullptr;
    _Thrd_id_t _Stopping_thread                          = 0;

    _NODISCARD bool _Stop_requested() const noexcept {
        return (_Stop_sources.load() & uint32_t{1}) != 0;
    }

    _NODISCARD bool _Stop_possible() const noexcept {
        return _Stop_sources.load() != 0;
    }

    _NODISCARD bool _Request_stop() noexcept {
        // Attempts to request stop and call callbacks, returns whether request was successful
        if ((_Stop_sources.fetch_or(uint32_t{1}) & uint32_t{1}) != 0) {
            // another thread already requested
            return false;
        }

        _Stopping_thread = _Thrd_id();
        for (;;) {
            auto _Head = _Callbacks._Lock_and_load();
            _Current_callback.store(_Head, memory_order_relaxed);
            _Current_callback.notify_all();
            if (_Head == nullptr) {
                _Callbacks._Store_and_unlock(nullptr);
                return true;
            }

            const auto _Next = _STD exchange(_Head->_Next, nullptr);
            _STL_INTERNAL_CHECK(_Head->_Prev == nullptr);
            if (_Next != nullptr) {
                _Next->_Prev = nullptr;
            }

            _Callbacks._Store_and_unlock(_Next); // unlock before running _Head so other registrations
                                                 // can detach without blocking on the callback

            _Head->_Fn(_Head); // might destroy *_Head
        }
    }
};

_EXPORT_STD class stop_source;

_EXPORT_STD class stop_token {
    friend stop_source;
    friend _Stop_callback_base;

public:
    stop_token() noexcept : _State{} {}
    stop_token(const stop_token& _Other) noexcept : _State{_Other._State} {
        const auto _Local = _State;
        if (_Local != nullptr) {
            _Local->_Stop_tokens.fetch_add(1, memory_order_relaxed);
        }
    }

    stop_token(stop_token&& _Other) noexcept : _State{_STD exchange(_Other._State, nullptr)} {}
    stop_token& operator=(const stop_token& _Other) noexcept {
        stop_token{_Other}.swap(*this);
        return *this;
    }

    stop_token& operator=(stop_token&& _Other) noexcept {
        stop_token{_STD move(_Other)}.swap(*this);
        return *this;
    }

    ~stop_token() {
        const auto _Local = _State;
        if (_Local != nullptr) {
            if (_Local->_Stop_tokens.fetch_sub(1, memory_order_acq_rel) == 1) {
                delete _Local;
            }
        }
    }

    void swap(stop_token& _Other) noexcept {
        _STD swap(_State, _Other._State);
    }

    _NODISCARD bool stop_requested() const noexcept {
        const auto _Local = _State;
        return _Local != nullptr && _Local->_Stop_requested();
    }

    _NODISCARD bool stop_possible() const noexcept {
        const auto _Local = _State;
        return _Local != nullptr && _Local->_Stop_possible();
    }

    _NODISCARD friend bool operator==(const stop_token& _Lhs, const stop_token& _Rhs) noexcept = default;

    friend void swap(stop_token& _Lhs, stop_token& _Rhs) noexcept {
        _STD swap(_Lhs._State, _Rhs._State);
    }

private:
    explicit stop_token(_Stop_state* const _State_) : _State{_State_} {}

    _Stop_state* _State;
};

_EXPORT_STD class stop_source {
public:
    stop_source() : _State{new _Stop_state} {}
    explicit stop_source(nostopstate_t) noexcept : _State{} {}
    stop_source(const stop_source& _Other) noexcept : _State{_Other._State} {
        const auto _Local = _State;
        if (_Local != nullptr) {
            _Local->_Stop_sources.fetch_add(2, memory_order_relaxed);
        }
    }

    stop_source(stop_source&& _Other) noexcept : _State{_STD exchange(_Other._State, nullptr)} {}
    stop_source& operator=(const stop_source& _Other) noexcept {
        stop_source{_Other}.swap(*this);
        return *this;
    }

    stop_source& operator=(stop_source&& _Other) noexcept {
        stop_source{_STD move(_Other)}.swap(*this);
        return *this;
    }

    ~stop_source() {
        const auto _Local = _State;
        if (_Local != nullptr) {
            if ((_Local->_Stop_sources.fetch_sub(2, memory_order_acq_rel) >> 1) == 1) {
                if (_Local->_Stop_tokens.fetch_sub(1, memory_order_acq_rel) == 1) {
                    delete _Local;
                }
            }
        }
    }

    void swap(stop_source& _Other) noexcept {
        _STD swap(_State, _Other._State);
    }

    _NODISCARD stop_token get_token() const noexcept {
        const auto _Local = _State;
        if (_Local != nullptr) {
            _Local->_Stop_tokens.fetch_add(1, memory_order_relaxed);
        }

        return stop_token{_Local};
    }

    _NODISCARD bool stop_requested() const noexcept {
        const auto _Local = _State;
        return _Local != nullptr && _Local->_Stop_requested();
    }

    _NODISCARD bool stop_possible() const noexcept {
        return _State != nullptr;
    }

    bool request_stop() noexcept {
        const auto _Local = _State;
        return _Local && _Local->_Request_stop();
    }

    _NODISCARD friend bool operator==(const stop_source& _Lhs, const stop_source& _Rhs) noexcept = default;

    friend void swap(stop_source& _Lhs, stop_source& _Rhs) noexcept {
        _STD swap(_Lhs._State, _Rhs._State);
    }

private:
    _Stop_state* _State;
};

template <bool _Transfer_ownership>
void _Stop_callback_base::_Do_attach(
    conditional_t<_Transfer_ownership, _Stop_state*&, _Stop_state* const> _State_raw) noexcept {
    const auto _State = _State_raw; // avoid an indirection in all of the below
    if (_State == nullptr) {
        return;
    }

    // fast path check if the state is already known
    auto _Local_sources = _State->_Stop_sources.load();
    if ((_Local_sources & uint32_t{1}) != 0) {
        // stop already requested
        _Fn(this);
        return;
    }

    if (_Local_sources == 0) {
        return; // stop not possible
    }

    // fast path doesn't know, so try to insert
    auto _Head = _State->_Callbacks._Lock_and_load();
    // recheck the state in case it changed while we were waiting to acquire the lock
    _Local_sources = _State->_Stop_sources.load();
    if ((_Local_sources & uint32_t{1}) != 0) {
        // stop already requested
        _State->_Callbacks._Store_and_unlock(_Head);
        _Fn(this);
        return;
    }

    if (_Local_sources != 0) {
        // stop possible, do the insert
        _Parent = _State;
        _Next   = _Head;
        if constexpr (_Transfer_ownership) {
            _State_raw = nullptr;
        } else {
            _State->_Stop_tokens.fetch_add(1, memory_order_relaxed);
        }

        if (_Head != nullptr) {
            _Head->_Prev = this;
        }

        _Head = this;
    }

    _State->_Callbacks._Store_and_unlock(_Head);
}

inline void _Stop_callback_base::_Attach(const stop_token& _Token) noexcept {
    this->_Do_attach<false>(_Token._State);
}

inline void _Stop_callback_base::_Attach(stop_token&& _Token) noexcept {
    this->_Do_attach<true>(_Token._State);
}

inline void _Stop_callback_base::_Detach() noexcept {
    stop_token _Token{_Parent}; // transfers ownership
    if (_Token._State == nullptr) {
        // callback was never inserted into the list
        return;
    }

    auto _Head = _Token._State->_Callbacks._Lock_and_load();
    if (this == _Head) {
        // we are still in the list, so the callback is not being request_stop'd
        const auto _Local_next = _Next;
        if (_Local_next != nullptr) {
            _Local_next->_Prev = nullptr;
        }

        _STL_INTERNAL_CHECK(_Prev == nullptr);
        _Token._State->_Callbacks._Store_and_unlock(_Next);
        return;
    }

    const auto _Local_prev = _Prev;
    if (_Local_prev != nullptr) {
        // we are still in the list, so the callback is not being request_stop'd, and there is at least one other
        // callback still registered
        const auto _Local_next = _Next;
        if (_Local_next != nullptr) {
            _Next->_Prev = _Local_prev;
        }

        _Prev->_Next = _Local_next;
        _Token._State->_Callbacks._Store_and_unlock(_Head);
        return;
    }

    // we aren't in the callback list even though we were added to it, so the stop requesting thread is attempting to
    // call the callback
    _STL_INTERNAL_CHECK((_Token._State->_Stop_sources.load() & uint32_t{1}) != 0);
    if (_Token._State->_Current_callback.load(memory_order_acquire) != this
        || _Token._State->_Stopping_thread == _Thrd_id()) {
        // the callback is done or the dtor is being recursively reentered, do not block
        _Token._State->_Callbacks._Store_and_unlock(_Head);
        return;
    }

    // the callback is being executed by another thread, block until it is complete
    _Token._State->_Callbacks._Store_and_unlock(_Head);
    _Token._State->_Current_callback.wait(this, memory_order_acquire);
}

_EXPORT_STD template <class _Callback>
class stop_callback : public _Stop_callback_base {
public:
    using callback_type = _Callback;

    template <class _CbInitTy>
        requires is_constructible_v<_Callback, _CbInitTy>
    explicit stop_callback(const stop_token& _Token, _CbInitTy&& _Cb_)
        noexcept(is_nothrow_constructible_v<_Callback, _CbInitTy>)
        : _Stop_callback_base{_Invoke_by_stop}, _Cb(_STD forward<_CbInitTy>(_Cb_)) {
        _Attach(_Token);
    }

    template <class _CbInitTy>
        requires is_constructible_v<_Callback, _CbInitTy>
    explicit stop_callback(stop_token&& _Token, _CbInitTy&& _Cb_)
        noexcept(is_nothrow_constructible_v<_Callback, _CbInitTy>)
        : _Stop_callback_base{_Invoke_by_stop}, _Cb(_STD forward<_CbInitTy>(_Cb_)) {
        _Attach(_STD move(_Token));
    }

    ~stop_callback() {
        _Detach();
    }

private:
    static void __cdecl _Invoke_by_stop(_Stop_callback_base* const _This) noexcept // terminates
    {
        _STD forward<_Callback>(static_cast<stop_callback*>(_This)->_Cb)();
    }

    _Callback _Cb;
};

template <class _Callback>
stop_callback(stop_token, _Callback) -> stop_callback<_Callback>;

_STD_END
#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)
#endif // ^^^ _HAS_CXX20 ^^^
#endif // _STL_COMPILER_PREPROCESSOR
#endif // _STOP_TOKEN_
