// coroutine experimental header

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Library support of coroutines TS, https://wg21.link/p0057

#ifndef _EXPERIMENTAL_COROUTINE_
#define _EXPERIMENTAL_COROUTINE_
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR

#include <memory>
#include <new>
#if _HAS_EXCEPTIONS
#include <exception>
#endif
#include <cstdint>
#include <type_traits>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

#if defined(__clang__) && !defined(_SILENCE_CLANG_COROUTINE_MESSAGE)
#error The <experimental/coroutine>, <experimental/generator>, and <experimental/resumable> \
headers do not support Clang, but the C++20 <coroutine> header does.
#endif // defined(__clang__) && !defined(_SILENCE_CLANG_COROUTINE_MESSAGE)

#ifdef __cpp_impl_coroutine
#error The <experimental/coroutine> and <experimental/resumable> headers are only supported with \
/await and implement pre-C++20 coroutine support. Use <coroutine> for standard C++20 coroutines.
#endif // defined(__cpp_impl_coroutine)

// intrinsics used in implementation of coroutine_handle
extern "C" size_t _coro_resume(void*) noexcept;
extern "C" void _coro_destroy(void*) noexcept;
extern "C" size_t _coro_done(void*) noexcept;
#pragma intrinsic(_coro_resume)
#pragma intrinsic(_coro_destroy)
#pragma intrinsic(_coro_done)

#ifndef _ALLOW_COROUTINE_ABI_MISMATCH
#pragma detect_mismatch("_COROUTINE_ABI", "1")
#endif // _ALLOW_COROUTINE_ABI_MISMATCH

_STD_BEGIN

namespace experimental {

    template <class _Ret, class = void>
    struct _Coroutine_traits_sfinae {};

    template <class _Ret>
    struct _Coroutine_traits_sfinae<_Ret, void_t<typename _Ret::promise_type>> {
        using promise_type = typename _Ret::promise_type;
    };

    template <class _Ret, class... _Ts>
    struct coroutine_traits : _Coroutine_traits_sfinae<_Ret> {};

    template <class _PromiseT = void>
    struct coroutine_handle;

    template <>
    struct coroutine_handle<void> { // no promise access
        coroutine_handle() noexcept {}

        coroutine_handle(nullptr_t) noexcept {}

        coroutine_handle& operator=(nullptr_t) noexcept {
            _Ptr = nullptr;
            return *this;
        }

        static coroutine_handle from_address(void* _Addr) noexcept {
            coroutine_handle _Result;
            _Result._Ptr = static_cast<_Resumable_frame_prefix*>(_Addr);
            return _Result;
        }

        void* address() const noexcept {
            return _Ptr;
        }

        [[deprecated("coroutine_handle::to_address() is deprecated. "
                     "Use coroutine_handle::address() instead.")]] void*
            to_address() const noexcept {
            return _Ptr;
        }

        void operator()() const {
            resume();
        }

        explicit operator bool() const noexcept {
            return _Ptr != nullptr;
        }

        void resume() const {
            _coro_resume(_Ptr);
        }

        void destroy() {
            _coro_destroy(_Ptr);
        }

        bool done() const {
            // REVISIT: should return _coro_done() == 0; when intrinsic is
            // hooked up
            return _Ptr->_Index == 0;
        }

        struct _Resumable_frame_prefix {
            using _Resume_fn = void(__cdecl*)(void*);

            _Resume_fn _Fn;
            uint16_t _Index;
            uint16_t _Flags;
        };

    protected:
        _Resumable_frame_prefix* _Ptr = nullptr;
    };

    template <class _PromiseT>
    struct coroutine_handle : coroutine_handle<> { // general form
        using coroutine_handle<>::coroutine_handle;

        [[deprecated("coroutine_handle::from_promise(T*) with pointer parameter is deprecated. "
                     "Use coroutine_handle::from_promise(T&) instead.")]] static coroutine_handle
            from_promise(_PromiseT* _Prom) noexcept {
            return from_promise(*_Prom);
        }

        static coroutine_handle from_promise(_PromiseT& _Prom) noexcept {
            auto _FramePtr = reinterpret_cast<char*>(_STD addressof(_Prom)) + _ALIGNED_SIZE;
            coroutine_handle<_PromiseT> _Result;
            _Result._Ptr = reinterpret_cast<_Resumable_frame_prefix*>(_FramePtr);
            return _Result;
        }

        static coroutine_handle from_address(void* _Addr) noexcept {
            coroutine_handle _Result;
            _Result._Ptr = static_cast<_Resumable_frame_prefix*>(_Addr);
            return _Result;
        }

        coroutine_handle& operator=(nullptr_t) noexcept {
            _Ptr = nullptr;
            return *this;
        }

        static const size_t _ALIGN_REQ = sizeof(void*) * 2;

        static const size_t _ALIGNED_SIZE =
            is_empty_v<_PromiseT> ? 0 : ((sizeof(_PromiseT) + _ALIGN_REQ - 1) & ~(_ALIGN_REQ - 1));

        _PromiseT& promise() const noexcept {
            return *const_cast<_PromiseT*>(
                reinterpret_cast<const _PromiseT*>(reinterpret_cast<const char*>(_Ptr) - _ALIGNED_SIZE));
        }
    };

    _NODISCARD inline bool operator==(coroutine_handle<> _Left, coroutine_handle<> _Right) noexcept {
        return _Left.address() == _Right.address();
    }

    _NODISCARD inline bool operator<(coroutine_handle<> _Left, coroutine_handle<> _Right) noexcept {
        return less<void*>()(_Left.address(), _Right.address());
    }

    _NODISCARD inline bool operator!=(coroutine_handle<> _Left, coroutine_handle<> _Right) noexcept {
        return !(_Left == _Right);
    }

    _NODISCARD inline bool operator>(coroutine_handle<> _Left, coroutine_handle<> _Right) noexcept {
        return _Right < _Left;
    }

    _NODISCARD inline bool operator<=(coroutine_handle<> _Left, coroutine_handle<> _Right) noexcept {
        return !(_Left > _Right);
    }

    _NODISCARD inline bool operator>=(coroutine_handle<> _Left, coroutine_handle<> _Right) noexcept {
        return !(_Left < _Right);
    }

    struct suspend_if {
        bool _Ready;

        explicit suspend_if(bool _Condition) noexcept : _Ready(!_Condition) {}

        bool await_ready() noexcept {
            return _Ready;
        }

        void await_suspend(coroutine_handle<>) noexcept {}

        void await_resume() noexcept {}
    };

    struct suspend_always {
        bool await_ready() noexcept {
            return false;
        }

        void await_suspend(coroutine_handle<>) noexcept {}

        void await_resume() noexcept {}
    };

    struct suspend_never {
        bool await_ready() noexcept {
            return true;
        }

        void await_suspend(coroutine_handle<>) noexcept {}

        void await_resume() noexcept {}
    };

    // _Resumable_helper_traits class isolates front-end from public surface
    // naming changes

    template <class _Ret, class... _Ts>
    struct _Resumable_helper_traits {
        using _Traits      = coroutine_traits<_Ret, _Ts...>;
        using _PromiseT    = typename _Traits::promise_type;
        using _Handle_type = coroutine_handle<_PromiseT>;

        static _PromiseT* _Promise_from_frame(void* _Addr) noexcept {
            return reinterpret_cast<_PromiseT*>(reinterpret_cast<char*>(_Addr) - _Handle_type::_ALIGNED_SIZE);
        }

        static _Handle_type _Handle_from_frame(void* _Addr) noexcept {
            return _Handle_type::from_promise(*_Promise_from_frame(_Addr));
        }

        static void _Set_exception(void* _Addr) {
            _Promise_from_frame(_Addr)->set_exception(_STD current_exception());
        }

        static void _ConstructPromise(void* _Addr, void* _Resume_addr, int _HeapElision) {
            *reinterpret_cast<void**>(_Addr) = _Resume_addr;
            *reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(_Addr) + sizeof(void*)) =
                2u + (_HeapElision ? 0u : 0x10000u);
            auto _Prom = _Promise_from_frame(_Addr);
            ::new (static_cast<void*>(_Prom)) _PromiseT();
        }

        static void _DestructPromise(void* _Addr) {
            _Promise_from_frame(_Addr)->~_PromiseT();
        }
    };
} // namespace experimental

_STD_END

// resumable functions support intrinsics

extern "C" size_t _coro_frame_size() noexcept;
extern "C" void* _coro_frame_ptr() noexcept;
extern "C" void _coro_init_block() noexcept;
extern "C" void* _coro_resume_addr() noexcept;
extern "C" void _coro_init_frame(void*) noexcept;
extern "C" void _coro_save(size_t) noexcept;
extern "C" void _coro_suspend(size_t) noexcept;
extern "C" void _coro_cancel() noexcept;
extern "C" void _coro_resume_block() noexcept;

#pragma intrinsic(_coro_frame_size)
#pragma intrinsic(_coro_frame_ptr)
#pragma intrinsic(_coro_init_block)
#pragma intrinsic(_coro_resume_addr)
#pragma intrinsic(_coro_init_frame)
#pragma intrinsic(_coro_save)
#pragma intrinsic(_coro_suspend)
#pragma intrinsic(_coro_cancel)
#pragma intrinsic(_coro_resume_block)

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)

#endif // _STL_COMPILER_PREPROCESSOR
#endif // _EXPERIMENTAL_COROUTINE_
