/* -*- c++ -*- */
/*
 * Copyright 2015 Free Software Foundation, Inc.
 * Copyright 2023-2026 Magnus Lundmark <magnuslundmark@gmail.com>
 *
 * This file is part of VOLK
 *
 * SPDX-License-Identifier: LGPL-3.0-or-later
 */

/*
 * This file is intended to hold SSE intrinsics of intrinsics.
 * They should be used in VOLK kernels to avoid copy-pasta.
 */

#ifndef INCLUDE_VOLK_VOLK_SSE_INTRINSICS_H_
#define INCLUDE_VOLK_VOLK_SSE_INTRINSICS_H_
#include <emmintrin.h>
#include <xmmintrin.h>

/*
 * Newton-Raphson refined reciprocal square root: 1/sqrt(a)
 * One iteration doubles precision from ~12-bit to ~24-bit
 * x1 = x0 * (1.5 - 0.5 * a * x0^2)
 * Handles edge cases: +0 → +Inf, +Inf → 0
 */
static inline __m128 _mm_rsqrt_nr_ps(const __m128 a)
{
    const __m128 HALF = _mm_set1_ps(0.5f);
    const __m128 THREE_HALFS = _mm_set1_ps(1.5f);

    const __m128 x0 = _mm_rsqrt_ps(a); // +Inf for +0, 0 for +Inf

    // Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
    __m128 x1 = _mm_mul_ps(
        x0, _mm_sub_ps(THREE_HALFS, _mm_mul_ps(HALF, _mm_mul_ps(_mm_mul_ps(x0, x0), a))));

    // For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
    // Blend: use x0 where a == +0 or a == +Inf, else use x1
    __m128i a_si = _mm_castps_si128(a);
    __m128i zero_mask = _mm_cmpeq_epi32(a_si, _mm_setzero_si128());
    __m128i inf_mask = _mm_cmpeq_epi32(a_si, _mm_set1_epi32(0x7F800000));
    __m128 special_mask = _mm_castsi128_ps(_mm_or_si128(zero_mask, inf_mask));
    // SSE2-compatible blend: (x0 & mask) | (x1 & ~mask)
    return _mm_or_ps(_mm_and_ps(special_mask, x0), _mm_andnot_ps(special_mask, x1));
}

/*
 * Approximate arctan(x) via polynomial expansion
 * on the interval [-1, 1]
 *
 * Maximum relative error ~6.5e-7
 * Polynomial evaluated via Horner's method
 */
static inline __m128 _mm_arctan_poly_sse(const __m128 x)
{
    const __m128 a1 = _mm_set1_ps(+0x1.ffffeap-1f);
    const __m128 a3 = _mm_set1_ps(-0x1.55437p-2f);
    const __m128 a5 = _mm_set1_ps(+0x1.972be6p-3f);
    const __m128 a7 = _mm_set1_ps(-0x1.1436ap-3f);
    const __m128 a9 = _mm_set1_ps(+0x1.5785aap-4f);
    const __m128 a11 = _mm_set1_ps(-0x1.2f3004p-5f);
    const __m128 a13 = _mm_set1_ps(+0x1.01a37cp-7f);

    const __m128 x_times_x = _mm_mul_ps(x, x);
    __m128 arctan;
    arctan = a13;
    arctan = _mm_mul_ps(x_times_x, arctan);
    arctan = _mm_add_ps(arctan, a11);
    arctan = _mm_mul_ps(x_times_x, arctan);
    arctan = _mm_add_ps(arctan, a9);
    arctan = _mm_mul_ps(x_times_x, arctan);
    arctan = _mm_add_ps(arctan, a7);
    arctan = _mm_mul_ps(x_times_x, arctan);
    arctan = _mm_add_ps(arctan, a5);
    arctan = _mm_mul_ps(x_times_x, arctan);
    arctan = _mm_add_ps(arctan, a3);
    arctan = _mm_mul_ps(x_times_x, arctan);
    arctan = _mm_add_ps(arctan, a1);
    arctan = _mm_mul_ps(x, arctan);

    return arctan;
}

/*
 * Approximate arcsin(x) via polynomial expansion
 * P(u) such that asin(x) = x * P(x^2) on |x| <= 0.5
 *
 * Maximum relative error ~1.5e-6
 * Polynomial evaluated via Horner's method
 */
static inline __m128 _mm_arcsin_poly_sse(const __m128 x)
{
    const __m128 c0 = _mm_set1_ps(0x1.ffffcep-1f);
    const __m128 c1 = _mm_set1_ps(0x1.55b648p-3f);
    const __m128 c2 = _mm_set1_ps(0x1.24d192p-4f);
    const __m128 c3 = _mm_set1_ps(0x1.0a788p-4f);

    const __m128 u = _mm_mul_ps(x, x);
    __m128 p = c3;
    p = _mm_mul_ps(u, p);
    p = _mm_add_ps(p, c2);
    p = _mm_mul_ps(u, p);
    p = _mm_add_ps(p, c1);
    p = _mm_mul_ps(u, p);
    p = _mm_add_ps(p, c0);

    return _mm_mul_ps(x, p);
}

static inline __m128 _mm_magnitudesquared_ps(__m128 cplxValue1, __m128 cplxValue2)
{
    __m128 iValue, qValue;
    // Arrange in i1i2i3i4 format
    iValue = _mm_shuffle_ps(cplxValue1, cplxValue2, _MM_SHUFFLE(2, 0, 2, 0));
    // Arrange in q1q2q3q4 format
    qValue = _mm_shuffle_ps(cplxValue1, cplxValue2, _MM_SHUFFLE(3, 1, 3, 1));
    iValue = _mm_mul_ps(iValue, iValue); // Square the I values
    qValue = _mm_mul_ps(qValue, qValue); // Square the Q Values
    return _mm_add_ps(iValue, qValue);   // Add the I2 and Q2 values
}

static inline __m128 _mm_magnitude_ps(__m128 cplxValue1, __m128 cplxValue2)
{
    return _mm_sqrt_ps(_mm_magnitudesquared_ps(cplxValue1, cplxValue2));
}

static inline __m128 _mm_scaled_norm_dist_ps_sse(const __m128 symbols0,
                                                 const __m128 symbols1,
                                                 const __m128 points0,
                                                 const __m128 points1,
                                                 const __m128 scalar)
{
    // calculate scalar * |x - y|^2
    const __m128 diff0 = _mm_sub_ps(symbols0, points0);
    const __m128 diff1 = _mm_sub_ps(symbols1, points1);
    const __m128 norms = _mm_magnitudesquared_ps(diff0, diff1);
    return _mm_mul_ps(norms, scalar);
}

static inline __m128 _mm_accumulate_square_sum_ps(
    __m128 sq_acc, __m128 acc, __m128 val, __m128 rec, __m128 aux)
{
    aux = _mm_mul_ps(aux, val);
    aux = _mm_sub_ps(aux, acc);
    aux = _mm_mul_ps(aux, aux);
    aux = _mm_mul_ps(aux, rec);
    return _mm_add_ps(sq_acc, aux);
}

/*
 * Minimax polynomial for sin(x) on [-pi/4, pi/4]
 * Coefficients via Remez algorithm (Sollya)
 * Max |error| < 7.3e-9
 * sin(x) = x + x^3 * (s1 + x^2 * (s2 + x^2 * s3))
 */
static inline __m128 _mm_sin_poly_sse(const __m128 x)
{
    const __m128 s1 = _mm_set1_ps(-0x1.555552p-3f);
    const __m128 s2 = _mm_set1_ps(+0x1.110be2p-7f);
    const __m128 s3 = _mm_set1_ps(-0x1.9ab22ap-13f);

    const __m128 x2 = _mm_mul_ps(x, x);
    const __m128 x3 = _mm_mul_ps(x2, x);

    __m128 poly = _mm_add_ps(_mm_mul_ps(x2, s3), s2);
    poly = _mm_add_ps(_mm_mul_ps(x2, poly), s1);
    return _mm_add_ps(_mm_mul_ps(x3, poly), x);
}

/*
 * Minimax polynomial for cos(x) on [-pi/4, pi/4]
 * Coefficients via Remez algorithm (Sollya)
 * Max |error| < 1.1e-7
 * cos(x) = 1 + x^2 * (c1 + x^2 * (c2 + x^2 * c3))
 */
static inline __m128 _mm_cos_poly_sse(const __m128 x)
{
    const __m128 c1 = _mm_set1_ps(-0x1.fffff4p-2f);
    const __m128 c2 = _mm_set1_ps(+0x1.554a46p-5f);
    const __m128 c3 = _mm_set1_ps(-0x1.661be2p-10f);
    const __m128 one = _mm_set1_ps(1.0f);

    const __m128 x2 = _mm_mul_ps(x, x);

    __m128 poly = _mm_add_ps(_mm_mul_ps(x2, c3), c2);
    poly = _mm_add_ps(_mm_mul_ps(x2, poly), c1);
    return _mm_add_ps(_mm_mul_ps(x2, poly), one);
}

/*
 * Polynomial coefficients for log2(x)/(x-1) on [1, 2]
 * Generated with Sollya: remez(log2(x)/(x-1), 6, [1+1b-20, 2])
 * Max error: ~1.55e-6
 *
 * Usage: log2(x) ≈ poly(x) * (x - 1) for x ∈ [1, 2]
 * Polynomial evaluated via Horner's method
 */
static inline __m128 _mm_log2_poly_sse(const __m128 x)
{
    const __m128 c0 = _mm_set1_ps(+0x1.a8a726p+1f);
    const __m128 c1 = _mm_set1_ps(-0x1.0b7f7ep+2f);
    const __m128 c2 = _mm_set1_ps(+0x1.05d9ccp+2f);
    const __m128 c3 = _mm_set1_ps(-0x1.4d476cp+1f);
    const __m128 c4 = _mm_set1_ps(+0x1.04fc3ap+0f);
    const __m128 c5 = _mm_set1_ps(-0x1.c97982p-3f);
    const __m128 c6 = _mm_set1_ps(+0x1.57aa42p-6f);

    // Horner's method: c0 + x*(c1 + x*(c2 + ...))
    __m128 poly = c6;
    poly = _mm_add_ps(_mm_mul_ps(poly, x), c5);
    poly = _mm_add_ps(_mm_mul_ps(poly, x), c4);
    poly = _mm_add_ps(_mm_mul_ps(poly, x), c3);
    poly = _mm_add_ps(_mm_mul_ps(poly, x), c2);
    poly = _mm_add_ps(_mm_mul_ps(poly, x), c1);
    poly = _mm_add_ps(_mm_mul_ps(poly, x), c0);
    return poly;
}

#endif /* INCLUDE_VOLK_VOLK_SSE_INTRINSICS_H_ */
