







#[cfg(feature="std")]
use core::cell::UnsafeCell;
use core::cmp::min;
use core::mem::size_of;
use core::slice;

use crate::aligned_alloc::Alloc;

use crate::ptr::Ptr;
use crate::util::range_chunk;
use crate::util::round_up_to;

use crate::kernel::Element;
use crate::kernel::GemmKernel;
use crate::kernel::GemmSelect;
#[cfg(feature = "cgemm")]
use crate::kernel::{c32, c64};
use crate::threading::{get_thread_pool, ThreadPoolCtx, LoopThreadConfig};
use crate::sgemm_kernel;
use crate::dgemm_kernel;
#[cfg(feature = "cgemm")]
use crate::cgemm_kernel;
#[cfg(feature = "cgemm")]
use crate::zgemm_kernel;
use rawpointer::PointerExt;

















pub unsafe fn sgemm(
    m: usize, k: usize, n: usize,
    alpha: f32,
    a: *const f32, rsa: isize, csa: isize,
    b: *const f32, rsb: isize, csb: isize,
    beta: f32,
    c: *mut f32, rsc: isize, csc: isize)
{
    sgemm_kernel::detect(GemmParameters { m, k, n,
                alpha,
                a, rsa, csa,
                b, rsb, csb,
                beta,
                c, rsc, csc})
}

















pub unsafe fn dgemm(
    m: usize, k: usize, n: usize,
    alpha: f64,
    a: *const f64, rsa: isize, csa: isize,
    b: *const f64, rsb: isize, csb: isize,
    beta: f64,
    c: *mut f64, rsc: isize, csc: isize)
{
    dgemm_kernel::detect(GemmParameters { m, k, n,
                alpha,
                a, rsa, csa,
                b, rsb, csb,
                beta,
                c, rsc, csc})
}




#[cfg(feature = "cgemm")]
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
pub enum CGemmOption {

    Standard,
}

#[cfg(feature = "cgemm")]


















pub unsafe fn cgemm(
    flaga: CGemmOption, flagb: CGemmOption,
    m: usize, k: usize, n: usize,
    alpha: c32,
    a: *const c32, rsa: isize, csa: isize,
    b: *const c32, rsb: isize, csb: isize,
    beta: c32,
    c: *mut c32, rsc: isize, csc: isize)
{
    let _ = (flaga, flagb);
    cgemm_kernel::detect(GemmParameters { m, k, n,
                alpha,
                a, rsa, csa,
                b, rsb, csb,
                beta,
                c, rsc, csc})
}

#[cfg(feature = "cgemm")]


















pub unsafe fn zgemm(
    flaga: CGemmOption, flagb: CGemmOption,
    m: usize, k: usize, n: usize,
    alpha: c64,
    a: *const c64, rsa: isize, csa: isize,
    b: *const c64, rsb: isize, csb: isize,
    beta: c64,
    c: *mut c64, rsc: isize, csc: isize)
{
    let _ = (flaga, flagb);
    zgemm_kernel::detect(GemmParameters { m, k, n,
                alpha,
                a, rsa, csa,
                b, rsb, csb,
                beta,
                c, rsc, csc})
}

struct GemmParameters<T> {

    m: usize, k: usize, n: usize,
    alpha: T,
    a: *const T, rsa: isize, csa: isize,
    beta: T,
    b: *const T, rsb: isize, csb: isize,
    c:   *mut T, rsc: isize, csc: isize,
}

impl<T> GemmSelect<T> for GemmParameters<T> {
    fn select<K>(self, _kernel: K)
       where K: GemmKernel<Elem=T>,
             T: Element,
    {



        let GemmParameters {
            m, k, n,
            alpha,
            a, rsa, csa,
            b, rsb, csb,
            beta,
            c, rsc, csc} = self;

        unsafe {
            gemm_loop::<K>(
                m, k, n,
                alpha,
                a, rsa, csa,
                b, rsb, csb,
                beta,
                c, rsc, csc)
        }
    }
}






#[inline(always)]
fn ensure_kernel_params<K>()
    where K: GemmKernel
{
    let mr = K::MR;
    let nr = K::NR;


    assert!(mr > 0 && mr <= 8);
    assert!(nr > 0 && nr <= 8);
    assert!(mr * nr * size_of::<K::Elem>() <= 8 * 4 * 8);
    assert!(K::align_to() <= 32);

    let max_align = size_of::<K::Elem>() * min(mr, nr);
    assert!(K::align_to() <= max_align);

    assert!(K::MR <= K::mc());
    assert!(K::mc() <= K::kc());
    assert!(K::kc() <= K::nc());
    assert!(K::nc() <= 65536);
}





#[inline(never)]
unsafe fn gemm_loop<K>(
    m: usize, k: usize, n: usize,
    alpha: K::Elem,
    a: *const K::Elem, rsa: isize, csa: isize,
    b: *const K::Elem, rsb: isize, csb: isize,
    beta: K::Elem,
    c: *mut K::Elem, rsc: isize, csc: isize)
    where K: GemmKernel
{
    debug_assert!(m <= 1 || n == 0 || rsc != 0);
    debug_assert!(m == 0 || n <= 1 || csc != 0);


    if m == 0 || k == 0 || n == 0 {
        return c_to_beta_c(m, n, beta, c, rsc, csc);
    }

    let knc = K::nc();
    let kkc = K::kc();
    let kmc = K::mc();
    ensure_kernel_params::<K>();

    let a = Ptr(a);
    let b = Ptr(b);
    let c = Ptr(c);

    let (nthreads, tp) = get_thread_pool();
    let thread_config = LoopThreadConfig::new::<K>(m, k, n, nthreads);
    let nap = thread_config.num_pack_a();

    let (mut packing_buffer, ap_size, bp_size) = make_packing_buffer::<K>(m, k, n, nap);
    let app = Ptr(packing_buffer.ptr_mut());
    let bpp = app.add(ap_size * nap);


    for (l5, nc) in range_chunk(n, knc) {
        dprint!("LOOP 5, {}, nc={}", l5, nc);
        let b = b.stride_offset(csb, knc * l5);
        let c = c.stride_offset(csc, knc * l5);




        for (l4, kc) in range_chunk(k, kkc) {
            dprint!("LOOP 4, {}, kc={}", l4, kc);
            let b = b.stride_offset(rsb, kkc * l4);
            let a = a.stride_offset(csa, kkc * l4);


            K::pack_nr(kc, nc, slice::from_raw_parts_mut(bpp.ptr(), bp_size),
                       b.ptr(), csb, rsb);


            let betap = if l4 == 0 { beta } else { <_>::one() };


            range_chunk(m, kmc)
                .parallel(thread_config.loop3, tp)
                .thread_local(move |i, _nt| {

                    debug_assert!(i < nap);
                    app.add(ap_size * i)
                })
                .for_each(move |tp, &mut app, l3, mc| {
                    dprint!("LOOP 3, {}, mc={}", l3, mc);
                    let a = a.stride_offset(rsa, kmc * l3);
                    let c = c.stride_offset(rsc, kmc * l3);


                    K::pack_mr(kc, mc, slice::from_raw_parts_mut(app.ptr(), ap_size),
                               a.ptr(), rsa, csa);


                    gemm_packed::<K>(nc, kc, mc,
                                     alpha,
                                     app.to_const(), bpp.to_const(),
                                     betap,
                                     c, rsc, csc,
                                     tp, thread_config);
                });
        }
    }
}


const KERNEL_MAX_SIZE: usize = 8 * 8 * 4;
const KERNEL_MAX_ALIGN: usize = 32;
const MASK_BUF_SIZE: usize = KERNEL_MAX_SIZE + KERNEL_MAX_ALIGN - 1;





#[cfg_attr(not(target_os = "macos"), repr(align(32)))]
struct MaskBuffer {
    buffer: [u8; MASK_BUF_SIZE],
}



#[cfg(feature = "std")]
thread_local! {
    static MASK_BUF: UnsafeCell<MaskBuffer> =
        UnsafeCell::new(MaskBuffer { buffer: [0; MASK_BUF_SIZE] });
}








unsafe fn gemm_packed<K>(nc: usize, kc: usize, mc: usize,
                         alpha: K::Elem,
                         app: Ptr<*const K::Elem>, bpp: Ptr<*const K::Elem>,
                         beta: K::Elem,
                         c: Ptr<*mut K::Elem>, rsc: isize, csc: isize,
                         tp: ThreadPoolCtx, thread_config: LoopThreadConfig)
    where K: GemmKernel,
{
    let mr = K::MR;
    let nr = K::NR;

    assert!(mr * nr * size_of::<K::Elem>() <= KERNEL_MAX_SIZE && K::align_to() <= KERNEL_MAX_ALIGN);

    #[cfg(not(feature = "std"))]
    let mut mask_buf = MaskBuffer { buffer: [0; MASK_BUF_SIZE] };


    range_chunk(nc, nr)
        .parallel(thread_config.loop2, tp)
        .thread_local(|_i, _nt| {
            let mut ptr;
            #[cfg(not(feature = "std"))]
            {
                debug_assert_eq!(_nt, 1);
                ptr = mask_buf.buffer.as_mut_ptr();
            }
            #[cfg(feature = "std")]
            {
                ptr = MASK_BUF.with(|buf| (*buf.get()).buffer.as_mut_ptr());
            }
            ptr = align_ptr(K::align_to(), ptr);
            slice::from_raw_parts_mut(ptr as *mut K::Elem, KERNEL_MAX_SIZE / size_of::<K::Elem>())
        })
        .for_each(move |_tp, mask_buf, l2, nr_| {
            let bpp = bpp.stride_offset(1, kc * nr * l2);
            let c = c.stride_offset(csc, nr * l2);


            for (l1, mr_) in range_chunk(mc, mr) {
                let app = app.stride_offset(1, kc * mr * l1);
                let c = c.stride_offset(rsc, mr * l1);




                if K::always_masked() || nr_ < nr || mr_ < mr {
                    masked_kernel::<_, K>(kc, alpha, app.ptr(), bpp.ptr(),
                                          beta, c.ptr(), rsc, csc,
                                          mr_, nr_, mask_buf);
                    continue;
                } else {
                    K::kernel(kc, alpha, app.ptr(), bpp.ptr(), beta, c.ptr(), rsc, csc);
                }
            }
        });
}











unsafe fn make_packing_buffer<K>(m: usize, k: usize, n: usize, na: usize)
    -> (Alloc<K::Elem>, usize, usize)
    where K: GemmKernel,
{


    let m = min(m, K::mc());
    let k = min(k, K::kc());
    let n = min(n, K::nc());


    debug_assert_ne!(na, 0);
    debug_assert!(na <= 128);
    let apack_size = k * round_up_to(m, K::MR);
    let bpack_size = k * round_up_to(n, K::NR);
    let nelem = apack_size * na + bpack_size;

    dprint!("packed nelem={}, apack={}, bpack={},
             m={} k={} n={}, na={}",
             nelem, apack_size, bpack_size,
             m,k,n, na);

    (Alloc::new(nelem, K::align_to()), apack_size, bpack_size)
}



#[inline]
unsafe fn align_ptr<T>(mut align_to: usize, mut ptr: *mut T) -> *mut T {

    if cfg!(target_os = "macos") {
        align_to = Ord::max(align_to, 8);
    }

    if align_to != 0 {
        let cur_align = ptr as usize % align_to;
        if cur_align != 0 {
            ptr = ptr.offset(((align_to - cur_align) / size_of::<T>()) as isize);
        }
    }
    ptr
}









#[inline(never)]
unsafe fn masked_kernel<T, K>(k: usize, alpha: T,
                              a: *const T,
                              b: *const T,
                              beta: T,
                              c: *mut T, rsc: isize, csc: isize,
                              rows: usize, cols: usize,
                              mask_buf: &mut [T])
    where K: GemmKernel<Elem=T>, T: Element,
{

    K::kernel(k, alpha, a, b, T::zero(), mask_buf.as_mut_ptr(), 1, K::MR as isize);
    c_to_masked_ab_beta_c::<_, K>(beta, c, rsc, csc, rows, cols, &*mask_buf);
}




#[inline]
unsafe fn c_to_masked_ab_beta_c<T, K>(beta: T,
                                      c: *mut T, rsc: isize, csc: isize,
                                      rows: usize, cols: usize,
                                      mask_buf: &[T])
    where K: GemmKernel<Elem=T>, T: Element,
{


    let mr = K::MR;
    let nr = K::NR;
    let mut ab = mask_buf.as_ptr();
    for j in 0..nr {
        for i in 0..mr {
            if i < rows && j < cols {
                let cptr = c.stride_offset(rsc, i)
                            .stride_offset(csc, j);
                if beta.is_zero() {
                    *cptr = *ab; // initialize
                } else {
                    (*cptr).mul_assign(beta);
                    (*cptr).add_assign(*ab);
                }
            }
            ab.inc();
        }
    }
}


#[inline(never)]
unsafe fn c_to_beta_c<T>(m: usize, n: usize, beta: T,
                         c: *mut T, rsc: isize, csc: isize)
    where T: Element
{
    for i in 0..m {
        for j in 0..n {
            let cptr = c.stride_offset(rsc, i)
                        .stride_offset(csc, j);
            if beta.is_zero() {
                *cptr = T::zero(); // initialize C
            } else {
                (*cptr).mul_assign(beta);
            }
        }
    }
}
