
#define BLOCK_SIZE 8

#define    C_a 1.3870398453221474618216191915664f  //a = sqrt(2) * cos(1 * pi / 16)
#define    C_b 1.3065629648763765278566431734272f  //b = sqrt(2) * cos(2 * pi / 16)
#define    C_c 1.1758756024193587169744671046113f  //c = sqrt(2) * cos(3 * pi / 16)
#define    C_d 0.78569495838710218127789736765722f //d = sqrt(2) * cos(5 * pi / 16)
#define    C_e 0.54119610014619698439972320536639f //e = sqrt(2) * cos(6 * pi / 16)
#define    C_f 0.27589937928294301233595756366937f //f = sqrt(2) * cos(7 * pi / 16)
#define C_norm 0.35355339059327376220042218105242f //1 / sqrt(8)

inline void IDCT8(float *D)
{
    float Y04P   = D[0] + D[4];
    float Y2b6eP = C_b * D[2] + C_e * D[6];
    
    float Y04P2b6ePP = Y04P + Y2b6eP;
    float Y04P2b6ePM = Y04P - Y2b6eP;
    float Y7f1aP3c5dPP = C_f * D[7] + C_a * D[1] + C_c * D[3] + C_d * D[5];
    float Y7a1fM3d5cMP = C_a * D[7] - C_f * D[1] + C_d * D[3] - C_c * D[5];
    
    float Y04M   = D[0] - D[4];
    float Y2e6bM = C_e * D[2] - C_b * D[6];
    
    float Y04M2e6bMP = Y04M + Y2e6bM;
    float Y04M2e6bMM = Y04M - Y2e6bM;
    float Y1c7dM3f5aPM = C_c * D[1] - C_d * D[7] - C_f * D[3] - C_a * D[5];
    float Y1d7cP3a5fMM = C_d * D[1] + C_c * D[7] - C_a * D[3] + C_f * D[5];
    
    D[0] = C_norm * (Y04P2b6ePP + Y7f1aP3c5dPP);
    D[7] = C_norm * (Y04P2b6ePP - Y7f1aP3c5dPP);
    D[4] = C_norm * (Y04P2b6ePM + Y7a1fM3d5cMP);
    D[3] = C_norm * (Y04P2b6ePM - Y7a1fM3d5cMP);
    
    D[1] = C_norm * (Y04M2e6bMP + Y1c7dM3f5aPM);
    D[5] = C_norm * (Y04M2e6bMM - Y1d7cP3a5fMM);
    D[2] = C_norm * (Y04M2e6bMM + Y1d7cP3a5fMM);
    D[6] = C_norm * (Y04M2e6bMP - Y1c7dM3f5aPM);
}


#define BLOCK_X 64
#define BLOCK_Y 8

__kernel void IDCT8x8(
                      __global float *src,
                      __global float *dst,
                      int numTilesW,
                      int numTilesH)
{
    __local float l_Transpose[BLOCK_Y][BLOCK_X + 1];
    const uint    localX = get_local_id(0);
    const uint    localY = BLOCK_SIZE * get_local_id(1);
    const uint modLocalX = localX & (BLOCK_SIZE - 1);
    
    __local float *l_V = &l_Transpose[localY +         0][localX +         0];
    __local float *l_H = &l_Transpose[localY + modLocalX][localX - modLocalX];
    
    const uint tile_index_x = (get_group_id(0) * (BLOCK_X / BLOCK_SIZE)) + (localX / BLOCK_SIZE);
    const uint tile_index_y = get_global_id(1);
    
    if ((int)tile_index_x >= numTilesW || (int)tile_index_y >= numTilesH) {
        return;
    }
    
    const uint offset = ((tile_index_y * numTilesW + tile_index_x) * 64) + modLocalX;
    
    src += offset;
    dst += offset;
    
    float D[BLOCK_SIZE];
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        l_V[i * (BLOCK_X + 1)] = src[i * BLOCK_SIZE];
    }
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        D[i] = l_H[i];
    }
    
    IDCT8(D);

    for(uint i = 0; i < BLOCK_SIZE; i++) {
        l_H[i] = D[i];
    }
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        D[i] = l_V[i * (BLOCK_X + 1)];
    }
    
    IDCT8(D);
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        dst[i * BLOCK_SIZE] = D[i];
    }
}

//__kernel void IDCT8x8(
//                      __global float *src,
//                      __global float *dst)
//{
//    int patch_idx = get_group_id(0);
//    int thread_idx = get_local_id(0);
//    
//    int offset = patch_idx * 64 + 8 * thread_idx;
//    
//    __local float patch[8*8];
//    __local float *patch_ptr = patch + thread_idx * 8;
//    
//    for (int i = 0; i < 8; ++i) {
//        patch[8 * i + thread_idx] = src[offset + i];
//    }
//    
//    float D[8];
//    
//    for (int i = 0; i < 8; ++i) {
//        D[i] = patch_ptr[i];
//    }
//    
//    IDCT8(D);
//    
//    for (int i = 0; i < 8; ++i) {
//        patch[8 * i + thread_idx] = D[i];
//    }
//    
//    for (int i = 0; i < 8; ++i) {
//        D[i] = patch_ptr[i];
//    }
//    
//    IDCT8(D);
//    
//    for (int i = 0; i < 8; ++i) {
//        dst[offset + i] = D[i];
//    }
//}

//#define patchSide 8
//#define patchSideSh 3
//#define side_2_sh 6
//#define side_2 (patchSide * patchSide)
//
//#define N 16
//#define NSHIFT 4
//
//#define PATCHSHIFT 3
//#define PATCHSIZE (1<<PATCHSHIFT)
//
//__constant float DCTv8matrix[] =
//{
//    0.3535533905932738f,  0.4903926402016152f,  0.4619397662556434f,  0.4157348061512726f,  0.3535533905932738f,  0.2777851165098011f,  0.1913417161825449f,  0.0975451610080642f,
//    0.3535533905932738f,  0.4157348061512726f,  0.1913417161825449f, -0.0975451610080641f, -0.3535533905932737f, -0.4903926402016152f, -0.4619397662556434f, -0.2777851165098011f,
//    0.3535533905932738f,  0.2777851165098011f, -0.1913417161825449f, -0.4903926402016152f, -0.3535533905932738f,  0.0975451610080642f,  0.4619397662556433f,  0.4157348061512727f,
//    0.3535533905932738f,  0.0975451610080642f, -0.4619397662556434f, -0.2777851165098011f,  0.3535533905932737f,  0.4157348061512727f, -0.1913417161825450f, -0.4903926402016153f,
//    0.3535533905932738f, -0.0975451610080641f, -0.4619397662556434f,  0.2777851165098009f,  0.3535533905932738f, -0.4157348061512726f, -0.1913417161825453f,  0.4903926402016152f,
//    0.3535533905932738f, -0.2777851165098010f, -0.1913417161825452f,  0.4903926402016153f, -0.3535533905932733f, -0.0975451610080649f,  0.4619397662556437f, -0.4157348061512720f,
//    0.3535533905932738f, -0.4157348061512727f,  0.1913417161825450f,  0.0975451610080640f, -0.3535533905932736f,  0.4903926402016152f, -0.4619397662556435f,  0.2777851165098022f,
//    0.3535533905932738f, -0.4903926402016152f,  0.4619397662556433f, -0.4157348061512721f,  0.3535533905932733f, -0.2777851165098008f,  0.1913417161825431f, -0.0975451610080625f
//};
//
//#define    C_a 1.3870398453221474618216191915664f  //a = sqrt(2) * cos(1 * pi / 16)
//#define    C_b 1.3065629648763765278566431734272f  //b = sqrt(2) * cos(2 * pi / 16)
//#define    C_c 1.1758756024193587169744671046113f  //c = sqrt(2) * cos(3 * pi / 16)
//#define    C_d 0.78569495838710218127789736765722f //d = sqrt(2) * cos(5 * pi / 16)
//#define    C_e 0.54119610014619698439972320536639f //e = sqrt(2) * cos(6 * pi / 16)
//#define    C_f 0.27589937928294301233595756366937f //f = sqrt(2) * cos(7 * pi / 16)
//#define C_norm 0.35355339059327376220042218105242f //1 / sqrt(8)
//
//inline void IDCT8(float *D, float *O)
//{
//    float Y04P   = D[0] + D[4];
//    float Y2b6eP = C_b * D[2] + C_e * D[6];
//    
//    float Y04P2b6ePP = Y04P + Y2b6eP;
//    float Y04P2b6ePM = Y04P - Y2b6eP;
//    float Y7f1aP3c5dPP = C_f * D[7] + C_a * D[1] + C_c * D[3] + C_d * D[5];
//    float Y7a1fM3d5cMP = C_a * D[7] - C_f * D[1] + C_d * D[3] - C_c * D[5];
//    
//    float Y04M   = D[0] - D[4];
//    float Y2e6bM = C_e * D[2] - C_b * D[6];
//    
//    float Y04M2e6bMP = Y04M + Y2e6bM;
//    float Y04M2e6bMM = Y04M - Y2e6bM;
//    float Y1c7dM3f5aPM = C_c * D[1] - C_d * D[7] - C_f * D[3] - C_a * D[5];
//    float Y1d7cP3a5fMM = C_d * D[1] + C_c * D[7] - C_a * D[3] + C_f * D[5];
//    
//    O[0] = C_norm * (Y04P2b6ePP + Y7f1aP3c5dPP);
//    O[7] = C_norm * (Y04P2b6ePP - Y7f1aP3c5dPP);
//    O[4] = C_norm * (Y04P2b6ePM + Y7a1fM3d5cMP);
//    O[3] = C_norm * (Y04P2b6ePM - Y7a1fM3d5cMP);
//    
//    O[1] = C_norm * (Y04M2e6bMP + Y1c7dM3f5aPM);
//    O[5] = C_norm * (Y04M2e6bMM - Y1d7cP3a5fMM);
//    O[2] = C_norm * (Y04M2e6bMM + Y1d7cP3a5fMM);
//    O[6] = C_norm * (Y04M2e6bMP - Y1c7dM3f5aPM);
//}
//
//inline void iDCT2D8x8_new(float* patch)
//{
//    const int side2 = PATCHSIZE*PATCHSIZE;
//    
//    float tmp[PATCHSIZE*PATCHSIZE];
//    
//#pragma unroll
//    for (int j = 0; j < PATCHSIZE; j++)
//        IDCT8(patch + (j<<PATCHSHIFT), tmp + (j<<PATCHSHIFT));
//    
//#pragma unroll
//    for (int j = 0; j < PATCHSIZE; j++)
//#pragma unroll
//        for (int i = 0; i < PATCHSIZE; i++)
//            patch[(j<<PATCHSHIFT) + i] = tmp[(i<<PATCHSHIFT) + j];
//    
//#pragma unroll
//    for (int j = 0; j < PATCHSIZE; j++)
//        IDCT8(patch + (j<<PATCHSHIFT), tmp + (j<<PATCHSHIFT));
//    
//#pragma unroll
//    for (int j = 0; j < PATCHSIZE; j++)
//#pragma unroll
//        for (int i = 0; i < PATCHSIZE; i++)
//            patch[(j<<PATCHSHIFT) + i] = tmp[(i<<PATCHSHIFT) + j];
//}
//
//__kernel
//void idct(
//          __global float*       stacks,
//          int                   size)
//{
//    const int idx = get_global_id(0);
//    if(idx >= size)
//        return;
//    
//    __global float* _patch = &stacks[idx*side_2];
//    
//    float patch[PATCHSIZE*PATCHSIZE];
//    
//#pragma unroll
//    for (unsigned k = 0; k < side_2; k++)
//        patch[k] = _patch[k];
//    
//    iDCT2D8x8_new(patch);
//    
//#pragma unroll
//    for (unsigned k = 0; k < side_2; k++)
//        _patch[k] = patch[k];
//}
//
//#define G_X 8
//#define G_Y 8
//#define G_Z 2
//
////__kernel void idct_new(__global float *buffer, int num)
////{
////    const int l_x = get_local_id(0);
////    const int l_y = get_local_id(1);
////    const int l_z = get_local_id(2);
////    const int offset = get_global_id(0) + get_global_id(1) * G_X + get_global_id(2) * G_X * G_Y;
////    const bool boundsCheck = get_group_id(2) * G_Z < num;
////    
////    __local float patch_1[G_Z][G_Y][G_X];
////    __local float patch_2[G_Z][G_Y][G_X];
////    
////    if (boundsCheck) {
////        patch_1[l_z][l_y][l_x] = buffer[offset];
////    }
////    
////    barrier(CLK_LOCAL_MEM_FENCE);
////    
////    if (boundsCheck) {
////        
////        float sum = 0.0f;
////#pragma unroll
////        for (int i = 0; i < 8; ++i) {
////            sum += patch_1[l_z][l_y][i] * DCTv8matrix[i + 8 * l_x];
////        }
////        
////        patch_2[l_z][l_x][l_y] = sum;
////    }
////    
////    barrier(CLK_LOCAL_MEM_FENCE);
////    
////    if (boundsCheck) {
////        float sum = 0.0f;
////#pragma unroll
////        for (int i = 0; i < 8; ++i) {
////            sum += patch_2[l_z][l_y][i] * DCTv8matrix[i + 8 * l_x];
////        }
////        
////        const int offset = get_global_id(1) + get_global_id(0) * G_X + get_global_id(2) * G_X * G_Y;
////        buffer[offset] = sum;
////    }
////}
//
//__kernel void idct_new2(__global float *buffer, int num)
//{
//    int n = get_local_id(1) + get_group_size(0) * get_group_id(0);
//    int y = get_local_id(0);
//    
//    
//    
//    
//    
//}
//
//__kernel void idct_new2(__global float *buffer, int num)
//{
//    const int x = get_local_id(0);
//    
//    const int l_x = x % 8;
//    const int l_y = x / 8;
//    
//    const int offset = get_global_id(0);
//    
//    __local float patch_1[64];
//    __local float patch_2[64];
//    
//    patch_1[x] = buffer[offset];
//    
//    barrier(CLK_LOCAL_MEM_FENCE);
//    
////    float sum = 0.0f;
////#pragma unroll
////    for (int i = 0; i < 8; ++i) {
////        sum += patch_1[l_y * 8 + i] * DCTv8matrix[i + 8 * l_x];
////    }
////    patch_2[l_x * 8 + l_y] = sum;
//    
//    
//    float8 vec1 = vload8(0, patch_1 + l_y * 8);
//    float8 vec2 = vload8(0, DCTv8matrix + l_x * 8);
//    
//    patch_2[l_x * 8 + l_y] = dot(vec1.lo, vec2.lo) + dot(vec1.hi, vec2.hi);
//    
//    barrier(CLK_LOCAL_MEM_FENCE);
//    
//    vec1 = vload8(0, patch_2 + l_y * 8);
//    vec2 = vload8(0, DCTv8matrix + l_x * 8);
//    
//    patch_1[l_x * 8 + l_y] = dot(vec1.lo, vec2.lo) + dot(vec1.hi, vec2.hi);
//    
//    //    sum = 0.0f;
//    //#pragma unroll
//    //    for (int i = 0; i < 8; ++i) {
//    //        sum += patch_2[l_y * 8 + i] * DCTv8matrix[i + 8 * l_x];
//    //    }
//    //    
////    patch_1[l_x * 8 + l_y] = sum;
//    barrier(CLK_LOCAL_MEM_FENCE);
//    
//    buffer[offset] = patch_1[x];
//}
