#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <string.h>
#include <cblas_64.h>

#define u64 unsigned long long int
#define i32 signed int
#define MOD 100

#if defined(__clang__)
#define COMP "Clang"
#elif defined(__GNUC__)
#define COMP "GCC"
#else
#define COMP "CompCert"
#endif

extern void b_saxpy(i32, float, float*, i32, float*, i32);
extern void b_saxpy_assert(i32, float, float*, i32, float*, i32);
extern void b_saxpy_unrolling(i32, float, float*, i32, float*, i32);
extern void b_saxpy_unrolling_assert(i32, float, float*, i32, float*, i32);

void print_int(u64 i) { printf("%llu\n", i); }

float* rand_vector(CBLAS_INT N) {
  float* x = malloc(sizeof(float) * N);
  for (CBLAS_INT i = 0; i < N; i++) {
    x[i] = rand() % MOD;
  }
  return x;
}

void print_vector(CBLAS_INT N, float* x) {
  printf("[ ");
  for (CBLAS_INT i = 0; i < N; i++)
    printf("%g ", x[i]);
  printf("]");
}

#define ITER 4000000
#define CHECK(N, res, res2)                     \
  for (CBLAS_INT i = 0; i < N; i++) {           \
    if (res[i] != res2[i]) {                    \
      printf("Error: Results are different!");  \
      printf(" res[%d] = %g", i, res[i]);       \
      printf("res2[%d] = %g", i, res2[i]);      \
      exit(1);                                  \
    }                                           \
  }                                             \

int main() {
  srand(time(NULL));

  CBLAS_INT N = 2000;
  clock_t t1, t2;
  float  a = rand() % MOD;
  float* sx = rand_vector(N);
  float* sy = rand_vector(N);
  float* res = malloc(sizeof(float) * N);
  float* res2 = malloc(sizeof(float) * N);
  // memcpy(res, sy, sizeof(float) * N);

  /* BLAS */

  memcpy(res, sy, sizeof(float) * N);
  t1 = clock();
  for (int i = 0; i < ITER; i++) {
    cblas_saxpy_64(N, a, sx, 1, res, 1);
  }
  t2 = clock();
  double ref_time = (double)(t2 - t1) / (double)CLOCKS_PER_SEC;

  /* ORIGINAL */

  memcpy(res2, sy, sizeof(float) * N);
  t1 = clock();
  for (int i = 0; i < ITER; i++) {
    b_saxpy(N, a, sx, 1, res2, 1);
  }
  t2 = clock();
  printf("%8s -- Original:\t\t", COMP);
  CHECK(N, res, res2);
  printf("%.2lf\n", (double)(t2 - t1) / (double)CLOCKS_PER_SEC / ref_time);

  /* LOOP UNROLLING */

  memcpy(res2, sy, sizeof(float) * N);
  t1 = clock();
  for (int i = 0; i < ITER; i++) {
    b_saxpy_unrolling(N, a, sx, 1, res2, 1);
  }
  t2 = clock();
  printf("%8s -- Unrolling:\t\t", COMP);
  CHECK(N, res, res2);
  printf("%.2lf\n", (double)(t2 - t1) / (double)CLOCKS_PER_SEC / ref_time);

  /* ASSERTIONS ADDED */

  memcpy(res2, sy, sizeof(float) * N);
  t1 = clock();
  for (int i = 0; i < ITER; i++) {
    b_saxpy_assert(N, a, sx, 1, res2, 1);
  }
  t2 = clock();
  printf("%8s -- Assertions added:\t", COMP);
  CHECK(N, res, res2);
  printf("%.2lf\n", (double)(t2 - t1) / (double)CLOCKS_PER_SEC / ref_time);

  /* LOOP UNROLLING + ASSERT */

  memcpy(res2, sy, sizeof(float) * N);
  t1 = clock();
  for (int i = 0; i < ITER; i++) {
    b_saxpy_unrolling_assert(N, a, sx, 1, res2, 1);
  }
  t2 = clock();
  printf("%8s -- Unrolling + assert:\t", COMP);
  CHECK(N, res, res2);
  printf("%.2lf\n", (double)(t2 - t1) / (double)CLOCKS_PER_SEC / ref_time);

  return 0;
}