#define patchSide 8
#define step 5

#define WGS_W 8
#define WGS_H 8

#define ITEMS (WGS_W*WGS_H)

#define N 16
#define NSHIFT 4

#define LOCAL_SIZE_LIMIT 256

typedef struct dist_t { unsigned offset; float dist; } dist_t;

void bitonicSortLocal(__local dist_t *array);

inline void ComparatorLocal(
                            __local dist_t *keyValA,
                            __local dist_t *keyValB,
                            uint arrowDir
                            ){
    if( (keyValA->dist > keyValB->dist) == arrowDir ){
        dist_t t;
        t = *keyValA;
        *keyValA = *keyValB;
        *keyValB = t;
    }
}

void bitonicSortLocal(__local dist_t *array)
{
    const int l_id = get_local_id(0) + get_local_id(1) * 8;
    
    {
        for(uint size = 2; size < 256; size <<= 1) {
            //Bitonic merge
            for (int z = 0; z < 2; ++z) {
                
                int local_id = l_id + 64 * z;
                
                uint dir = ( (local_id & (size / 2)) != 0 );
                for(uint stride = size / 2; stride > 0; stride >>= 1){
                    barrier(CLK_LOCAL_MEM_FENCE);
                    uint pos = 2 * local_id - (local_id & (stride - 1));
                    ComparatorLocal(
                                    array + (pos +      0),
                                    array + (pos + stride),
                                    dir
                                    );
                }
            }
        }
    }
    
    {
        for(uint stride = 256 / 2; stride > 0; stride >>= 1) {
            
            for (int z = 0; z < 2; ++z) {
                
                int local_id = l_id + 64 * z;
                
                barrier(CLK_LOCAL_MEM_FENCE);
                uint pos = 2 * local_id - (local_id & (stride - 1));
                ComparatorLocal(
                                array + (pos +      0),
                                array + (pos + stride),
                                1
                                );
            }
        }
    }
}

__kernel __attribute__((reqd_work_group_size(WGS_W, WGS_H, 1)))
void blocks_match(
              __global float*      img,
              int                  width,
              int                  height,
              float                threshold,
              __global unsigned*   g_offsets
              )
{
    __local float window[24][24];   // 2304
    __local dist_t _dists[256];     // 2048
                                    // 4352
    
    const int w_ind_size = (width - 2*patchSide)/step;

    const int ind_i = get_group_id(0);
    const int ind_j = get_group_id(1);

    const int i = patchSide + ind_i*step;
    const int j = patchSide + ind_j*step;

    const int offsetOrg = j*width + i;
    {
        const int local_i = get_local_id(0);
        const int local_j = get_local_id(1);
        
        const int local_idx = local_j*WGS_W + local_i;
        
        const int searchWindow = 16;
        const int halfSearchWindow = searchWindow/2;
        
        const int win_x = i - halfSearchWindow;
        const int win_y = j - halfSearchWindow;
        
        // window
//*
        {
            for (int x = 0; x < 3; ++x)
                for (int y = 0; y < 3; ++y)
                    window[y*8 + local_j][x*8 + local_i] = img[(win_y + y*8 + local_j)*width + (win_x + x*8 + local_i)];
        }
        
        barrier(CLK_LOCAL_MEM_FENCE);
//*/
/*
        // window
        const int qpos_i = (local_i % 3)*8;
        const int qpos_j = (local_i / 3)*8 + local_j;
        const int lpos_i = 2*8 + local_i;
        const int lpos_j = 2*8 + local_j;

        const int qoffs = (win_y + qpos_j)*width + (win_x + qpos_i);
        const int loffs = (win_y + lpos_j)*width + (win_x + lpos_i);

        *(float8*)&window[qpos_j][qpos_i] = *(float8*)&img[qoffs];
        window[lpos_j][lpos_i] = img[loffs];

        barrier(CLK_LOCAL_MEM_FENCE);
        //
*/
        const int _local_i = local_idx%16;
        const int _local_j = local_idx/16;
        
        const int fromW = _local_i;
        const int fromH = _local_j*4;
        
        __local float8 *p1 = (__local float8*)&window[halfSearchWindow + 0][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 0][halfSearchWindow]);
        __local float8 *p2 = (__local float8*)&window[halfSearchWindow + 1][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 1][halfSearchWindow]);
        __local float8 *p3 = (__local float8*)&window[halfSearchWindow + 2][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 2][halfSearchWindow]);
        __local float8 *p4 = (__local float8*)&window[halfSearchWindow + 3][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 3][halfSearchWindow]);
        __local float8 *p5 = (__local float8*)&window[halfSearchWindow + 4][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 4][halfSearchWindow]);
        __local float8 *p6 = (__local float8*)&window[halfSearchWindow + 5][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 5][halfSearchWindow]);
        __local float8 *p7 = (__local float8*)&window[halfSearchWindow + 6][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 6][halfSearchWindow]);
        __local float8 *p8 = (__local float8*)&window[halfSearchWindow + 7][halfSearchWindow];//vload8(0, &window[halfSearchWindow + 7][halfSearchWindow]);
        
        int di = fromW;
        {
            int dj = fromH;
            {
                // ssd
                float8 pp1  = vload8(0, &window[dj + 0][di]);
                float8 pp2  = vload8(0, &window[dj + 1][di]);
                float8 pp3  = vload8(0, &window[dj + 2][di]);
                float8 pp4  = vload8(0, &window[dj + 3][di]);
                float8 pp5  = vload8(0, &window[dj + 4][di]);
                float8 pp6  = vload8(0, &window[dj + 5][di]);
                float8 pp7  = vload8(0, &window[dj + 6][di]);
                float8 pp8  = vload8(0, &window[dj + 7][di]);
                
                // 1
                {
                    float dist = 0;
                    
                    float8
                    dp = pp1 - *p1;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp2 - *p2;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp3 - *p3;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp4 - *p4;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp5 - *p5;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp6 - *p6;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp7 - *p7;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp8 - *p8;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    
                    __local dist_t* d = &_dists[dj*searchWindow + di];
                    d->dist = dist;
                    d->offset = (win_y + dj)*width + (win_x + di);
                }
                
                dj++;
                
                // 2
                {
                    float dist = 0;
                    
                    pp1 = vload8(0, &window[dj + 7][di]);
                    
                    float8
                    dp = pp2 - *p1;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp3 - *p2;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp4 - *p3;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp5 - *p4;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp6 - *p5;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp7 - *p6;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp8 - *p7;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp1 - *p8;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    
                    __local dist_t* d = &_dists[dj*searchWindow + di];
                    d->dist = dist;
                    d->offset = (win_y + dj)*width + (win_x + di);
                }
                
                dj++;
                
                // 3
                {
                    float dist = 0;
                    
                    pp2 = vload8(0, &window[dj + 7][di]);
                    
                    float8
                    dp = pp3 - *p1;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp4 - *p2;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp5 - *p3;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp6 - *p4;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp7 - *p5;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp8 - *p6;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp1 - *p7;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp2 - *p8;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    
                    __local dist_t* d = &_dists[dj*searchWindow + di];
                    d->dist = dist;
                    d->offset = (win_y + dj)*width + (win_x + di);
                }
                
                dj++;
                
                // 4
                {
                    float dist = 0;
                    
                    pp3 = vload8(0, &window[dj + 7][di]);
                    
                    float8
                    dp = pp4 - *p1;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp5 - *p2;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp6 - *p3;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp7 - *p4;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp8 - *p5;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp1 - *p6;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp2 - *p7;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    dp = pp3 - *p8;
                    dist += dot(dp.lo, dp.lo) + dot(dp.hi, dp.hi);
                    
                    __local dist_t* d = &_dists[dj*searchWindow + di];
                    d->dist = dist;
                    d->offset = (win_y + dj)*width + (win_x + di);
                }
            }
        }
    }
    
    barrier(CLK_LOCAL_MEM_FENCE);
    
    bitonicSortLocal(_dists);
    
    int l_idx = get_local_id(0) + get_local_id(1) * 8;
    if (l_idx < 16)
    {
        int dst_idx = (ind_j * w_ind_size + ind_i)*16;
        
        g_offsets[dst_idx + l_idx] = l_idx == 0 ? offsetOrg : _dists[l_idx].offset;
    }
}
