#define patchSide 8
#define patchSideSh 3
#define side_2_sh 6
#define side_2 (patchSide * patchSide)

#define N 16
#define NSHIFT 4

#define PATCHSHIFT 3
#define PATCHSIZE (1<<PATCHSHIFT)

#define step 5

float16 _hadamard_transform_and_threshold(float16 v, float T);
float8 _wht(float8 v);
void walshV(__local float* _img, int idx);

//inline
float16 _hadamard_transform_and_threshold(float16 v, float T)
{
    float16 v1;
    
    // Step 1
    
    v1.lo = v.lo + v.hi;
    v1.hi = v.lo - v.hi;
    
    // Step 2
    
    v.lo.lo = v1.lo.lo + v1.lo.hi;
    v.lo.hi = v1.lo.lo - v1.lo.hi;
    
    v.hi.lo = v1.hi.lo + v1.hi.hi;
    v.hi.hi = v1.hi.lo - v1.hi.hi;
    
    // Step 3
    
    v1.s014589CD = v.s014589CD + v.s2367ABEF;
    v1.s2367ABEF = v.s014589CD - v.s2367ABEF;
    
    // Step 4
    
    v.s02468ACE = v1.s02468ACE + v1.s13579BDF;
    v.s13579BDF = v1.s02468ACE - v1.s13579BDF;
    
    // Threshold
    
    v = as_float16(as_int16(v) & (v * v > T));
    
    // Inverse-Hadamard
    
    v1.lo = v.lo + v.hi;
    v1.hi = v.lo - v.hi;
    
    // Step 2
    
    v.lo.lo = v1.lo.lo + v1.lo.hi;
    v.lo.hi = v1.lo.lo - v1.lo.hi;
    
    v.hi.lo = v1.hi.lo + v1.hi.hi;
    v.hi.hi = v1.hi.lo - v1.hi.hi;
    
    // Step 3
    
    v1.s014589CD = v.s014589CD + v.s2367ABEF;
    v1.s2367ABEF = v.s014589CD - v.s2367ABEF;
    
    // Step 4
    
    v.s02468ACE = v1.s02468ACE + v1.s13579BDF;
    v.s13579BDF = v1.s02468ACE - v1.s13579BDF;

    return v * (1.0f / 16.0f);
}

//inline
float8 _wht(float8 v)
{
    float8 v1;
    
    // Step 1
    
    v1.lo = v.lo + v.hi;
    v1.hi = v.lo - v.hi;
    
    // Step 2
    
    v.lo.lo = v1.lo.lo + v1.lo.hi;
    v.lo.hi = v1.lo.lo - v1.lo.hi;
    
    v.hi.lo = v1.hi.lo + v1.hi.hi;
    v.hi.hi = v1.hi.lo - v1.hi.hi;
    
    // Step 3
    
    v1.s0246 = v.s0246 + v.s1357;
    v1.s1357 = v.s0246 - v.s1357;

    return v1 * 0.35355339059327f;//(1.0f/sqrt(8.0f));
}

inline void walshH(__local float* _img, int idx)
{
    vstore8(_wht(vload8(idx, _img)), idx, _img);
}

//inline
void walshV(__local float* _img, int idx)
{
    __local float* img = _img + idx;

    float8 v;

    v.s0 = img[8*0];
    v.s1 = img[8*1];
    v.s2 = img[8*2];
    v.s3 = img[8*3];
    v.s4 = img[8*4];
    v.s5 = img[8*5];
    v.s6 = img[8*6];
    v.s7 = img[8*7];

    v = _wht(v);

    img[8*0] = v.s0;
    img[8*1] = v.s1;
    img[8*2] = v.s2;
    img[8*3] = v.s3;
    img[8*4] = v.s4;
    img[8*5] = v.s5;
    img[8*6] = v.s6;
    img[8*7] = v.s7;
}
/*
__constant float kaiser_window[side_2] =
{
    0.192400f, 0.298900f, 0.384600f, 0.432500f, 0.432500f, 0.384600f, 0.298900f, 0.192400f,
    0.298900f, 0.464200f, 0.597400f, 0.671700f, 0.671700f, 0.597400f, 0.464200f, 0.298900f,
    0.384600f, 0.597400f, 0.768800f, 0.864400f, 0.864400f, 0.768800f, 0.597400f, 0.384600f,
    0.432500f, 0.671700f, 0.864400f, 0.971800f, 0.971800f, 0.864400f, 0.671700f, 0.432500f,
    0.432500f, 0.671700f, 0.864400f, 0.971800f, 0.971800f, 0.864400f, 0.671700f, 0.432500f,
    0.384600f, 0.597400f, 0.768800f, 0.864400f, 0.864400f, 0.768800f, 0.597400f, 0.384600f,
    0.298900f, 0.464200f, 0.597400f, 0.671700f, 0.671700f, 0.597400f, 0.464200f, 0.298900f,
    0.192400f, 0.298900f, 0.384600f, 0.432500f, 0.432500f, 0.384600f, 0.298900f, 0.192400f
};
*/
#define WGS_W 8
#define WGS_H 8
#define ITEMS (WGS_W*WGS_H)

__kernel __attribute__((reqd_work_group_size(WGS_W, WGS_H, 1)))
void denoise_stack_wht(
                   int                   step_i,
                   int                   step_j,
                   __global unsigned*    offsets,
                   __global const float* img,
                   __global float2*      numerator_denominator,
                   int                   width,
                   int                   height,
                   float                 T)
{
    __local float stack[N][8][8];   // 4096
    __local float window[24][25];   // 2304
                                    // 6400

//    __local float2 numDenom[24][24];

    const int w_ind_size = (width - 2*patchSide)/step;
    //const int h_ind_size = (height - 2*patchSide)/step;

    const int ind_i = get_group_id(0)*5 + step_i;
    const int ind_j = get_group_id(1)*5 + step_j;

    const int i = ind_i*step + patchSide;
    const int j = ind_j*step + patchSide;

    const int local_i = get_local_id(0);
    const int local_j = get_local_id(1);
        
    const int win_x = i - 8;
    const int win_y = j - 8;

    // window

//    const int qpos_i = (local_i % 3)*8;
//    const int qpos_j = (local_i / 3)*8 + local_j;
//    const int lpos_i = 2*8 + local_i;
//    const int lpos_j = 2*8 + local_j;
//
//    const int qoffs = (win_y + qpos_j)*width + (win_x + qpos_i);
//    const int loffs = (win_y + lpos_j)*width + (win_x + lpos_i);
//    
//    vstore8(vload8(0, img + qoffs), 0, window + qpos_j*24 + qpos_i);
//    window[lpos_j*24 + lpos_i] = img[loffs];
//
//    barrier(CLK_LOCAL_MEM_FENCE);

//    for (int y = 0; y < 3; ++y)
//        for (int x = 0; x < 3; ++x)
//            window[y*8 + local_j][x*8 + local_i] = img[(win_y + y*8 + local_j)*width + (win_x + x*8 + local_i)];

    for (int y = 0; y < 3; ++y)
        for (int x = 0; x < 3; ++x)
            window[local_j*3 + y][local_i*3 + x] = img[(win_y + local_j*3 + y)*width + (win_x + local_i*3 + x)];

    barrier(CLK_LOCAL_MEM_FENCE);

    // process offsets
    
    const unsigned ind_offset = ind_j * w_ind_size + ind_i;

    __global unsigned* _offsets = &offsets[ind_offset<<NSHIFT];

    // load stack
//    for(int n = 0; n < 2; n++)
//    {
//        int k = local_i*2+n;
//        int offset = _offsets[k];
//        int pos_x = offset % width - win_x;
//        int pos_y = offset / width - win_y + local_j;
//
//        vstore8(vload8(0, &window[pos_y][pos_x]), 0, stack[k][local_j]);
//    }

    for(int k = 0; k < N; k++)
    {
        int offset = _offsets[k];
        int pos_x = offset % width - win_x + local_i;
        int pos_y = offset / width - win_y + local_j;
        
        stack[k][local_j][local_i] = window[pos_y][pos_x];
    }
    
    barrier(CLK_LOCAL_MEM_FENCE);

    // wht2D

    walshV((__local float*)stack[local_j*2+0], local_i);
    walshV((__local float*)stack[local_j*2+1], local_i);

    barrier(CLK_LOCAL_MEM_FENCE);

    walshH((__local float*)stack[local_j*2+0], local_i);
    walshH((__local float*)stack[local_j*2+1], local_i);

    barrier(CLK_LOCAL_MEM_FENCE);

    // whtZ and clip

    float16 _zcolumn;
    
    _zcolumn.s0 = stack[ 0][local_j][local_i];
    _zcolumn.s1 = stack[ 1][local_j][local_i];
    _zcolumn.s2 = stack[ 2][local_j][local_i];
    _zcolumn.s3 = stack[ 3][local_j][local_i];
    _zcolumn.s4 = stack[ 4][local_j][local_i];
    _zcolumn.s5 = stack[ 5][local_j][local_i];
    _zcolumn.s6 = stack[ 6][local_j][local_i];
    _zcolumn.s7 = stack[ 7][local_j][local_i];
    _zcolumn.s8 = stack[ 8][local_j][local_i];
    _zcolumn.s9 = stack[ 9][local_j][local_i];
    _zcolumn.sA = stack[10][local_j][local_i];
    _zcolumn.sB = stack[11][local_j][local_i];
    _zcolumn.sC = stack[12][local_j][local_i];
    _zcolumn.sD = stack[13][local_j][local_i];
    _zcolumn.sE = stack[14][local_j][local_i];
    _zcolumn.sF = stack[15][local_j][local_i];
    
    _zcolumn = _hadamard_transform_and_threshold(_zcolumn, T);
    
    stack[ 0][local_j][local_i] = _zcolumn.s0;
    stack[ 1][local_j][local_i] = _zcolumn.s1;
    stack[ 2][local_j][local_i] = _zcolumn.s2;
    stack[ 3][local_j][local_i] = _zcolumn.s3;
    stack[ 4][local_j][local_i] = _zcolumn.s4;
    stack[ 5][local_j][local_i] = _zcolumn.s5;
    stack[ 6][local_j][local_i] = _zcolumn.s6;
    stack[ 7][local_j][local_i] = _zcolumn.s7;
    stack[ 8][local_j][local_i] = _zcolumn.s8;
    stack[ 9][local_j][local_i] = _zcolumn.s9;
    stack[10][local_j][local_i] = _zcolumn.sA;
    stack[11][local_j][local_i] = _zcolumn.sB;
    stack[12][local_j][local_i] = _zcolumn.sC;
    stack[13][local_j][local_i] = _zcolumn.sD;
    stack[14][local_j][local_i] = _zcolumn.sE;
    stack[15][local_j][local_i] = _zcolumn.sF;

    barrier(CLK_LOCAL_MEM_FENCE);
    
    // wht2D

    walshV((__local float*)stack[local_j*2+0], local_i);
    walshV((__local float*)stack[local_j*2+1], local_i);

    barrier(CLK_LOCAL_MEM_FENCE);

    walshH((__local float*)stack[local_j*2+0], local_i);
    walshH((__local float*)stack[local_j*2+1], local_i);

    barrier(CLK_LOCAL_MEM_FENCE);

    // apply
//    vstore16(float16(0), 0, (__local float*)&numDenom[qpos_j][qpos_i]);
//    vstore2(float2(0), 0, (__local float*)&numDenom[lpos_j][lpos_i]);
//
//    barrier(CLK_LOCAL_MEM_FENCE);

//    const int p = get_local_id(0);
//    const int q = get_local_id(1);
//
//    const int offsetOrg = _offsets[0];
//    const int dst_i = (offsetOrg % width) - 8;
//    const int dst_j = (offsetOrg / width) - 8;
//    
//    for (int y = q * 3; y < q * 3 + 3; ++y)
//    {
//        for (int x = p * 3; x < p * 3 + 3; ++x)
//        {
//            const int offsetGlobal = (y + dst_j) * width + x + dst_i;
//
//            numDenom[y][x] = numerator_denominator[offsetGlobal];
//        }
//    }
//    
//    barrier(CLK_LOCAL_MEM_FENCE);

    for(int k = 0; k < N; k++)
    {
        int offset = _offsets[k];
//        int pos_x = offset % width - win_x + local_i;
//        int pos_y = offset / width - win_y + local_j;
        
//        const float w = kaiser_window[local_idx];
        float v = stack[k][local_j][local_i];

//        numDenom[pos_y][pos_x].s0 += v;//*w;
//        numDenom[pos_y][pos_x].s1 += 1;//w;

//        assert(offset + local_j*width + local_i < width*height);

        numerator_denominator[offset + local_j*width + local_i].s0 += v;
        numerator_denominator[offset + local_j*width + local_i].s1 += 1;

//        barrier(CLK_LOCAL_MEM_FENCE);
    }

    // aggregate
//    float16 v = vload16(0, (__global float*)&numerator_denominator[qoffs]);
//    vstore16(v + vload16(0, (__local float*)&numDenom[qpos_j][qpos_i]), 0, (__global float*)&numerator_denominator[qoffs]);
//    numerator_denominator[loffs].s0 += numDenom[lpos_j][lpos_i].s0;
//    numerator_denominator[loffs].s1 += numDenom[lpos_j][lpos_i].s1;

//    for (int y = q * 3; y < q * 3 + 3; ++y)
//    {
//        for (int x = p * 3; x < p * 3 + 3; ++x)
//        {
//            const int offsetGlobal = (y + dst_j) * width + x + dst_i;
//            
//            numerator_denominator[offsetGlobal] = numDenom[y][x];
//        }
//    }
}
