#include "libscl.h"
#include <pthread.h>   // Header for pthread
#include <unistd.h>    // Header for sysconf
using namespace scl;

namespace {

  struct arg_type {
    INTEGER col;
    realmat* r_ptr;
    const realmat* a_ptr;
    const realmat* b_ptr;
    arg_type() { }
    arg_type(INTEGER j, realmat* rp, const realmat* ap, const realmat* bp) 
      : col(j), r_ptr(rp), a_ptr(ap), b_ptr(bp) 
      { }
  };

  void* mult(void* arg_ptr)
  {
     arg_type* arg = (arg_type*)(arg_ptr); 
     INTEGER j = arg->col;
     for (INTEGER k=1; k<=(*(arg->a_ptr)).ncol(); ++k) {
       REAL b_kj = (*(arg->b_ptr))(k,j);
       for (INTEGER i=1; i<=(*(arg->a_ptr)).nrow(); ++i) {
         (*(arg->r_ptr))(i,j) += b_kj * (*(arg->a_ptr))(i,k);
       }
     }
     pthread_exit((void*) 0);
  }
}

int main(int argc, char** argp, char** envp)
{
  realmat a,b;
  if(!vecread("a.dat",a) || !vecread("b.dat",b)) error("Read failed");
  if (a.ncol() != b.nrow()) error("Not conformable");
  
  realmat r(a.nrow(),b.ncol(),0.0);

  INTEGER num_threads = b.ncol();

  pthread_t threads[num_threads];
  arg_type  args[num_threads];

  pthread_attr_t attr;
  pthread_attr_init(&attr);
  pthread_attr_setdetachstate(&attr,PTHREAD_CREATE_JOINABLE);

  for (INTEGER t=0; t<num_threads; ++t) {
    args[t] = arg_type(t+1,&r,&a,&b);
    int rc = pthread_create(&threads[t], &attr, &mult, (void*)(&args[t]));
    if (rc) error("Cannot create threads");
  }
  
  pthread_attr_destroy(&attr);

  void* status;
  for (INTEGER t=0; t<num_threads; ++t) {
    int rc = pthread_join(threads[t], &status);
    if (rc) error("Cannot join threads");
  }

  std::cout << a << b << r << '\n';

  pthread_exit(NULL);
}


