// thread standard header

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef _THREAD_
#define _THREAD_
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR

#ifdef _M_CEE_PURE
#error <thread> is not supported when compiling with /clr:pure.
#endif // defined(_M_CEE_PURE)

#include <__msvc_chrono.hpp>
#include <memory>
#include <process.h>
#include <tuple>
#include <xthreads.h>

#if _HAS_CXX20
#include <compare>
#include <stop_token>
#endif // _HAS_CXX20

#if _HAS_CXX23
#include <__msvc_formatter.hpp>
#endif // _HAS_CXX23

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

_STD_BEGIN
#if _HAS_CXX20
_EXPORT_STD class jthread;
#endif // _HAS_CXX20

_EXPORT_STD class thread { // class for observing and managing threads
public:
    class id;

    using native_handle_type = void*;

    thread() noexcept : _Thr{} {}

private:
#if _HAS_CXX20
    friend jthread;
#endif // _HAS_CXX20

    template <class _Tuple, size_t... _Indices>
    static unsigned int __stdcall _Invoke(void* _RawVals) noexcept /* terminates */ {
        // adapt invoke of user's callable object to _beginthreadex's thread procedure
        const unique_ptr<_Tuple> _FnVals(static_cast<_Tuple*>(_RawVals));
        _Tuple& _Tup = *_FnVals.get(); // avoid ADL, handle incomplete types
        _STD invoke(_STD move(_STD get<_Indices>(_Tup))...);
        _Cnd_do_broadcast_at_thread_exit(); // TRANSITION, ABI
        return 0;
    }

    template <class _Tuple, size_t... _Indices>
    _NODISCARD static constexpr auto _Get_invoke(index_sequence<_Indices...>) noexcept {
        return &_Invoke<_Tuple, _Indices...>;
    }

#pragma warning(push) // pointer or reference to potentially throwing function passed to 'extern "C"' function under
#pragma warning(disable : 5039) // -EHc. Undefined behavior may occur if this function throws an exception. (/Wall)
    template <class _Fn, class... _Args>
    void _Start(_Fn&& _Fx, _Args&&... _Ax) {
        using _Tuple                 = tuple<decay_t<_Fn>, decay_t<_Args>...>;
        auto _Decay_copied           = _STD make_unique<_Tuple>(_STD forward<_Fn>(_Fx), _STD forward<_Args>(_Ax)...);
        constexpr auto _Invoker_proc = _Get_invoke<_Tuple>(make_index_sequence<1 + sizeof...(_Args)>{});

        _Thr._Hnd =
            reinterpret_cast<void*>(_CSTD _beginthreadex(nullptr, 0, _Invoker_proc, _Decay_copied.get(), 0, &_Thr._Id));

        if (_Thr._Hnd) { // ownership transferred to the thread
            (void) _Decay_copied.release();
        } else { // failed to start thread
            _Thr._Id = 0;
            _Throw_Cpp_error(_RESOURCE_UNAVAILABLE_TRY_AGAIN);
        }
    }
#pragma warning(pop)

public:
    template <class _Fn, class... _Args, enable_if_t<!is_same_v<_Remove_cvref_t<_Fn>, thread>, int> = 0>
    _NODISCARD_CTOR_THREAD explicit thread(_Fn&& _Fx, _Args&&... _Ax) {
        _Start(_STD forward<_Fn>(_Fx), _STD forward<_Args>(_Ax)...);
    }

    ~thread() noexcept {
        if (joinable()) {
            _STD terminate(); // per N4950 [thread.thread.destr]/1
        }
    }

    thread(thread&& _Other) noexcept : _Thr(_STD exchange(_Other._Thr, {})) {}

    thread& operator=(thread&& _Other) noexcept {
        if (joinable()) {
            _STD terminate(); // per N4950 [thread.thread.assign]/1
        }

        _Thr = _STD exchange(_Other._Thr, {});
        return *this;
    }

    thread(const thread&)            = delete;
    thread& operator=(const thread&) = delete;

    void swap(thread& _Other) noexcept {
        _STD swap(_Thr, _Other._Thr);
    }

    _NODISCARD bool joinable() const noexcept {
        return _Thr._Id != 0;
    }

    void join() {
        if (!joinable()) {
            _Throw_Cpp_error(_INVALID_ARGUMENT);
        }

        if (_Thr._Id == _Thrd_id()) {
            _Throw_Cpp_error(_RESOURCE_DEADLOCK_WOULD_OCCUR);
        }

        if (_Thrd_join(_Thr, nullptr) != _Thrd_result::_Success) {
            _Throw_Cpp_error(_NO_SUCH_PROCESS);
        }

        _Thr = {};
    }

    void detach() {
        if (!joinable()) {
            _Throw_Cpp_error(_INVALID_ARGUMENT);
        }

        if (_Thrd_detach(_Thr) != _Thrd_result::_Success) {
            _Throw_Cpp_error(_INVALID_ARGUMENT);
        }

        _Thr = {};
    }

    _NODISCARD id get_id() const noexcept;

    _NODISCARD native_handle_type native_handle() noexcept /* strengthened */ { // return Win32 HANDLE as void *
        return _Thr._Hnd;
    }

    _NODISCARD static unsigned int hardware_concurrency() noexcept {
        return _Thrd_hardware_concurrency();
    }

private:
    _Thrd_t _Thr;
};

template <class _Rep, class _Period>
_NODISCARD auto _To_absolute_time(const chrono::duration<_Rep, _Period>& _Rel_time) noexcept {
    constexpr auto _Zero                 = chrono::duration<_Rep, _Period>::zero();
    const auto _Now                      = chrono::steady_clock::now();
    decltype(_Now + _Rel_time) _Abs_time = _Now; // return common type
    if (_Rel_time > _Zero) {
        constexpr auto _Forever = (decltype(_Abs_time)::max)();
        if (_Abs_time < _Forever - _Rel_time) {
            _Abs_time += _Rel_time;
        } else {
            _Abs_time = _Forever;
        }
    }
    return _Abs_time;
}

struct _Clamped_rel_time_ms_count_result {
    unsigned long _Count;
    bool _Clamped;
};

template <class _Duration>
_NODISCARD _Clamped_rel_time_ms_count_result _Clamped_rel_time_ms_count(const _Duration& _Rel) {
    // _Clamp must be less than 2^32 - 1 (INFINITE) milliseconds, but is otherwise arbitrary.
    constexpr chrono::milliseconds _Clamp{chrono::hours{24}};

    if (_Rel > _Clamp) {
        return {static_cast<unsigned long>(_Clamp.count()), true};
    } else {
        const auto _Rel_ms = chrono::ceil<chrono::milliseconds>(_Rel);
        return {static_cast<unsigned long>(_Rel_ms.count()), false};
    }
}

namespace this_thread {
    _EXPORT_STD _NODISCARD thread::id get_id() noexcept;

    _EXPORT_STD inline void yield() noexcept {
        _Thrd_yield();
    }

    _EXPORT_STD template <class _Clock, class _Duration>
    void sleep_until(const chrono::time_point<_Clock, _Duration>& _Abs_time) {
        static_assert(chrono::_Is_clock_v<_Clock>, "Clock type required");
        for (;;) {
            const auto _Now = _Clock::now();
            if (_Abs_time <= _Now) {
                return;
            }

            const unsigned long _Rel_ms_count = _Clamped_rel_time_ms_count(_Abs_time - _Now)._Count;
            _Thrd_sleep_for(_Rel_ms_count);
        }
    }

    _EXPORT_STD template <class _Rep, class _Period>
    void sleep_for(const chrono::duration<_Rep, _Period>& _Rel_time) {
        sleep_until(_To_absolute_time(_Rel_time));
    }
} // namespace this_thread

class thread::id { // thread id
public:
    id() noexcept = default; // id for no thread

#if _HAS_CXX23
    _NODISCARD _Thrd_id_t _Get_underlying_id() const noexcept {
        return _Id;
    }
#endif // _HAS_CXX23

private:
    explicit id(_Thrd_id_t _Other_id) noexcept : _Id(_Other_id) {}

    _Thrd_id_t _Id = 0;

    friend thread::id thread::get_id() const noexcept;
    friend thread::id this_thread::get_id() noexcept;
    friend bool operator==(thread::id _Left, thread::id _Right) noexcept;
#if _HAS_CXX20
    friend strong_ordering operator<=>(thread::id _Left, thread::id _Right) noexcept;
#else // ^^^ _HAS_CXX20 / !_HAS_CXX20 vvv
    friend bool operator<(thread::id _Left, thread::id _Right) noexcept;
#endif // ^^^ !_HAS_CXX20 ^^^
    template <class _Ch, class _Tr>
    friend basic_ostream<_Ch, _Tr>& operator<<(basic_ostream<_Ch, _Tr>& _Str, thread::id _Id);
    friend hash<thread::id>;
};

_NODISCARD inline thread::id thread::get_id() const noexcept {
    return thread::id{_Thr._Id};
}

_EXPORT_STD _NODISCARD inline thread::id this_thread::get_id() noexcept {
    return thread::id{_Thrd_id()};
}

_EXPORT_STD inline void swap(thread& _Left, thread& _Right) noexcept {
    _Left.swap(_Right);
}

_EXPORT_STD _NODISCARD inline bool operator==(thread::id _Left, thread::id _Right) noexcept {
    return _Left._Id == _Right._Id;
}

#if _HAS_CXX20
_EXPORT_STD _NODISCARD inline strong_ordering operator<=>(thread::id _Left, thread::id _Right) noexcept {
    return _Left._Id <=> _Right._Id;
}
#else // ^^^ _HAS_CXX20 / !_HAS_CXX20 vvv
_NODISCARD inline bool operator!=(thread::id _Left, thread::id _Right) noexcept {
    return !(_Left == _Right);
}

_NODISCARD inline bool operator<(thread::id _Left, thread::id _Right) noexcept {
    return _Left._Id < _Right._Id;
}

_NODISCARD inline bool operator<=(thread::id _Left, thread::id _Right) noexcept {
    return !(_Right < _Left);
}

_NODISCARD inline bool operator>(thread::id _Left, thread::id _Right) noexcept {
    return _Right < _Left;
}

_NODISCARD inline bool operator>=(thread::id _Left, thread::id _Right) noexcept {
    return !(_Left < _Right);
}
#endif // ^^^ !_HAS_CXX20 ^^^

_EXPORT_STD template <class _Ch, class _Tr>
basic_ostream<_Ch, _Tr>& operator<<(basic_ostream<_Ch, _Tr>& _Str, thread::id _Id) {
    _STL_INTERNAL_STATIC_ASSERT(sizeof(_Thrd_id_t) == 4);
    _Ch _Buff[11]; // can hold 2^32 - 1, plus terminating null
    _Ch* _RNext = _STD end(_Buff);
    *--_RNext   = static_cast<_Ch>('\0');
    _RNext      = _STD _UIntegral_to_buff(_RNext, _Id._Id);
    return _Str << _RNext;
}

#if _HAS_CXX23
// Per LWG-3997, `_CharT` in library-provided `formatter` specializations is
// constrained to character types supported by `format`.
template <_Format_supported_charT _CharT>
struct formatter<thread::id, _CharT> {
    template <class _Pc = basic_format_parse_context<_CharT>> // improves throughput, see GH-5003
    constexpr _Pc::iterator parse(type_identity_t<_Pc&> _Parse_ctx) {
        return _Impl._Parse(_Parse_ctx);
    }

    template <class _FormatContext>
    _FormatContext::iterator format(thread::id _Val, _FormatContext& _Format_ctx) const {
        _STL_INTERNAL_STATIC_ASSERT(sizeof(_Thrd_id_t) == 4);
        _CharT _Buff[10]; // can hold 2^32 - 1
        _CharT* const _Last        = _STD end(_Buff);
        const _CharT* const _First = _STD _UIntegral_to_buff(_Last, _Val._Get_underlying_id());
        return _Impl._Format(_Format_ctx, static_cast<int>(_Last - _First), _Fmt_align::_Right,
            [&](_FormatContext::iterator _Out) { return _RANGES copy(_First, _Last, _STD move(_Out)).out; });
    }

private:
    _Fill_align_and_width_formatter<_CharT> _Impl;
};

template <>
inline constexpr bool enable_nonlocking_formatter_optimization<thread::id> = true;
#endif // _HAS_CXX23

template <>
struct hash<thread::id> {
    using _ARGUMENT_TYPE_NAME _CXX17_DEPRECATE_ADAPTOR_TYPEDEFS = thread::id;
    using _RESULT_TYPE_NAME _CXX17_DEPRECATE_ADAPTOR_TYPEDEFS   = size_t;

    _NODISCARD _STATIC_CALL_OPERATOR size_t operator()(const thread::id _Keyval) _CONST_CALL_OPERATOR noexcept {
        return _Hash_representation(_Keyval._Id);
    }
};

#if _HAS_CXX20
_EXPORT_STD class jthread {
public:
    using id                 = thread::id;
    using native_handle_type = thread::native_handle_type;

    jthread() noexcept : _Impl{}, _Ssource{nostopstate} {}

    template <class _Fn, class... _Args>
        requires (!is_same_v<remove_cvref_t<_Fn>, jthread>)
    _NODISCARD_CTOR_JTHREAD explicit jthread(_Fn&& _Fx, _Args&&... _Ax) {
        if constexpr (is_invocable_v<decay_t<_Fn>, stop_token, decay_t<_Args>...>) {
            _Impl._Start(_STD forward<_Fn>(_Fx), _Ssource.get_token(), _STD forward<_Args>(_Ax)...);
        } else {
            _Impl._Start(_STD forward<_Fn>(_Fx), _STD forward<_Args>(_Ax)...);
        }
    }

    ~jthread() {
        _Try_cancel_and_join();
    }

    jthread(const jthread&)            = delete;
    jthread(jthread&&) noexcept        = default;
    jthread& operator=(const jthread&) = delete;

    jthread& operator=(jthread&& _Other) noexcept {
        if (this == _STD addressof(_Other)) {
            return *this;
        }

        _Try_cancel_and_join();
        _Impl    = _STD move(_Other._Impl);
        _Ssource = _STD move(_Other._Ssource);
        return *this;
    }

    void swap(jthread& _Other) noexcept {
        _Impl.swap(_Other._Impl);
        _Ssource.swap(_Other._Ssource);
    }

    _NODISCARD bool joinable() const noexcept {
        return _Impl.joinable();
    }

    void join() {
        _Impl.join();
    }

    void detach() {
        _Impl.detach();
    }

    _NODISCARD id get_id() const noexcept {
        return _Impl.get_id();
    }

    _NODISCARD native_handle_type native_handle() noexcept /* strengthened */ {
        return _Impl.native_handle();
    }

    _NODISCARD stop_source get_stop_source() noexcept {
        return _Ssource;
    }

    _NODISCARD stop_token get_stop_token() const noexcept {
        return _Ssource.get_token();
    }

    bool request_stop() noexcept {
        return _Ssource.request_stop();
    }

    friend void swap(jthread& _Lhs, jthread& _Rhs) noexcept {
        _Lhs.swap(_Rhs);
    }

    _NODISCARD static unsigned int hardware_concurrency() noexcept {
        return thread::hardware_concurrency();
    }

private:
    void _Try_cancel_and_join() noexcept {
        if (_Impl.joinable()) {
            _Ssource.request_stop();
            _Impl.join();
        }
    }

    thread _Impl;
    stop_source _Ssource;
};
#endif // _HAS_CXX20
_STD_END

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)
#endif // _STL_COMPILER_PREPROCESSOR
#endif // _THREAD_
