#include "libscl.h"
#include "cuda.h"
#include "cuda_runtime.h"
#include "cublas.h" 
#include <ctime>

extern "C" {
  #include "cblas.h"
}

using namespace scl;
using namespace std;

int main(int argc, char** argp, char** envp)
{ 
    clock_t time_start, time_stop, sub_time_start, sub_time_stop;

    INT_32BIT seed = 12345;

    const INTEGER arows = 400;
    const INTEGER acols = 10000;
    const INTEGER brows = acols;
    const INTEGER bcols = 400;
    const INTEGER crows = arows;
    const INTEGER ccols = bcols;
    const INTEGER drows = crows;
    const INTEGER dcols = ccols;

    const INTEGER asize = arows*acols;
    const INTEGER bsize = brows*bcols;
    const INTEGER csize = crows*ccols;
    const INTEGER dsize = drows*dcols;

    float* a = new(nothrow) float[asize];
    if (a == 0) error("Error, operator new failed.");

    float* b = new(nothrow) float[bsize];
    if (b == 0) error("Error, operator new failed.");

    float* c = new(nothrow) float[csize];
    if (c == 0) error("Error, operator new failed.");

    float* d = new(nothrow) float[dsize];
    if (d == 0) error("Error, operator new failed.");

    for (INTEGER i=0; i<asize; ++i) a[i] = unsk(seed);
    for (INTEGER i=0; i<bsize; ++i) b[i] = unsk(seed);

    cout << '\n';

    time_start = clock();

    for (INTEGER i=0; i<csize; ++i) c[i] = 0;

    for (INTEGER j=0; j<ccols; ++j) {
      for (INTEGER k=0; k<acols; ++k) {
         for (INTEGER i=0; i<crows; ++i) {
           c[j*crows + i] += a[k*arows+i]*b[j*brows + k];
         }
      }
    }

    time_stop = clock();
    cout<<"crude mult time = "<<time_stop - time_start<< '\n';

    time_start = clock();

    cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
      arows, bcols, acols, 1.0, 
      a, arows, b, acols, 0.0, 
      c, arows);

    time_stop = clock();
    cout<<"cblas_sgemm time = "<<time_stop - time_start<< '\n';

    for (INTEGER i=0; i<csize; ++i) d[i] = c[i];

    cublasStatus stat; 

    time_start = clock();

    cublasInit(); 

    time_stop = clock();
    cout<<"cublasInit time = "<<time_stop - time_start<< '\n';

    time_start = clock();

    float* devPtrA; 

    stat = cublasAlloc(asize, sizeof(a[0]), (void**)(&devPtrA));
    if (stat != CUBLAS_STATUS_SUCCESS) { 
        printf ("device memory allocation failed"); 
        cublasShutdown(); 
        return EXIT_FAILURE; 
    } 

    stat = cublasSetMatrix(arows,acols,sizeof(a[0]),a,arows,devPtrA,arows);
    if (stat != CUBLAS_STATUS_SUCCESS) { 
        printf ("data download failed"); 
        cublasFree (devPtrA); 
        cublasShutdown(); 
        return EXIT_FAILURE; 
    } 

    float* devPtrB; 

    stat = cublasAlloc(bsize, sizeof(b[0]), (void**)(&devPtrB));
    if (stat != CUBLAS_STATUS_SUCCESS) { 
        printf ("device memory allocation failed"); 
        cublasShutdown(); 
        return EXIT_FAILURE; 
    } 

    stat = cublasSetMatrix(brows,bcols,sizeof(b[0]),b,brows,devPtrB,brows);
    if (stat != CUBLAS_STATUS_SUCCESS) { 
        printf ("data download failed"); 
        cublasFree (devPtrB); 
        cublasShutdown(); 
        return EXIT_FAILURE; 
    } 

    float* devPtrC; 

    stat = cublasAlloc (csize, sizeof(c[0]), (void**)(&devPtrC));
    if (stat != CUBLAS_STATUS_SUCCESS) { 
        printf ("device memory allocation failed"); 
        cublasShutdown(); 
        return EXIT_FAILURE; 
    } 

    sub_time_start = clock();

    cublasSgemm ('n', 'n', arows, bcols, acols, 1.0,
       devPtrA, arows, devPtrB, brows, 0.0,
       devPtrC, arows);

  cudaError_t err = cudaThreadSynchronize();
    if (err != cudaSuccess) error("Error, cudaThreadSynchronize failed");

    sub_time_stop = clock();
    cout << "cublasSgemm time = "<< sub_time_stop - sub_time_start << '\n';

    if (cublasGetError() != CUBLAS_STATUS_SUCCESS) { 
        printf ("cublasDgemm failed"); 
        cublasFree (devPtrA); 
        cublasShutdown(); 
        return EXIT_FAILURE; 
    } 

    stat = cublasGetMatrix(crows,ccols,sizeof(c[0]),devPtrC,crows,c,crows);
    if (stat != CUBLAS_STATUS_SUCCESS) { 
        printf ("data upload failed"); 
        cublasFree (devPtrC); 
        cublasShutdown(); 
        return EXIT_FAILURE; 
    } 

    cublasFree (devPtrA); 
    cublasShutdown(); 

    time_stop = clock();

    cout<<"cublas copy overhead = "
        << time_stop - time_start - sub_time_stop + sub_time_start << '\n';

    cout << '\n';
    cout << "arows = " << arows << ", acols = " << acols 
         << ", bcols = " << bcols << '\n';

    /*
    for (INTEGER i=0; i<csize; ++i) {
      if (i % arows == 0) cout << '\n';
      cout << d[i] - c[i] << ' ';
    }
    */

    cout << '\n';

    /*
    float x;
    double y;
    cout << sizeof(x) << ' ' << sizeof(y) << ' ' << sizeof(*devPtrB) << '\n';
    */

    return EXIT_SUCCESS; 
} 

