/*-----------------------------------------------------------------------------

Copyright (C) 2004, 2006, 2007.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libsnp.h"
#include "libsmm.h"
#include "snpusr.h"
#include "snp.h"
#include "emm.h"
#include "emmusr.h"

using namespace scl;
using namespace libsnp;
using namespace libsmm;
using namespace snp;
using namespace emm;
using namespace std;

namespace {

  snpll nlopt_ll;

  class objective_function : public nleqns_base {
  private:
    INTEGER count;
  public:
    objective_function() : count(0) { }
    void reset_count() { count = 0; }
    INTEGER get_count() { return count; }
    bool get_f(const realmat& rho, realmat& f)
    {
      if (f.nrow()!=1) f.resize(1,1);
      nlopt_ll.set_rho(rho);
      realmat dllwrho;
      REAL log_like = nlopt_ll.log_likehood(dllwrho);
      f[1] = -log_like/nlopt_ll.get_datparms().n;
      ++count;
      return true;
    }
    bool get_F(const realmat& rho, realmat& f, realmat& F)
    {
      if (f.nrow()!=1) f.resize(1,1);
      nlopt_ll.set_rho(rho);
      realmat dllwrho;
      REAL log_like = nlopt_ll.log_likehood(dllwrho);
      f[1] = -log_like/nlopt_ll.get_datparms().n;
      F = -dllwrho/nlopt_ll.get_datparms().n;
      ++count;
      return true;
    }
  };
}

emm::emm_objfun::emm_objfun
  (const realmat& data, INTEGER num_mod_parms, INTEGER num_mod_funcs,
   const std::vector<std::string>& mod_pfvec,
   const std::vector<std::string>& mod_alvec,
   const std::vector<std::string>& obj_pfvec,
   const std::vector<std::string>& obj_alvec, 
   std::ostream& detail)
{
  parmfile pf;

  if(!pf.set_parms(obj_pfvec,detail)) {
     if (pf.get_optparms().print) detail.flush();
     error("Error, emm_objfun, cannot read parmfile");
  }

  string msg = "Warning, emm_objfun, ";

  iter = pf.get_optparms().itmax0; 
  if (iter < 15) {
    warn(msg + string("value for itmax0 too small, iter set to 15"));
    iter = 15;
  }

  toler = pf.get_optparms().toler; 
  if (toler < sqrt(EPS) || toler > 0.1) {
    warn(msg + string("value for toler invalid, reset to root EPS"));
    toler = sqrt(EPS);
  }
    
  realmat Y;
  
  datread_type dr; 
  dr.initialize(pf.get_datparms());
  
  Y = data;
  if((Y.get_rows()!=pf.get_datparms().M)||(Y.get_cols()!=pf.get_datparms().n)){ 
    error("Error, snp_usrmod, data rows or cols differ from snp parmfile");
  }

  if (pf.get_optparms().print) {
    detail << starbox("/First 12 observations//");
    detail << Y("",seq(1,12));
    detail << starbox("/Last 12 observations//");
    detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
  }

  trnfrm tr_tmp(pf.get_datparms(), pf.get_tranparms(), Y);
  tr = tr_tmp;

  pf.set_tranparms(tr.get_tranparms());
    
  if (pf.get_optparms().print) {
    detail << starbox("/Mean and variance of data//");
    if (pf.get_tranparms().diag) {
      detail << "(Variance has been diagonalized.)\n";
    }
    detail << pf.get_tranparms().mean << pf.get_tranparms().variance;
  }

  realmat X = Y;

  tr.normalize(Y);
  tr.normalize(X);

  if (pf.get_tranparms().squash == 1) {
    tr.spline(X);
  } else if (pf.get_tranparms().squash == 2) {
    tr.logistic(X);
  }

  if (pf.get_optparms().print) {
    detail << starbox("/First 12 normalized observations//");
    detail << Y("",seq(1,12));
    detail << starbox("/Last 12 normalized observations//");
    detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
    if (pf.get_tranparms().squash > 0) {
      detail << starbox("/First 12 transformed observations//");
      detail << Y("",seq(1,12));
      detail << starbox("/Last 12 transformed observations//");
      detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
    }
    detail.flush();
  }

  snpll ll_tmp(pf.get_datparms(),
    pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(),
    pf.get_rfunc(), pf.get_afunc_mask(), pf.get_ufunc_mask(),
    pf.get_rfunc_mask());

  ll = ll_tmp;
  ll.set_XY(&X,&Y);

  realmat dllwrho;
  ll.log_likehood(dllwrho, rho_infmat);
  rho_infmat = rho_infmat/ll.get_datparms().n;
  V = invpsd(rho_infmat);
}

emm::emm_objfun::emm_objfun() 
: tr(), ll(), V(), iter(), toler() 
{ error("Error, emm_objfun, default constructor got called"); }

emm::emm_objfun::emm_objfun(const emm_objfun& obj)
: tr(obj.tr), ll(obj.ll), V(obj.V), iter(obj.iter), toler(obj.toler) 
{ }

emm::emm_objfun& emm_objfun::operator=(const emm_objfun& obj)
{
  if (this != &obj) {
    tr = obj.tr; ll = obj.ll; V = obj.V; iter = obj.iter; toler = obj.toler;
  }
  return *this;
}

void emm::emm_objfun::set_data(const realmat& data)
{
  nlopt_ll = ll;

  realmat Y = data;

  nlopt_ll.set_n(Y.get_cols());
  
  tr.normalize(Y);

  realmat X;
  
  if (tr.get_tranparms().squash == 1) {
    X = Y;
    tr.spline(X);
    nlopt_ll.set_XY(&X,&Y);
  }
  else if (tr.get_tranparms().squash == 2) {
    X = Y;
    tr.logistic(X);
    nlopt_ll.set_XY(&X,&Y);
  }
  else {
    nlopt_ll.set_XY(&Y,&Y);
  }
  
  objective_function obj;
  nlopt minimizer(obj);  
  
  minimizer.set_iter_limit(iter);
  minimizer.set_solution_tolerance(toler);

  realmat rho_start = nlopt_ll.get_rho();
  realmat rho_stop  = rho_start; 

  minimizer.minimize(rho_start, rho_stop);

  ll.set_rho(rho_stop);
  nlopt_ll.set_rho(rho_stop);

  realmat dllwrho;
  nlopt_ll.log_likehood(dllwrho, rho_infmat);
  rho_infmat = rho_infmat/nlopt_ll.get_datparms().n;

  #if defined GNU_GPP_COMPILER
    bool bad_sim = false;
    for (INTEGER i=1; i<=rho_infmat.size(); ++i) {
        if (!finite(rho_infmat[i])) bad_sim = true;
    }
    if (bad_sim) { 
      warn("Warning, emm_objfun, set_data, bad simulation");
      V = inv(rho_infmat);
    }
    else {
      V = invpsd(rho_infmat);
    }  
  #else
    V = invpsd(rho_infmat);
  #endif
}

REAL emm::emm_objfun::operator()
  (const realmat& rho, const realmat& sim, const realmat& stats) const
{ 
  nlopt_ll = ll;

  realmat Y = sim;

  nlopt_ll.set_n(Y.get_cols());

  tr.normalize(Y);
  
  realmat X;
  
  if (tr.get_tranparms().squash ==1) {
    X = Y;
    tr.spline(X);
    nlopt_ll.set_XY(&X,&Y);
  }
  else if (tr.get_tranparms().squash == 2) {
    X = Y;
    tr.logistic(X);
    nlopt_ll.set_XY(&X,&Y);
  }
  else {
    nlopt_ll.set_XY(&Y,&Y);
  }
  
  realmat dllwrho;
  nlopt_ll.log_likehood(dllwrho);
  mean_score = dllwrho/Y.get_cols();
  realmat sn = (mean_score*V)*T(mean_score);

  REAL rv = sn[1];
  
  rv *= ll.get_datparms().n;

  return rv;
}  

bool emm::emm_objfun::write_diagnostics(const char* filename) const
{
  ofstream os(filename);
  if (!os) {
     std::string msg("Error, realmat, write_diagnostics, cannot open ");
     error(msg + filename);
  }

  os << "Score diagnostics:\n";

  realmat rho = ll.get_rho();
  realmat theta = ll.get_theta();

  snpden f=ll.get_snpden();
  afunc af=ll.get_afunc();
  ufunc uf=ll.get_ufunc();
  rfunc rf=ll.get_rfunc();

  REAL n = ll.get_datparms().n;

  realmat se(theta.size(),1,0.0);
  realmat tv(theta.size(),1,0.0);
  realmat sc(theta.size(),1,0.0);

  vector< std::pair<INTEGER,INTEGER> > rt = ll.get_rt();

  vector< pair<INTEGER,INTEGER> >::const_iterator rt_ptr;
  for (rt_ptr = rt.begin(); rt_ptr != rt.end(); ++rt_ptr) {
    se[rt_ptr->second] = sqrt(rho_infmat(rt_ptr->first,rt_ptr->first));
    sc[rt_ptr->second] = sqrt(n)*mean_score[rt_ptr->first];
  }
  for (INTEGER i=1; i<=theta.size(); ++i) {
    tv[i] = (se[i]==0.0 ? 0.0 : sc[i]/se[i]);
  }

  os << "           normalized        standard\n";
  os << "Index      mean score           error     t-statistic    descriptor\n";
  INTEGER count=0;
  for (INTEGER i=1; i<=af.get_nrowa0(); ++i) {
    ++count;
    os << fmt('d',5,count) << ' '
       << fmt('f',15,5,sc[count]) << ' '
       << fmt('f',15,5,se[count]) << ' '
       << fmt('f',15,5,tv[count])
       << "    a0["<<i<<"]   ";
    for(INTEGER k=1; k<=f.get_ly(); ++k) {
      os << af.get_alpha()[af.get_odx()[i]-1][k];
    }
    os << '\n';
  }
  for (INTEGER j=1; j<=af.get_ncolA(); ++j) {
    for (INTEGER i=1; i<=af.get_nrowA(); ++i) {
      ++count;
      os << fmt('d',5,count) << ' '
	 << fmt('f',15,5,sc[count]) << ' ' 
	 << fmt('f',15,5,se[count]) << ' '
	 << fmt('f',15,5,tv[count])
         << "    A("<<i<<','<<j<<")  ";
      for(INTEGER k=1; k<=f.get_ly(); ++k) {
        os << af.get_alpha()[af.get_idx()[i]-1][k];
      }
      os << " ";
      for(INTEGER k=1; k<=af.get_lx()*af.get_lag(); ++k) {
        os << af.get_beta()[j-1][k];
      }
      os << '\n';
    }
  }
  if (uf.is_intercept()) {
    for (INTEGER i=1; i<=uf.get_ly(); ++i) {
      ++count;
      os << fmt('d',5,count) << ' '
	 << fmt('f',15,5,sc[count]) << ' ' 
	 << fmt('f',15,5,se[count]) << ' '
	 << fmt('f',15,5,tv[count])
         << "    b0["<<i<<"] ";
      os << '\n';
    }
  }
  if (uf.is_regression()) {
    for (INTEGER j=1; j<=uf.get_lx()*uf.get_lag(); ++j) {
      for (INTEGER i=1; i<=uf.get_ly(); ++i) {
	++count;
        os << fmt('d',5,count) << ' '
	   << fmt('f',15,5,sc[count]) << ' ' 
	   << fmt('f',15,5,se[count]) << ' '
	   << fmt('f',15,5,tv[count])
	   << "    B("<<i<<','<<j<<") ";
	os << '\n';
      }
    }
  }
  for (INTEGER i=1; i<=rf.get_lR(); ++i) {
    ++count;
    os << fmt('d',5,count) << ' '
       << fmt('f',15,5,sc[count]) << ' ' 
       << fmt('f',15,5,se[count]) << ' '
       << fmt('f',15,5,tv[count])
       << "    R0["<<i<<"] ";
    os << '\n';
  }
  for (INTEGER j=1; j<=rf.get_colsP(); ++j) {
    for (INTEGER i=1; i<=rf.get_rowsP(); ++i) {
      ++count;
      os << fmt('d',5,count) << ' '
	 << fmt('f',15,5,sc[count]) << ' ' 
	 << fmt('f',15,5,se[count]) << ' '
	 << fmt('f',15,5,tv[count])
         << "    P("<<i<<','<<j<<")  " << rf.get_Ptype();
      os << '\n';
    }
  }
  for (INTEGER j=1; j<=rf.get_colsQ(); ++j) {
    for (INTEGER i=1; i<=rf.get_rowsQ(); ++i) {
      ++count;
      os << fmt('d',5,count) << ' '
	 << fmt('f',15,5,sc[count]) << ' ' 
	 << fmt('f',15,5,se[count]) << ' '
	 << fmt('f',15,5,tv[count])
         << "    Q("<<i<<','<<j<<")  " << rf.get_Qtype();
      os << '\n';
    }
  }
  for (INTEGER j=1; j<=rf.get_colsV(); ++j) {
    for (INTEGER i=1; i<=rf.get_rowsV(); ++i) {
      ++count;
      os << fmt('d',5,count) << ' '
	 << fmt('f',15,5,sc[count]) << ' ' 
	 << fmt('f',15,5,se[count]) << ' '
	 << fmt('f',15,5,tv[count])
         << "    V("<<i<<','<<j<<")  " << rf.get_Vtype();
      os << '\n';
    }
  }
  for (INTEGER j=1; j<=rf.get_colsW(); ++j) {
    for (INTEGER i=1; i<=rf.get_rowsW(); ++i) {
      ++count;
      os << fmt('d',5,count) << ' '
	 << fmt('f',15,5,sc[count]) << ' ' 
	 << fmt('f',15,5,se[count]) << ' '
	 << fmt('f',15,5,tv[count])
         << "    W("<<i<<','<<j<<")  " << rf.get_Wtype();
      os << '\n';
    }
  }

  os.flush();
  return os.good();
}

objfun_base* emm::emm_objfun::new_objfun()
{ 
  objfun_base* r = new(nothrow) emm_objfun(*this);
  if (r == 0) error("Error, emm_objfun, operator new failed");
  return r;
}

void emm::emm_objfun::delete_objfun(objfun_base* objfun_ptr)
{ 
  delete objfun_ptr;
}

emm::mle_objfun::mle_objfun
  (const realmat& dat, INTEGER len_mod_parm, INTEGER len_mod_func,
   const std::vector<std::string>& mod_pfvec,
   const std::vector<std::string>& mod_alvec,
   const std::vector<std::string>& obj_pfvec,
   const std::vector<std::string>& obj_alvec, 
   std::ostream& detail)
:  model(dat,len_mod_parm,len_mod_func,mod_pfvec,mod_alvec,detail)
{ }

emm::mle_objfun::mle_objfun(const mle_objfun& obj)
: model(obj.model)
{ }

mle_objfun& emm::mle_objfun::operator=(const mle_objfun& obj)
{
  if (this != &obj) {
    model = obj.model;
  }
  return *this;
}

REAL emm::mle_objfun::operator()
  (const realmat& rho, const realmat& sim, const realmat& stats) const
{
  model.set_rho(rho);
  realmat predicted, residuals;
  den_val dv = model.likelihood(predicted,residuals);
  return -dv.log_den;
}  

objfun_base* emm::mle_objfun::new_objfun() 
{ 
  objfun_base* objfun_ptr = new(nothrow) mle_objfun(*this); 
  if (objfun_ptr == 0) error("Error, mle_objfun, operator new failed");
  return objfun_ptr; 
}

void emm::mle_objfun::delete_objfun(objfun_base* objfun_ptr)
{ 
  delete objfun_ptr; 
}

bool emm::mle_objfun::write_diagnostics(const char* filename) const
{
  realmat predicted, residuals;

  if (!model.likelihood(predicted,residuals).positive) return false;

  if (residuals.size() > 0) {
    ofstream os(filename);
    if (!os) {
       std::string msg("Error, realmat, write_diagnostics, cannot open ");
       error(msg + filename);
    }

    vecwrite(os,residuals);

    os.flush();
    return os.good();
  }
  else {
    return true;
  }
}

