/*
 * Test harness for exercising candidate DCT/IDCT implementations.
 *
 */
#include <time.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

static const double PI=3.14159265358979323;

/***************************************************************************/

/* System dependencies: high-quality random numbers */

/* If your system supports a higher-quality random number generator,
   plug that in here */
void InitRandom() { srand(time(0)); }
int Random() { return rand(); }


/* Timing primitives:
 * timestamp() returns a TIME_T value which somehow indicates the current time
 * timeelapsed() accepts a TIME_T value and returns a double indicating
 *   the number of elapsed seconds
 */

/*
 * Define ONLY ONE of the following:
 */

#define BSD_TIMES (1)  /* 1 => This system supports BSD-style times() */
#define ANSI_TIMES (0) /* 1 => This system supports ANSI time()/difftime() */


/* BSD systems provide a times() call which provides CPU time usage
 * for the current process.  Historically, this was accurate to the
 * nearest 1/128 of a second.  Some newer systems may provide
 * millisecond resolution.
 */
#if BSD_TIMES
#include <sys/times.h>
#define CLOCK_TCK _BSD_CLOCKS_PER_SEC_
#define TIME_T double
TIME_T timestamp() {
  struct tms timeInfo;
  times(&timeInfo);
  return timeInfo.tms_utime/(double)CLOCK_TCK;
}
double timeelapsed(TIME_T last) {
  return timestamp()-last;
}
#endif

/* On most systems, ANSI-standard time()/difftime() functions
 * only provides timing to the nearest whole second.
 * If you use this, you may need to use very large iteration
 * counts (possibly hundreds of millions) to accurately
 * determine the speed of a function.
 */
#if ANSI_TIMES
#define TIME_T time_t
TIME_T timestamp() {
  return time(0);
}
double timeelapsed(TIME_T last) {
  return difftime(time(0),last);
}
#endif

/***************************************************************************/

/*
 * 2-d Forward DCT implemented directly from the formulas.
 * Very accurate, very slow.
 */
static void
dct2dReference(int (*data)[8]) {
  double output[8][8] = {{0}};
  short x,y,n,m;
  for(y=0;y<8;y++) {
    for(x=0;x<8;x++) {
      for(n=0;n<8;n++) {
	for(m=0;m<8;m++) {
	  output[y][x] += data[n][m]
	    * cos(PI * x * (2*m+1)/16.0) * cos(PI * y * (2*n+1)/16.0);
	}
      }
    }
  }
  {
    for(y=0;y<8;y++) {
      for(x=0;x<8;x++) {
	if(x==0) output[y][x] /= sqrt(2);
	if(y==0) output[y][x] /= sqrt(2);
	data[y][x] = floor(output[y][x]/16 + 0.5);
      }
    }
  }
}

/***************************************************************************/

/*
 * 2-d Forward DCT implemented in terms of 1-D DCT
 */
static void
dct2dSeparable(int (*data)[8]) {
  double work[8][8] = {{0}};
  int row,col;
  for(row=0;row<8;row++) {
    short x,n;
    for(x=0;x<8;x++) {
      for(n=0;n<8;n++)
	work[row][x] += data[row][n] * cos(PI * x * (2*n+1)/16.0);
      work[row][x] /= 4.0; /* typical weighting */
      if(x == 0) work[row][x] /= sqrt(2.0);
    }
  }

  for(col=0;col<8;col++) {
    short x,n;
    for(x=0;x<8;x++) {
      double result=0;
      for(n=0;n<8;n++)
	result += work[n][col] * cos(PI * x * (2*n+1)/16.0);
      if(x==0) result /= sqrt(2.0);
      result /= 4.0;
      /* Assign final result back into data */
      data[x][col] = floor(result + 0.5); /* Round correctly */
    }
  }
}

/***************************************************************************/

/*
 * From Figure 1 of Loeffler, Ligtenberg, and Moschytz (LL&M).
 * ("Practical Fast 1-D DCT Algorithms with 11 Multiplications,"
 * Acoustics, Speech, and Signal Processing, 1989. ICASSP-89, 1989.
 * pp 988-991.)
 *
 * Note that the 1-D DCT algorithm in LL&M results in the output
 * scaled by 4*sqrt(2) (i.e., 2 1/2 bits).  After two passes,
 * I need to scale the output by 32 (>>5).
 */
static void
dct2dTest(int (*dctBlock)[8]) {
  static const int c1=1004 /*cos(pi/16)<<10*/, s1=200 /*sin(pi/16)<<10*/;
  static const int c3=851 /*cos(3pi/16)<<10*/, s3=569 /*sin(3pi/16)<<10*/;
  static const int r2c6=554 /*sqrt(2)*cos(6pi/16)<<10*/, r2s6=1337;
  static const int r2=181; /* sqrt(2)<<7 */
  int row,col;

  for(row=0;row<8;row++) {
    int x0=dctBlock[row][0], x1=dctBlock[row][1], x2=dctBlock[row][2],
      x3=dctBlock[row][3], x4=dctBlock[row][4], x5=dctBlock[row][5],
      x6=dctBlock[row][6], x7=dctBlock[row][7], x8;

    /* Stage 1 */
    x8=x7+x0; x0-=x7; x7=x1+x6; x1-=x6; x6=x2+x5; x2-=x5; x5=x3+x4; x3-=x4;

    /* Stage 2 */
    x4=x8+x5; x8-=x5; x5=x7+x6; x7-=x6;
    x6=c1*(x1+x2); x2=(-s1-c1)*x2+x6; x1=(s1-c1)*x1+x6;
    x6=c3*(x0+x3); x3=(-s3-c3)*x3+x6; x0=(s3-c3)*x0+x6;

    /* Stage 3 */
    x6=x4+x5; x4-=x5; x5=x0+x2;x0-=x2; x2=x3+x1; x3-=x1;
    x1=r2c6*(x7+x8); x7=(-r2s6-r2c6)*x7+x1; x8=(r2s6-r2c6)*x8+x1;

    /* Stage 4 and output */
    dctBlock[row][0]=x6;  dctBlock[row][4]=x4;
    dctBlock[row][2]=x8>>10; dctBlock[row][6] = x7>>10;
    dctBlock[row][7]=(x2-x5)>>10; dctBlock[row][1]=(x2+x5)>>10;
    dctBlock[row][3]=(x3*r2)>>17; dctBlock[row][5]=(x0*r2)>>17;
  }

  for(col=0;col<8;col++) {
    int x0=dctBlock[0][col], x1=dctBlock[1][col], x2=dctBlock[2][col],
      x3=dctBlock[3][col], x4=dctBlock[4][col], x5=dctBlock[5][col],
      x6=dctBlock[6][col], x7=dctBlock[7][col], x8;

    /* Stage 1 */
    x8=x7+x0; x0-=x7; x7=x1+x6; x1-=x6; x6=x2+x5; x2-=x5; x5=x3+x4; x3-=x4;

    /* Stage 2 */
    x4=x8+x5; x8-=x5; x5=x7+x6; x7-=x6;
    x6=c1*(x1+x2); x2=(-s1-c1)*x2+x6; x1=(s1-c1)*x1+x6;
    x6=c3*(x0+x3); x3=(-s3-c3)*x3+x6; x0=(s3-c3)*x0+x6;

    /* Stage 3 */
    x6=x4+x5; x4-=x5; x5=x0+x2;x0-=x2; x2=x3+x1; x3-=x1;
    x1=r2c6*(x7+x8); x7=(-r2s6-r2c6)*x7+x1; x8=(r2s6-r2c6)*x8+x1;

    /* Stage 4 and output */
    dctBlock[0][col]=(x6+16)>>5;  dctBlock[4][col]=(x4+16)>>5;
    dctBlock[2][col]=(x8+16384)>>15; dctBlock[6][col] = (x7+16384)>>15;
    dctBlock[7][col]=(x2-x5+16384)>>15; dctBlock[1][col]=(x2+x5+16384)>>15;
    dctBlock[3][col]=((x3>>8)*r2+8192)>>14;
    dctBlock[5][col]=((x0>>8)*r2+8192)>>14;
  }
}

/***************************************************************************/

/*
 * 2-d IDCT implemented directly from the formulas.
 * Very accurate, very slow.
 */
static void
idct2dReference(int (*data)[8]) {
  double output[8][8] = {{0}};
  short x,y,m,n;
  for(y=0;y<8;y++) {
    for(x=0;x<8;x++) {
      output[y][x]=0.0;
      for(n=0;n<8;n++)
	for(m=0;m<8;m++) {
	  double term = data[n][m] 
	    * cos(PI * m * (2*x+1)/16.0) * cos(PI * n * (2*y+1)/16.0);
	  if(n==0) term /= sqrt(2);
	  if(m==0) term /= sqrt(2);
	  output[y][x] += term;
      }
    }
  }
  for(y=0;y<8;y++) {
    for(x=0;x<8;x++) {
      output[y][x] /= 4.0;
      data[y][x] = floor(output[y][x] + 0.5); /* Round accurately */
    }
  }
}

/***************************************************************************/

static void
idct2dSeparable(int (*data)[8]) {
  double work[8][8] = {{0}};
  int row,col;
  for(row=0;row<8;row++) {
    short x,n;
    for(x=0;x<8;x++) {
      work[row][x]= data[row][0] / sqrt(2.0);
      for(n=1;n<8;n++)
	work[row][x] += data[row][n] * cos(PI * n * (2*x+1)/16.0);
    }
  }

  for(col=0;col<8;col++) {
    short x,n;
    for(x=0;x<8;x++) {
      double result = work[0][col] / sqrt(2.0);
      for(n=1;n<8;n++)
	result += work[n][col] * cos(PI * n * (2*x+1)/16.0);

      /* Assign final result back into data */
      result /= 4.0;
      data[x][col] = floor(result + 0.5); /* Round correctly */
    }
  }
}

/***************************************************************************/

static void
idct2dTest(int (*dctBlock)[8]) {
  int row,col;

  for(row=0;row<8;row++) {
    static const int c1=251 /*cos(pi/16)<<8*/, s1=50 /*sin(pi/16)<<8*/;
    static const int c3=213 /*cos(3pi/16)<<8*/, s3=142 /*sin(3pi/16)<<8*/;
    static const int r2c6=277 /*cos(6pi/16)*sqrt(2)<<9*/, r2s6=669;
    static const int r2=181; /* sqrt(2)<<7 */

    /* Stage 4 */
    int x0=dctBlock[row][0]<<9, x1=dctBlock[row][1]<<7, x2=dctBlock[row][2],
      x3=dctBlock[row][3]*r2, x4=dctBlock[row][4]<<9, x5=dctBlock[row][5]*r2,
      x6=dctBlock[row][6], x7=dctBlock[row][7]<<7;
    int x8=x7+x1; x1 -= x7;

    /* Stage 3 */
    x7=x0+x4; x0-=x4; x4=x1+x5; x1-=x5; x5=x3+x8; x8-=x3;
    x3=r2c6*(x2+x6);x6=x3+(-r2c6-r2s6)*x6;x2=x3+(-r2c6+r2s6)*x2;

    /* Stage 2 */
    x3=x7+x2; x7-=x2; x2=x0+x6; x0-= x6;
    x6=c3*(x4+x5);x5=(x6+(-c3-s3)*x5)>>6;x4=(x6+(-c3+s3)*x4)>>6;
    x6=c1*(x1+x8);x1=(x6+(-c1-s1)*x1)>>6;x8=(x6+(-c1+s1)*x8)>>6;

    /* Stage 1 and output */
    x7+=512; x2+=512; x0+=512; x3+=512;
    dctBlock[row][0]=(x3+x4)>>10;  dctBlock[row][1]=(x2+x8)>>10;
    dctBlock[row][2]=(x0+x1)>>10;  dctBlock[row][3]=(x7+x5)>>10;
    dctBlock[row][4]=(x7-x5)>>10;  dctBlock[row][5]=(x0-x1)>>10;
    dctBlock[row][6]=(x2-x8)>>10;  dctBlock[row][7]=(x3-x4)>>10;
  }

  for(col=0;col<8;col++) {
    static const int c1=251 /*cos(pi/16)<<8*/, s1=50 /*sin(pi/16)<<8*/;
    static const int c3=213 /*cos(3pi/16)<<8*/, s3=142 /*sin(3pi/16)<<8*/;
    static const int r2c6=277 /*cos(6pi/16)*sqrt(2)<<9*/, r2s6=669;
    static const int r2=181; /* sqrt(2)<<7 */

    /* Stage 4 */
    int x0=dctBlock[0][col]<<9, x1=dctBlock[1][col]<<7, x2=dctBlock[2][col],
      x3=((dctBlock[3][col]))*r2, x4=dctBlock[4][col]<<9,
      x5=((dctBlock[5][col]))*r2, x6=dctBlock[6][col],
      x7=dctBlock[7][col]<<7;
    int x8=x7+x1; x1 -= x7;

    /* Stage 3 */
    x7=x0+x4; x0-=x4; x4=x1+x5; x1-=x5; x5=x3+x8; x8-=x3;
    x3=r2c6*(x2+x6);x6=x3+(-r2c6-r2s6)*x6;x2=x3+(-r2c6+r2s6)*x2;

    /* Stage 2 */
    x3=x7+x2; x7-=x2; x2=x0+x6; x0-= x6;
    x4>>=6;x5>>=6;x1>>=6;x8>>=6;
    x6=c3*(x4+x5);x5=(x6+(-c3-s3)*x5);x4=(x6+(-c3+s3)*x4);
    x6=c1*(x1+x8);x1=(x6+(-c1-s1)*x1);x8=(x6+(-c1+s1)*x8);

    /* Stage 1, rounding and output */
    x7+=1024; x2+=1024;x0+=1024;x3+=1024; /* For correct rounding */
    dctBlock[0][col]=(x3+x4)>>11;  dctBlock[1][col]=(x2+x8)>>11;
    dctBlock[2][col]=(x0+x1)>>11;  dctBlock[3][col]=(x7+x5)>>11;
    dctBlock[4][col]=(x7-x5)>>11;  dctBlock[5][col]=(x0-x1)>>11;
    dctBlock[6][col]=(x2-x8)>>11;  dctBlock[7][col]=(x3-x4)>>11;
  }
}

/***************************************************************************/

void test2dAccuracy(int maxIterations,
		    void (*testFunc)(int (*)[8]),
		    char *testFuncName,
		    void (*referenceFunc)(int (*)[8]),
		    char *referenceFuncName) {
  int input[8][8], reference[8][8], test[8][8];
  int iteration;
  int totalCoefficients=0; /* Total number of coefficients tested */
  int errorCoefficients[4]={0}; /* # coefficients out of range */
  double squareError=0; /* Total squared error over all coefficients */
  double maxSquareError=0; /* Largest squared error for any block */
  int i,j;
  printf("Testing Accuracy: %s (%d iterations, comparing to %s)\n",
	 testFuncName,maxIterations,referenceFuncName);

  for(iteration=0;iteration<maxIterations;iteration++) {

    double thisSquareError = 0.0;

    /* Build random input values in range -128...127 */
    for(i=0;i<8;i++) {
      for(j=0;j<8;j++) {
	int t = Random() & 0xff;
	if(t > 127) t-= 256;
	input[i][j] = t;
      }
    }
    
    /* Compute reference version */
    memcpy(reference,input,sizeof(input));
    (*referenceFunc)(reference);
    
    /* Compute test version */
    memcpy(test,input,sizeof(input));
    (*testFunc)(test);
    
    /* Count number of errors exceeding one */
    totalCoefficients += 64;
    for(i=0;i<8;i++) {
      for(j=0;j<8;j++) {
	int err = test[i][j] - reference[i][j];
	double err2 = (double)err * (double)err;
	if(err < 0) err = -err;
	{
	  int k;
	  for(k=0;k<4;k++)
	    if(err > k) errorCoefficients[k]++;
	}
	squareError += err2;
	thisSquareError += err2;
      }
    }
    if(thisSquareError > maxSquareError)
      maxSquareError = thisSquareError;
    if(thisSquareError > 100) {
      int i,j=0;
      printf("Bad Example: mean square error = %f\n",thisSquareError/64);
      printf("Input: ");  for(i=0;i<8;i++) printf("  %4d",input[i][j]);
      printf("\nRef:   ");  for(i=0;i<8;i++) printf("  %4d",reference[i][j]);
      printf("\nTest:  ");  for(i=0;i<8;i++) printf("  %4d",test[i][j]);
      printf("\n\n");
    }
  }

  {
    int k;
    printf("   Probability of error > 0: %g",
	   (double)errorCoefficients[0] / (double)totalCoefficients);
    for(k=1;k<4;k++)
      printf(",  > %d: %g",k,
	     (double)errorCoefficients[k] / (double)totalCoefficients);
    printf("\n");
  }
  printf("   Overall mean square error: %f\n", squareError/totalCoefficients);
  printf("   Maximum mean square error: %f\n", maxSquareError / 64);
}

/***************************************************************************/

/*
 * Since the Random() function might not be infinitely fast,
 * I choose one set of random values for every hundred calls
 * to the test function.  That way, my time measures the function being
 * tested, not the random number generator.
 */

static void
test2dSpeed(int maxIterations, void (*testFunc)(int (*)[8]), char *funcName) {
  int i,j,iterations;
  static const int incr = 100;
  int input[8][8],work[8][8];
  TIME_T start = timestamp();

  printf("   %s: ",funcName); fflush(stdout);
  for(iterations = 0; iterations < maxIterations; iterations+=incr) {
    /* Build random input values in range -128...127 */
    for(i=0;i<8;i++) {
      for(j=0;j<8;j++) {
	int t = Random() & 0xff;
	if(t > 127) t-= 256;
	input[i][j] = t;
      }
    }
    for(i=0;i<incr;i++) {
      memcpy(work,input,sizeof(input));
      (*testFunc)(work);
    }
  }
  printf("%f microseconds (based on %d iterations)\n",
	 timeelapsed(start)/maxIterations * 1000000, maxIterations);
}

/***************************************************************************/

int
main(int argc, char **argv) {
  int i;
  InitRandom();

  printf("Testing 8x8-Element 2-D Forward DCT Implementation\n\n");
  {
    /* Double-check that Separable and Reference versions agree. */
    test2dAccuracy(100,dct2dSeparable,"dct2dSeparable",
		   dct2dReference,"dct2dReference");
    /* Use faster separable version as reference now */
    test2dAccuracy(5000,dct2dTest,"dct2dTest",
		   dct2dSeparable,"dct2dSeparable");
    
    printf("Measuring Speed\n");
    test2dSpeed(100,dct2dReference,"2d Reference");
    test2dSpeed(1000,dct2dSeparable,"2d Separable");
    test2dSpeed(100000,dct2dTest,"2d Test");
  }
  printf("\n\nTesting 8x8-Element 2-D IDCT Implementation\n\n");
  {
    test2dAccuracy(100,idct2dSeparable,"idct2dSeparable",
		   idct2dReference,"idct2dReference");
    /* Use faster separable version as reference now */
    test2dAccuracy(5000,idct2dTest,"idct2dTest",
		   idct2dSeparable,"idct2dSeparable");
    
    printf("Measuring Speed\n");
    test2dSpeed(100,idct2dReference,"2d Reference");
    test2dSpeed(1000,idct2dSeparable,"2d Separable");
    test2dSpeed(100000,idct2dTest,"2d Test");
  }
  return 0;
}

