//
// Modifications, Copyright (C) 2023 Intel Corporation
//
// This software and the related documents are Intel copyrighted materials, and
// your use of them is governed by the express license under which they were
// provided to you ("License"). Unless the License provides otherwise, you may not
// use, modify, copy, publish, distribute, disclose or transmit this software or
// the related documents without Intel's prior written permission.
//
// This software and the related documents are provided as is, with no express
// or implied warranties, other than those that are expressly stated in the
// License.
//
//==------------------------ cmath -----------------------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#if defined(__SYCL_DEVICE_ONLY__) && defined(_WIN32)
#if defined(__NVPTX__) || defined(__AMDGCN__)
// Forward declare device _*dsign functions before including cmath to enable
// compilation of math functions defined in UCRT headers.
extern "C" __attribute__((sycl_device_only, always_inline)) int
_fdsign(float x);
extern "C" __attribute__((sycl_device_only, always_inline)) int
_dsign(double x);
#endif // defined(__NVPTX__) || defined(__AMDGCN__)
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(_WIN32)

// Include real STL <cmath> header - the next one from the include search
// directories.
#if defined(__has_include_next)
// GCC/clang support go through this path.
#include_next <cmath>
#else
// MSVC doesn't support "#include_next", so we have to be creative.
// Our header is located in "stl_wrappers/cmath" so it won't be picked by the
// following include. MSVC's installation, on the other hand, has the layout
// where the following would result in the <cmath> we want. This is obviously
// hacky, but the best we can do...
#include <../include/cmath>
#endif

// *** <sycl/builtins.hpp> ***

#if defined(__NVPTX__) || defined(__AMDGCN__)
#include "__sycl_cmath_wrapper_impl.hpp"
#else

#include <sycl/detail/defines_elementary.hpp>

#ifdef __SYCL_DEVICE_ONLY__
extern "C" {
extern __DPCPP_SYCL_EXTERNAL_LIBC float scalbnf(float x, int n);
extern __DPCPP_SYCL_EXTERNAL_LIBC double scalbn(double x, int n);
extern __DPCPP_SYCL_EXTERNAL_LIBC float logf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double log(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float expf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double exp(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float log10f(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double log10(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float modff(float x, float *intpart);
extern __DPCPP_SYCL_EXTERNAL_LIBC double modf(double x, double *intpart);
extern __DPCPP_SYCL_EXTERNAL_LIBC float exp2f(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double exp2(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float expm1f(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double expm1(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC int ilogbf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC int ilogb(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float log1pf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double log1p(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float log2f(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double log2(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float logbf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double logb(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float sqrtf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double sqrt(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float cbrtf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double cbrt(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float erff(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double erf(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float erfcf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double erfc(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float tgammaf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double tgamma(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float lgammaf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double lgamma(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float fmodf(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC double fmod(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC float remainderf(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC double remainder(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC float remquof(float x, float y, int *q);
extern __DPCPP_SYCL_EXTERNAL_LIBC double remquo(double x, double y, int *q);
extern __DPCPP_SYCL_EXTERNAL_LIBC float nextafterf(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC double nextafter(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC float fdimf(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC double fdim(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC float fmaf(float x, float y, float z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double fma(double x, double y, double z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float sinf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double sin(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float cosf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double cos(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float tanf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double tan(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float asinf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double asin(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float acosf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double acos(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float atanf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double atan(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float powf(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC double pow(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC float atan2f(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC double atan2(double x, double y);

extern __DPCPP_SYCL_EXTERNAL_LIBC float sinhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double sinh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float coshf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double cosh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float tanhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double tanh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float asinhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double asinh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float acoshf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double acosh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float atanhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double atanh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double frexp(double x, int *exp);
extern __DPCPP_SYCL_EXTERNAL_LIBC double ldexp(double x, int exp);
extern __DPCPP_SYCL_EXTERNAL_LIBC double hypot(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC float rintf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double rint(double x);

extern __DPCPP_SYCL_EXTERNAL_LIBC float sinhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double sinh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float coshf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double cosh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float tanhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double tanh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float asinhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double asinh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float acoshf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double acosh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC float atanhf(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double atanh(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC double frexp(double x, int *exp);
extern __DPCPP_SYCL_EXTERNAL_LIBC double ldexp(double x, int exp);
extern __DPCPP_SYCL_EXTERNAL_LIBC double hypot(double x, double y);
}
#ifdef __GLIBC__
extern "C" {
extern __DPCPP_SYCL_EXTERNAL_LIBC float frexpf(float x, int *exp);
extern __DPCPP_SYCL_EXTERNAL_LIBC float ldexpf(float x, int exp);
extern __DPCPP_SYCL_EXTERNAL_LIBC float hypotf(float x, float y);

// MS UCRT supports most of the C standard library but <complex.h> is
// an exception.
extern __DPCPP_SYCL_EXTERNAL_LIBC float cimagf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double cimag(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float crealf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double creal(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float cargf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double carg(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float cabsf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double cabs(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ cprojf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ cproj(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ cexpf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ cexp(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ clogf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ clog(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ cpowf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ cpow(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ csqrtf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ csqrt(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ csinhf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ csinh(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ ccoshf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ ccosh(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ ctanhf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ ctanh(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ csinf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ csin(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ ccosf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ ccos(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ ctanf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ ctan(double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ cacosf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ cacos(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ cacoshf(
    float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ cacosh(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ casinf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ casin(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ casinhf(
    float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ casinh(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ catanf(float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ catan(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ catanhf(
    float __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ catanh(
    double __complex__ z);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ cpolarf(float rho,
                                                            float theta);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ cpolar(double rho,
                                                            double theta);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ __mulsc3(float a, float b,
                                                             float c, float d);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ __muldc3(double a,
                                                              double b,
                                                              double c,
                                                              double d);
extern __DPCPP_SYCL_EXTERNAL_LIBC float __complex__ __divsc3(float a, float b,
                                                             float c, float d);
extern __DPCPP_SYCL_EXTERNAL_LIBC double __complex__ __divdc3(float a, float b,
                                                              float c, float d);
}
#elif defined(_WIN32)
extern "C" {
// TODO: documented C runtime library APIs must be recognized as
//       builtins by FE. This includes _dpcomp, _dsign, _dtest,
//       _fdpcomp, _fdsign, _fdtest, _hypotf, _wassert.
//       APIs used by STL, such as _Cosh, are undocumented, even though
//       they are open-sourced. Recognizing them as builtins is not
//       straightforward currently.
extern __DPCPP_SYCL_EXTERNAL_LIBC double _Cosh(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC int _dpcomp(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC int _dsign(double x);
extern __DPCPP_SYCL_EXTERNAL_LIBC short _Dtest(double *px);
extern __DPCPP_SYCL_EXTERNAL_LIBC short _dtest(double *px);
extern __DPCPP_SYCL_EXTERNAL_LIBC short _Exp(double *px, double y, short eoff);
extern __DPCPP_SYCL_EXTERNAL_LIBC float _FCosh(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC int _fdpcomp(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC int _fdsign(float x);
extern __DPCPP_SYCL_EXTERNAL_LIBC short _FDtest(float *px);
extern __DPCPP_SYCL_EXTERNAL_LIBC short _fdtest(float *px);
extern __DPCPP_SYCL_EXTERNAL_LIBC short _FExp(float *px, float y, short eoff);
extern __DPCPP_SYCL_EXTERNAL_LIBC float _FSinh(float x, float y);
extern __DPCPP_SYCL_EXTERNAL_LIBC double _Sinh(double x, double y);
extern __DPCPP_SYCL_EXTERNAL_LIBC float _hypotf(float x, float y);
// MSVC math header includes implementation of 'hypotf' which calls a library
// function '_hypotf'. When icx is used, intel math header overrides MSVC math
// header and it doesn't include such 'hypotf' implementation. Instead, intel
// math header treats 'hypotf' as a library function, so we add 'hypotf' here
// and its implementation is provided by SYCL libdevice.
#if defined(__INTEL_LLVM_COMPILER)
extern __DPCPP_SYCL_EXTERNAL_LIBC float hypotf(float x, float y);
#endif
}
#endif
#endif // __SYCL_DEVICE_ONLY__
#endif
