#define patchSide 8
#define patchSideSh 3
#define side_2_sh 6
#define side_2 (patchSide * patchSide)

#define N 16
#define NSHIFT 4

#define PATCHSHIFT 3
#define PATCHSIZE (1<<PATCHSHIFT)

inline float16 _hadamard_transform_and_threshold(float16 v, float T, float coef)
{
    float16 v1;
    
    // Step 1
    
    v1.lo = v.lo + v.hi;
    v1.hi = v.lo - v.hi;
    
    // Step 2
    
    v.lo.lo = v1.lo.lo + v1.lo.hi;
    v.lo.hi = v1.lo.lo - v1.lo.hi;
    
    v.hi.lo = v1.hi.lo + v1.hi.hi;
    v.hi.hi = v1.hi.lo - v1.hi.hi;
    
    // Step 3
    
    v1.s014589CD = v.s014589CD + v.s2367ABEF;
    v1.s2367ABEF = v.s014589CD - v.s2367ABEF;
    
    // Step 4
    
    v.s02468ACE = v1.s02468ACE + v1.s13579BDF;
    v.s13579BDF = v1.s02468ACE - v1.s13579BDF;
    
    // Threshold
    
    v = as_float16(as_int16(v) & (v * v > T));
    
    // Inverse-Hadamard
    
    v1.lo = v.lo + v.hi;
    v1.hi = v.lo - v.hi;
    
    // Step 2
    
    v.lo.lo = v1.lo.lo + v1.lo.hi;
    v.lo.hi = v1.lo.lo - v1.lo.hi;
    
    v.hi.lo = v1.hi.lo + v1.hi.hi;
    v.hi.hi = v1.hi.lo - v1.hi.hi;
    
    // Step 3
    
    v1.s014589CD = v.s014589CD + v.s2367ABEF;
    v1.s2367ABEF = v.s014589CD - v.s2367ABEF;
    
    // Step 4
    
    v.s02468ACE = v1.s02468ACE + v1.s13579BDF;
    v.s13579BDF = v1.s02468ACE - v1.s13579BDF;

    return v * coef;
}

#define WGS_W 8
#define WGS_H 8
#define ITEMS (WGS_W*WGS_H)

__kernel __attribute__((reqd_work_group_size(WGS_W, WGS_H, 1)))
void denoise_stack(
                   __global unsigned*    w_ind,
                   __global unsigned*    h_ind,
                   int                   w_ind_size,
                   int                   h_ind_size,
                   __global unsigned*    offsets,
                   __global float*       stacks,
                   __global float*       weights,
                   __global float*       table_2D,
                   __global const float* img_mask,
                   int                   width,
                   int                   height,
                   float                 lambdaHard3D,
                   float                 sigma)
{
    const int ind_i = get_group_id(0);
    const int ind_j = get_group_id(1);
    const int local_idx = get_local_id(1) * WGS_W + get_local_id(0);
    
    const unsigned ind_offset = ind_j * w_ind_size + ind_i;
    
    __global unsigned* _offsets = &offsets[ind_offset<<NSHIFT];
    __global float* stack = (&stacks[ind_offset*side_2<<NSHIFT]) + local_idx;
    
    //! Declarations
    const float coef_norm = 4.0f;       // sqrt((float)N);
    const float coef = 1.0f / 16.0f;    // 1.0f / (float)N;
    
    const float T = lambdaHard3D * sigma * coef_norm * img_mask[_offsets[0]];
    
    float16 _stack;
    
    _stack.s0 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s1 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s2 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s3 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s4 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s5 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s6 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s7 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s8 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.s9 = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.sA = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.sB = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.sC = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.sD = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.sE = table_2D[local_idx + (*_offsets++ << side_2_sh)];
    _stack.sF = table_2D[local_idx + (*_offsets   << side_2_sh)];
    
    _stack = _hadamard_transform_and_threshold(_stack, T * T, coef);

    *stack = _stack.s0; stack += 64;
    *stack = _stack.s1; stack += 64;
    *stack = _stack.s2; stack += 64;
    *stack = _stack.s3; stack += 64;
    *stack = _stack.s4; stack += 64;
    *stack = _stack.s5; stack += 64;
    *stack = _stack.s6; stack += 64;
    *stack = _stack.s7; stack += 64;
    *stack = _stack.s8; stack += 64;
    *stack = _stack.s9; stack += 64;
    *stack = _stack.sA; stack += 64;
    *stack = _stack.sB; stack += 64;
    *stack = _stack.sC; stack += 64;
    *stack = _stack.sD; stack += 64;
    *stack = _stack.sE; stack += 64;
    *stack = _stack.sF;
}

//inline void _hadamard_transform(__local float* data)
//{
//    float16 v = vload16(0, data);
//    float16 v1;
//
//    // Step 1
//
//    v1.lo = v.lo + v.hi;
//    v1.hi = v.lo - v.hi;
//
//    // Step 2
//
//    v.lo.lo = v1.lo.lo + v1.lo.hi;
//    v.lo.hi = v1.lo.lo - v1.lo.hi;
//
//    v.hi.lo = v1.hi.lo + v1.hi.hi;
//    v.hi.hi = v1.hi.lo - v1.hi.hi;
//
//    // Step 3
//
//    v1.s014589CD = v.s014589CD + v.s2367ABEF;
//    v1.s2367ABEF = v.s014589CD - v.s2367ABEF;
//
//    // Step 4
//
//    v.s02468ACE = v1.s02468ACE + v1.s13579BDF;
//    v.s13579BDF = v1.s02468ACE - v1.s13579BDF;
//
//    vstore16(v, 0, data);
//}
//
//inline void _hadamard_transform(__local float* data)
//{
//    float16 v = vload16(0, data);
//    float16 v1;
//
//    // Step 1
//
//    v1.s0 = v.s0 + v.s8;
//    v1.s1 = v.s1 + v.s9;
//    v1.s2 = v.s2 + v.sA;
//    v1.s3 = v.s3 + v.sB;
//    v1.s4 = v.s4 + v.sC;
//    v1.s5 = v.s5 + v.sD;
//    v1.s6 = v.s6 + v.sE;
//    v1.s7 = v.s7 + v.sF;
//
//    v1.s8 = v.s0 - v.s8;
//    v1.s9 = v.s1 - v.s9;
//    v1.sA = v.s2 - v.sA;
//    v1.sB = v.s3 - v.sB;
//    v1.sC = v.s4 - v.sC;
//    v1.sD = v.s5 - v.sD;
//    v1.sE = v.s6 - v.sE;
//    v1.sF = v.s7 - v.sF;
//
//    // Step 2
//    v.s0 = v1.s0 + v1.s4;
//    v.s1 = v1.s1 + v1.s5;
//    v.s2 = v1.s2 + v1.s6;
//    v.s3 = v1.s3 + v1.s7;
//
//    v.s4 = v1.s0 - v1.s4;
//    v.s5 = v1.s1 - v1.s5;
//    v.s6 = v1.s2 - v1.s6;
//    v.s7 = v1.s3 - v1.s7;
//
//    v.s8 = v1.s8 + v1.sC;
//    v.s9 = v1.s9 + v1.sD;
//    v.sA = v1.sA + v1.sE;
//    v.sB = v1.sB + v1.sF;
//
//    v.sC = v1.s8 - v1.sC;
//    v.sD = v1.s9 - v1.sD;
//    v.sE = v1.sA - v1.sE;
//    v.sF = v1.sB - v1.sF;
//
//    // Step 3
//    v1.s0 = v.s0 + v.s2;
//    v1.s1 = v.s1 + v.s3;
//
//    v1.s2 = v.s0 - v.s2;
//    v1.s3 = v.s1 - v.s3;
//
//    v1.s4 = v.s4 + v.s6;
//    v1.s5 = v.s5 + v.s7;
//
//    v1.s6 = v.s4 - v.s6;
//    v1.s7 = v.s5 - v.s7;
//
//    v1.s8 = v.s8 + v.sA;
//    v1.s9 = v.s9 + v.sB;
//
//    v1.sA = v.s8 - v.sA;
//    v1.sB = v.s9 - v.sB;
//
//    v1.sC = v.sC + v.sE;
//    v1.sD = v.sD + v.sF;
//
//    v1.sE = v.sC - v.sE;
//    v1.sF = v.sD - v.sF;
//
//    // Step 4
//
//    v.s0 = v1.s0 + v1.s1;
//    v.s1 = v1.s0 - v1.s1;
//    v.s2 = v1.s2 + v1.s3;
//    v.s3 = v1.s2 - v1.s3;
//    v.s4 = v1.s4 + v1.s5;
//    v.s5 = v1.s4 - v1.s5;
//    v.s6 = v1.s6 + v1.s7;
//    v.s7 = v1.s6 - v1.s7;
//    v.s8 = v1.s8 + v1.s9;
//    v.s9 = v1.s8 - v1.s9;
//    v.sA = v1.sA + v1.sB;
//    v.sB = v1.sA - v1.sB;
//    v.sC = v1.sC + v1.sD;
//    v.sD = v1.sC - v1.sD;
//    v.sE = v1.sE + v1.sF;
//    v.sF = v1.sE - v1.sF;
//
//    vstore16(v, 0, data);
//}

//inline void _hadamard_transform(__local float* data)
//{
//    // Require n to be a power of 2
//    const int l2 = NSHIFT;
//
//    // Compute the WHT
//    for ( int i = 0; i < l2; ++i )
//    {
//        for ( int j = 0; j < N; j += (1 << (i+1)) )
//        {
//            for ( int k = 0; k < (int)(1 << i); ++k )
//            {
//                __local float* a = &data[j+k];
//                __local float* b = &data[j+k + (int)(1<<i)];
//                float A = *a;
//                float B = *b;
//                *a = A + B;
//                *b = A - B;
//            }
//        }
//    }
//}
