
#define    C_a 1.3870398453221474618216191915664f  //a = sqrt(2) * cos(1 * pi / 16)
#define    C_b 1.3065629648763765278566431734272f  //b = sqrt(2) * cos(2 * pi / 16)
#define    C_c 1.1758756024193587169744671046113f  //c = sqrt(2) * cos(3 * pi / 16)
#define    C_d 0.78569495838710218127789736765722f //d = sqrt(2) * cos(5 * pi / 16)
#define    C_e 0.54119610014619698439972320536639f //e = sqrt(2) * cos(6 * pi / 16)
#define    C_f 0.27589937928294301233595756366937f //f = sqrt(2) * cos(7 * pi / 16)
#define C_norm 0.35355339059327376220042218105242f //1 / sqrt(8)

#define BLOCK_SIZE 8

#define BLOCK_X 64
#define BLOCK_Y 8

inline void DCT8(float *D){
    float X07P = D[0] + D[7];
    float X16P = D[1] + D[6];
    float X25P = D[2] + D[5];
    float X34P = D[3] + D[4];
    
    float X07M = D[0] - D[7];
    float X61M = D[6] - D[1];
    float X25M = D[2] - D[5];
    float X43M = D[4] - D[3];
    
    float X07P34PP = X07P + X34P;
    float X07P34PM = X07P - X34P;
    float X16P25PP = X16P + X25P;
    float X16P25PM = X16P - X25P;
    
    D[0] = C_norm * (X07P34PP + X16P25PP);
    D[2] = C_norm * (C_b * X07P34PM + C_e * X16P25PM);
    D[4] = C_norm * (X07P34PP - X16P25PP);
    D[6] = C_norm * (C_e * X07P34PM - C_b * X16P25PM);
    
    D[1] = C_norm * (C_a * X07M - C_c * X61M + C_d * X25M - C_f * X43M);
    D[3] = C_norm * (C_c * X07M + C_f * X61M - C_a * X25M + C_d * X43M);
    D[5] = C_norm * (C_d * X07M + C_a * X61M + C_f * X25M - C_c * X43M);
    D[7] = C_norm * (C_f * X07M + C_d * X61M + C_c * X25M + C_a * X43M);
}

__kernel void DCT8x8(
                     __global float *src,
                     __global float *dst,
                     int numTilesW,
                     int numTilesH)
{
    __local float l_Transpose[BLOCK_Y][BLOCK_X + 1];
    const uint    localX = get_local_id(0);
    const uint    localY = BLOCK_SIZE * get_local_id(1);
    const uint modLocalX = localX & (BLOCK_SIZE - 1);
    
    __local float *l_V = &l_Transpose[localY +         0][localX +         0];
    __local float *l_H = &l_Transpose[localY + modLocalX][localX - modLocalX];
    
    const uint tile_index_x = (get_group_id(0) * (BLOCK_X / BLOCK_SIZE)) + (localX / BLOCK_SIZE);
    const uint tile_index_y = get_global_id(1);
    
    if ((int)tile_index_x >= (numTilesW - 7) || (int)tile_index_y >= (numTilesH - 7)) {
        return;
    }
    
    const uint offsetIn = (tile_index_y * numTilesW + tile_index_x) + modLocalX;
    const uint offsetOut = ((tile_index_y * numTilesW + tile_index_x) * 64) + modLocalX;
    
    src += offsetIn;
    dst += offsetOut;
    
    float D[BLOCK_SIZE];
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        l_V[i * (BLOCK_X + 1)] = src[i * (numTilesW)];
    }
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        D[i] = l_H[i];
    }
    
    DCT8(D);
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        l_H[i] = D[i];
    }
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        D[i] = l_V[i * (BLOCK_X + 1)];
    }
    
    DCT8(D);
    
    for(uint i = 0; i < BLOCK_SIZE; i++) {
        dst[i * BLOCK_SIZE] = D[i];
    }
}

//# define PATCHSIZE 8
//# define PATCH_OFFSET 7
//# define GRP_W 16
//# define GRP_H 8
//
//#define    C_a 1.3870398453221474618216191915664f  //a = sqrt(2) * cos(1 * pi / 16)
//#define    C_b 1.3065629648763765278566431734272f  //b = sqrt(2) * cos(2 * pi / 16)
//#define    C_c 1.1758756024193587169744671046113f  //c = sqrt(2) * cos(3 * pi / 16)
//#define    C_d 0.78569495838710218127789736765722f //d = sqrt(2) * cos(5 * pi / 16)
//#define    C_e 0.54119610014619698439972320536639f //e = sqrt(2) * cos(6 * pi / 16)
//#define    C_f 0.27589937928294301233595756366937f //f = sqrt(2) * cos(7 * pi / 16)
//#define C_norm 0.35355339059327376220042218105242f //1 / sqrt(8)
//
//#define BLOCK_SIZE 8
//
//#define BLOCK_X 64
//#define BLOCK_Y 8
//
//inline void DCT8(float *D){
//    float X07P = D[0] + D[7];
//    float X16P = D[1] + D[6];
//    float X25P = D[2] + D[5];
//    float X34P = D[3] + D[4];
//    
//    float X07M = D[0] - D[7];
//    float X61M = D[6] - D[1];
//    float X25M = D[2] - D[5];
//    float X43M = D[4] - D[3];
//    
//    float X07P34PP = X07P + X34P;
//    float X07P34PM = X07P - X34P;
//    float X16P25PP = X16P + X25P;
//    float X16P25PM = X16P - X25P;
//    
//    D[0] = C_norm * (X07P34PP + X16P25PP);
//    D[2] = C_norm * (C_b * X07P34PM + C_e * X16P25PM);
//    D[4] = C_norm * (X07P34PP - X16P25PP);
//    D[6] = C_norm * (C_e * X07P34PM - C_b * X16P25PM);
//    
//    D[1] = C_norm * (C_a * X07M - C_c * X61M + C_d * X25M - C_f * X43M);
//    D[3] = C_norm * (C_c * X07M + C_f * X61M - C_a * X25M + C_d * X43M);
//    D[5] = C_norm * (C_d * X07M + C_a * X61M + C_f * X25M - C_c * X43M);
//    D[7] = C_norm * (C_f * X07M + C_d * X61M + C_c * X25M + C_a * X43M);
//}
//
//__kernel void DCT8x8(
//                     __global float *src,
//                     __global float *dst,
//                     int numTilesW,
//                     int numTilesH)
//{
//    __local float l_Transpose[BLOCK_Y][BLOCK_X + 1];
//    const uint    localX = get_local_id(0);
//    const uint    localY = BLOCK_SIZE * get_local_id(1);
//    const uint modLocalX = localX & (BLOCK_SIZE - 1);
//    
//    __local float *l_V = &l_Transpose[localY +         0][localX +         0];
//    __local float *l_H = &l_Transpose[localY + modLocalX][localX - modLocalX];
//    
//    const uint tile_index_x = (get_group_id(0) * (BLOCK_X / BLOCK_SIZE)) + (localX / BLOCK_SIZE);
//    const uint tile_index_y = get_global_id(1);
//    
//    if (tile_index_x >= numTilesW - 7 || tile_index_y >= numTilesH - 7) {
//        return;
//    }
//    
//    const uint offset = ((tile_index_y * numTilesW + tile_index_x) * 64) + modLocalX;
//    
//    src += offset;
//    dst += offset;
//    
//    float D[BLOCK_SIZE];
//    
//    for(uint i = 0; i < BLOCK_SIZE; i++) {
//        l_V[i * (BLOCK_X + 1)] = src[i * BLOCK_SIZE];
//    }
//    
//    for(uint i = 0; i < BLOCK_SIZE; i++) {
//        D[i] = l_H[i];
//    }
//    
//    DCT8(D);
//    
//    for(uint i = 0; i < BLOCK_SIZE; i++) {
//        l_H[i] = D[i];
//    }
//    
//    for(uint i = 0; i < BLOCK_SIZE; i++) {
//        D[i] = l_V[i * (BLOCK_X + 1)];
//    }
//    
//    DCT8(D);
//    
//    for(uint i = 0; i < BLOCK_SIZE; i++) {
//        dst[i * BLOCK_SIZE] = D[i];
//    }
//}
//
//__kernel __attribute__((reqd_work_group_size(GRP_W, GRP_H, 1)))
//void DCT2D8x8_new(
//                           global float*     input,
//                           global float*     output,
//                           const int  width,
//                           const int  height
//                           )
//{
//    const int g_x = get_global_id(0);
//    const int g_y = get_global_id(1);
//    const int l_x = get_local_id(0);
//    const int l_y = get_local_id(1);
//    const bool boundsCheck = g_x < width - PATCH_OFFSET && g_y < height - PATCH_OFFSET;
//    const int inputOffet = g_y * width + g_x;
//    const int outputOffet = inputOffet * PATCHSIZE * PATCHSIZE;
//    
//    __local float s_mem[GRP_H + PATCH_OFFSET][GRP_W + PATCH_OFFSET];
//    
//    if (boundsCheck) {
//        const bool needsRightLoad = l_x + PATCH_OFFSET > GRP_W - 1;
//        const bool needsBottomLoad = l_y + PATCH_OFFSET > GRP_H - 1;
//        
//        s_mem[l_y][l_x] = input[inputOffet];
//        
//        if (needsRightLoad) {
//            s_mem[l_y][l_x + PATCH_OFFSET] = input[inputOffet + PATCH_OFFSET];
//        }
//        
//        if (needsBottomLoad) {
//            s_mem[l_y + PATCH_OFFSET][l_x] = input[inputOffet + width * PATCH_OFFSET];
//        }
//       
//        if (needsRightLoad && needsBottomLoad) {
//            s_mem[l_y + PATCH_OFFSET][l_x + PATCH_OFFSET] = input[inputOffet + width * PATCH_OFFSET + PATCH_OFFSET];
//        }
//    }
//    
//    barrier(CLK_LOCAL_MEM_FENCE);
//    
//    if (boundsCheck) {
//        
//        float patch[PATCHSIZE][PATCHSIZE];
//        
//#pragma unroll
//        for (int j = 0; j < PATCHSIZE; ++j) {
//            __local float *D = s_mem[j + l_y] + l_x;
//            
//            const float X07P = D[0] + D[7];
//            const float X16P = D[1] + D[6];
//            const float X25P = D[2] + D[5];
//            const float X34P = D[3] + D[4];
//            
//            const float X07M = D[0] - D[7];
//            const float X61M = D[6] - D[1];
//            const float X25M = D[2] - D[5];
//            const float X43M = D[4] - D[3];
//            
//            const float X07P34PP = X07P + X34P;
//            const float X07P34PM = X07P - X34P;
//            const float X16P25PP = X16P + X25P;
//            const float X16P25PM = X16P - X25P;
//            
//            patch[0][j] = C_norm * (X07P34PP + X16P25PP);
//            patch[2][j] = C_norm * (C_b * X07P34PM + C_e * X16P25PM);
//            patch[4][j] = C_norm * (X07P34PP - X16P25PP);
//            patch[6][j] = C_norm * (C_e * X07P34PM - C_b * X16P25PM);
//            
//            patch[1][j] = C_norm * (C_a * X07M - C_c * X61M + C_d * X25M - C_f * X43M);
//            patch[3][j] = C_norm * (C_c * X07M + C_f * X61M - C_a * X25M + C_d * X43M);
//            patch[5][j] = C_norm * (C_d * X07M + C_a * X61M + C_f * X25M - C_c * X43M);
//            patch[7][j] = C_norm * (C_f * X07M + C_d * X61M + C_c * X25M + C_a * X43M);
//        }
//        
//        float8 res[PATCHSIZE];
//#pragma unroll
//        for (int j = 0; j < PATCHSIZE; ++j) {
//            
//            const float *D = patch[j];
//            
//            const float X07P = D[0] + D[7];
//            const float X16P = D[1] + D[6];
//            const float X25P = D[2] + D[5];
//            const float X34P = D[3] + D[4];
//            
//            const float X07M = D[0] - D[7];
//            const float X61M = D[6] - D[1];
//            const float X25M = D[2] - D[5];
//            const float X43M = D[4] - D[3];
//            
//            const float X07P34PP = X07P + X34P;
//            const float X07P34PM = X07P - X34P;
//            const float X16P25PP = X16P + X25P;
//            const float X16P25PM = X16P - X25P;
//            
//            res[j].s0 = C_norm * (X07P34PP + X16P25PP);
//            res[j].s2 = C_norm * (C_b * X07P34PM + C_e * X16P25PM);
//            res[j].s4 = C_norm * (X07P34PP - X16P25PP);
//            res[j].s6 = C_norm * (C_e * X07P34PM - C_b * X16P25PM);
//            
//            res[j].s1 = C_norm * (C_a * X07M - C_c * X61M + C_d * X25M - C_f * X43M);
//            res[j].s3 = C_norm * (C_c * X07M + C_f * X61M - C_a * X25M + C_d * X43M);
//            res[j].s5 = C_norm * (C_d * X07M + C_a * X61M + C_f * X25M - C_c * X43M);
//            res[j].s7 = C_norm * (C_f * X07M + C_d * X61M + C_c * X25M + C_a * X43M);
//        }
//
//#pragma unroll
//        for (int i = 0; i < PATCHSIZE; ++i) {
//            __global float *dst = output + outputOffet + PATCHSIZE * i;
//#pragma unroll
//            for (int j = 0; j < PATCHSIZE; ++j, ++dst) {
//                *dst = res[j][i];
//            }
//        }
//    }
//}
//
////
////inline void DCT8(float *D, float *O){
////    const float X07P = D[0] + D[7];
////    const float X16P = D[1] + D[6];
////    const float X25P = D[2] + D[5];
////    const float X34P = D[3] + D[4];
////    
////    const float X07M = D[0] - D[7];
////    const float X61M = D[6] - D[1];
////    const float X25M = D[2] - D[5];
////    const float X43M = D[4] - D[3];
////    
////    const float X07P34PP = X07P + X34P;
////    const float X07P34PM = X07P - X34P;
////    const float X16P25PP = X16P + X25P;
////    const float X16P25PM = X16P - X25P;
////    
////    O[0] = C_norm * (X07P34PP + X16P25PP);
////    O[2] = C_norm * (C_b * X07P34PM + C_e * X16P25PM);
////    O[4] = C_norm * (X07P34PP - X16P25PP);
////    O[6] = C_norm * (C_e * X07P34PM - C_b * X16P25PM);
////    
////    O[1] = C_norm * (C_a * X07M - C_c * X61M + C_d * X25M - C_f * X43M);
////    O[3] = C_norm * (C_c * X07M + C_f * X61M - C_a * X25M + C_d * X43M);
////    O[5] = C_norm * (C_d * X07M + C_a * X61M + C_f * X25M - C_c * X43M);
////    O[7] = C_norm * (C_f * X07M + C_d * X61M + C_c * X25M + C_a * X43M);
////}
////
////// 2D DCT of a 8x8 patches. The result is restored in-place.
////__kernel void DCT2D8x8(
////                       global float*     input,
////                       global float*     output,
////                       const int  width,
////                       const int  height
////                       )
////{
////    int x = get_global_id(0);
////    int y = get_global_id(1);
////    
////    if (x < width - 7 && y < height - 7) {
////        
////        const int side2 = PATCHSIZE*PATCHSIZE;
////        
////        const int offsetOut = (y*width + x)*side2;
////        
////        float patch[PATCHSIZE][PATCHSIZE];
////        float tmp[PATCHSIZE][PATCHSIZE];
////        
////        for (int j = 0; j < PATCHSIZE; j++) {
////            for (int i = 0; i < PATCHSIZE; i++) {
////                const int offsetIn = (y + j)*width + (x + i);
////                patch[j][i] = input[offsetIn];
////            }
////        }
////        
////        for (int j = 0; j < PATCHSIZE; j++) {
////            DCT8(patch[j], tmp[j]);
////        }
////        
////        for (int j = 0; j < PATCHSIZE; j++) {
////            for (int i = 0; i < PATCHSIZE; i++) {
////                patch[j][i] = tmp[i][j];
////            }
////        }
////        
////        for (int j = 0; j < PATCHSIZE; j++) {
////            DCT8(patch[j], tmp[j]);
////        }
////        
////        for (int j = 0; j < PATCHSIZE; j++) {
////            for (int i = 0; i < PATCHSIZE; i++) {
////                output[offsetOut + j*PATCHSIZE + i] = tmp[i][j];
////            }
////        }
////    }
////}
