/* ----------------------------------------------------------------------------

Copyright (C) 2009.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libsmm.h"
#include "emm.h"

using namespace std;
using namespace scl;
using namespace emm;
using namespace libsmm;

namespace emm {

  vector<string> 
  prospect_usrmod_variables::get_prospect_usrmod_variables(realmat& mv)
  {
    vector<string> names(14);
    names[0] = "log_consumption_growth";
    mv = log_consumption_growth;
    names[1] = "log_dividend_growth";
    mv = cbind(mv,log_dividend_growth);
    names[2] = "consumption_growth";
    mv = cbind(mv,consumption_growth);
    names[3] = "dividend_growth";
    mv = cbind(mv,dividend_growth);
    names[4] = "consumption";
    mv = cbind(mv,consumption);
    names[5] = "dividend";
    mv = cbind(mv,dividend);
    names[6] = "state_variable";
    mv = cbind(mv,state_variable);
    names[7] = "marginal_rate_of_substitution";
    mv = cbind(mv,marginal_rate_of_substitution);
    names[8] = "price_dividend_ratio";
    mv = cbind(mv,price_dividend_ratio);
    names[9] = "gross_stock_return";
    mv=cbind(mv,gross_stock_return);
    names[10] = "gross_risk_free_rate";
    mv=cbind(mv,gross_risk_free_rate);
    names[11] = "geometric_stock_return";
    mv=cbind(mv,geometric_stock_return);
    names[12] = "geometric_risk_free_rate";
    mv=cbind(mv,geometric_risk_free_rate);
    realmat log_price_dividend_ratio = price_dividend_ratio;
    for (INTEGER i=1; i<=3; ++i) log_price_dividend_ratio[i] = 0.0;
    for (INTEGER i=4; i<=nt; ++i) 
      log_price_dividend_ratio[i] = log(log_price_dividend_ratio[i]);
    names[13] = "log_price_dividend_ratio";
    mv = cbind(mv,log_price_dividend_ratio);
    return names;
  }

  prospect_usrmod::prospect_usrmod 
    (const scl::realmat& dat, INTEGER len_mod_parm,
     INTEGER len_mod_func, const std::vector<std::string>& pfvec,
     const std::vector<std::string>& alvec, std::ostream& detail)
  :  debug(&detail), debug_switch(false), data(dat),
     bootstrap_seed(770116), simulation_seed(740726)
  {
    INTEGER n_dat_vars = dat.nrow();
    if (n_dat_vars != n_datum) error("Error, prospect_usrmod constr, bad data");

    if (alvec.size() < 3) error("Error, prospect_usrmod constr, bad alvec");

    vector<string>::const_iterator alvec_ptr = alvec.begin();

    INTEGER n0 = atoi((++alvec_ptr)->substr(0,12).c_str());
    INTEGER n  = atoi((++alvec_ptr)->substr(0,12).c_str());
    INTEGER n1 = atoi((++alvec_ptr)->substr(0,12).c_str());

    mv = prospect_usrmod_variables(n0,n,n1);

    #if defined MONTHLY_FREQUENCY
      realmat zgrid(25,1), pdval(25,1);  // BHS policy function
      zgrid[ 1] =   0.00000000;    pdval[ 1] = 258.72732878;
      zgrid[ 2] =   0.12500000;    pdval[ 2] = 245.87959445;
      zgrid[ 3] =   0.25000000;    pdval[ 3] = 232.53165967;
      zgrid[ 4] =   0.37500000;    pdval[ 4] = 218.70383521;
      zgrid[ 5] =   0.50000000;    pdval[ 5] = 204.45439853;
      zgrid[ 6] =   0.62500000;    pdval[ 6] = 189.87568804;
      zgrid[ 7] =   0.75000000;    pdval[ 7] = 175.18917072;
      zgrid[ 8] =   0.87500000;    pdval[ 8] = 160.81054675;
      zgrid[ 9] =   1.00000000;    pdval[ 9] = 147.29287752;
      zgrid[10] =   1.12500000;    pdval[10] = 135.62185027;
      zgrid[11] =   1.25000000;    pdval[11] = 125.70039475;
      zgrid[12] =   1.37500000;    pdval[12] = 117.16652409;
      zgrid[13] =   1.50000000;    pdval[13] = 109.77985157;
      zgrid[14] =   1.62500000;    pdval[14] = 103.41626585;
      zgrid[15] =   1.75000000;    pdval[15] =  97.79463396;
      zgrid[16] =   1.87500000;    pdval[16] =  92.91188451;
      zgrid[17] =   2.00000000;    pdval[17] =  88.64337357;
      zgrid[18] =   2.12500000;    pdval[18] =  84.87435019;
      zgrid[19] =   2.25000000;    pdval[19] =  81.52632358;
      zgrid[20] =   2.37500000;    pdval[20] =  78.60517301;
      zgrid[21] =   2.50000000;    pdval[21] =  76.05170764;
      zgrid[22] =   2.62500000;    pdval[22] =  73.86498187;
      zgrid[23] =   2.75000000;    pdval[23] =  72.05294919;
      zgrid[24] =   2.87500000;    pdval[24] =  70.66232082;
      zgrid[25] =   3.00000000;    pdval[25] =  69.79166808;
    #else 
      /*
      realmat zgrid(25,1), pdval(25,1);  // BHS policy function k=3, b0=2
      zgrid[ 1] =   0.00000000;    pdval[ 1] =  25.06554158;
      zgrid[ 2] =   0.12500000;    pdval[ 2] =  24.50138586;
      zgrid[ 3] =   0.25000000;    pdval[ 3] =  23.86267408;
      zgrid[ 4] =   0.37500000;    pdval[ 4] =  23.13210994;
      zgrid[ 5] =   0.50000000;    pdval[ 5] =  22.28755163;
      zgrid[ 6] =   0.62500000;    pdval[ 6] =  21.31146215;
      zgrid[ 7] =   0.75000000;    pdval[ 7] =  20.20242733;
      zgrid[ 8] =   0.87500000;    pdval[ 8] =  18.98071047;
      zgrid[ 9] =   1.00000000;    pdval[ 9] =  17.76186000;
      zgrid[10] =   1.12500000;    pdval[10] =  16.88967777;
      zgrid[11] =   1.25000000;    pdval[11] =  16.10988329;
      zgrid[12] =   1.37500000;    pdval[12] =  15.45116426;
      zgrid[13] =   1.50000000;    pdval[13] =  14.86533370;
      zgrid[14] =   1.62500000;    pdval[14] =  14.34379620;
      zgrid[15] =   1.75000000;    pdval[15] =  13.88507505;
      zgrid[16] =   1.87500000;    pdval[16] =  13.47101164;
      zgrid[17] =   2.00000000;    pdval[17] =  13.09749580;
      zgrid[18] =   2.12500000;    pdval[18] =  12.76048742;
      zgrid[19] =   2.25000000;    pdval[19] =  12.45640235;
      zgrid[20] =   2.37500000;    pdval[20] =  12.18462237;
      zgrid[21] =   2.50000000;    pdval[21] =  11.94223687;
      zgrid[22] =   2.62500000;    pdval[22] =  11.72358237;
      zgrid[23] =   2.75000000;    pdval[23] =  11.52811216;
      zgrid[24] =   2.87500000;    pdval[24] =  11.36063130;
      zgrid[25] =   3.00000000;    pdval[25] =  11.23684009;
      */
      realmat zgrid(25,1), pdval(25,1);  // BHS policy function k=10, b0=2
      zgrid[ 1] =   0.00000000;    pdval[ 1] =  24.34592322;
      zgrid[ 2] =   0.16666667;    pdval[ 2] =  23.11133334;
      zgrid[ 3] =   0.33333333;    pdval[ 3] =  21.73186478;
      zgrid[ 4] =   0.50000000;    pdval[ 4] =  20.18263388;
      zgrid[ 5] =   0.66666667;    pdval[ 5] =  18.46095412;
      zgrid[ 6] =   0.83333333;    pdval[ 6] =  16.64923323;
      zgrid[ 7] =   1.00000000;    pdval[ 7] =  14.94249330;
      zgrid[ 8] =   1.16666667;    pdval[ 8] =  13.44889792;
      zgrid[ 9] =   1.33333333;    pdval[ 9] =  12.31454786;
      zgrid[10] =   1.50000000;    pdval[10] =  11.42766074;
      zgrid[11] =   1.66666667;    pdval[11] =  10.72459638;
      zgrid[12] =   1.83333333;    pdval[12] =  10.15292951;
      zgrid[13] =   2.00000000;    pdval[13] =   9.67992618;
      zgrid[14] =   2.16666667;    pdval[14] =   9.28737845;
      zgrid[15] =   2.33333333;    pdval[15] =   8.95317719;
      zgrid[16] =   2.50000000;    pdval[16] =   8.66760745;
      zgrid[17] =   2.66666667;    pdval[17] =   8.42351995;
      zgrid[18] =   2.83333333;    pdval[18] =   8.20979904;
      zgrid[19] =   3.00000000;    pdval[19] =   8.02649042;
      zgrid[20] =   3.16666667;    pdval[20] =   7.86898123;
      zgrid[21] =   3.33333333;    pdval[21] =   7.73017328;
      zgrid[22] =   3.50000000;    pdval[22] =   7.60920214;
      zgrid[23] =   3.66666667;    pdval[23] =   7.50884334;
      zgrid[24] =   3.83333333;    pdval[24] =   7.42234139;
      zgrid[25] =   4.00000000;    pdval[25] =   7.34677307;
    #endif

    BHSpdval.update(zgrid,pdval);
    
    #if defined DEBUG_EMMUSR

      debug_switch = true;

      detail << starbox("/values read from parmfile//");
      detail << '\n';
      detail << "\t\t\t\t n0 = " << n0 << '\n';
      detail << "\t\t\t\t n = " << n << '\n';
      detail << "\t\t\t\t n1 = " << n1 << '\n';
      detail.flush();

      realmat theta;
      mp.get_theta(theta);
      detail << starbox("/theta from mp struct//");
      detail << theta;
      #if defined MONTHLY_FREQUENCY
        detail << starbox("/frequency is monthly/parameter values//");
        detail << '\n';
        detail << "\t\t\t\t" << "g_c     " << mp.g_c    << '\n';
        detail << "\t\t\t\t" << "g_d     " << mp.g_d    << '\n';
        detail << "\t\t\t\t" << "sig_c   " << mp.sig_c  << '\n';
        detail << "\t\t\t\t" << "sig_d   " << mp.sig_d  << '\n';
        detail << "\t\t\t\t" << "omega   " << mp.omega  << '\n';
        detail << "\t\t\t\t" << "gamma   " << mp.gamma  << '\n';
        detail << "\t\t\t\t" << "rho     " << mp.rho    << '\n';
        detail << "\t\t\t\t" << "lambda  " << mp.lambda << '\n';
        detail << "\t\t\t\t" << "k       " << mp.k      << '\n';
        detail << "\t\t\t\t" << "b_0     " << mp.b_0    << '\n';
        detail << "\t\t\t\t" << "eta     " << mp.eta    << '\n';
        detail << '\n';
        detail << starbox("/annualized parameter values//");
        detail << "\t\t\t\t" << "g_c     " << 100.0*12.0*mp.g_c << '\n';
        detail << "\t\t\t\t" << "g_c     " << 100.0*12.0*mp.g_d << '\n';
        detail << "\t\t\t\t" << "sig_c   " << 100.0*sqrt(12)*mp.sig_c << '\n';
        detail << "\t\t\t\t" << "sig_d   " << 100.0*sqrt(12)*mp.sig_d << '\n';
        detail << "\t\t\t\t" << "omega   " << mp.omega << '\n';
        detail << "\t\t\t\t" << "gamma   " << mp.gamma << '\n';
        detail << "\t\t\t\t" << "rho     " << pow(mp.rho,12) << '\n';
        detail << "\t\t\t\t" << "lambda  " << mp.lambda << '\n';
        detail << "\t\t\t\t" << "k       " << mp.k << '\n';
        detail << "\t\t\t\t" << "b_0     " << mp.b_0*12.0 << '\n';
        detail << "\t\t\t\t" << "eta     " << pow(mp.eta,12) << '\n';
      #else
        detail << starbox("/frequency is annual/parameter values//");
        detail << '\n';
        detail << "\t\t\t\t" << "g_c     " << mp.g_c    << '\n';
        detail << "\t\t\t\t" << "g_d     " << mp.g_d    << '\n';
        detail << "\t\t\t\t" << "sig_c   " << mp.sig_c  << '\n';
        detail << "\t\t\t\t" << "sig_d   " << mp.sig_d  << '\n';
        detail << "\t\t\t\t" << "omega   " << mp.omega  << '\n';
        detail << "\t\t\t\t" << "gamma   " << mp.gamma  << '\n';
        detail << "\t\t\t\t" << "rho     " << mp.rho    << '\n';
        detail << "\t\t\t\t" << "lambda  " << mp.lambda << '\n';
        detail << "\t\t\t\t" << "k       " << mp.k      << '\n';
        detail << "\t\t\t\t" << "b_0     " << mp.b_0    << '\n';
        detail << "\t\t\t\t" << "eta     " << mp.eta    << '\n';
      #endif
      detail << '\n';
      detail.flush();

      realmat sim, stats;
      gen_sim(sim,stats);

      realmat variables;
      vector<string> names = mv.get_prospect_usrmod_variables(variables);

      INTEGER r = variables.nrow();
      INTEGER c = variables.ncol();

      realmat sum(c,1);
      
      fill(sum,0.0);
      for (INTEGER j=1; j<=c; ++j) {
        for (INTEGER i=mv.n0+1; i<=mv.n0+mv.n; ++i) {
          sum[j] += variables(i,j);
        }
      }
      realmat mean = sum/mv.n;

      fill(sum,0.0);
      for (INTEGER j=1; j<=c; ++j) {
        for (INTEGER i=mv.n0+1; i<=mv.n0+mv.n; ++i) {
          sum[j] += pow(variables(i,j)-mean[j],2);
        }
      }
      realmat var = sum/mv.n;

      detail << starbox("/means and standard errors of model variables//");
      detail << '\n';
      for (INTEGER j=1; j<=c; ++j) {
        if (names[j-1] == "consumption") {
          detail << '\t' << names[j-1] << ' ' 
            << fmt('e',45-names[j-1].size(),5,mean[j]) 
            << fmt('e',15,5,sqrt(var[j])) << '\n';
         }
	 else if (names[j-1] == "dividend") {
          detail << '\t' << names[j-1] << ' ' 
            << fmt('e',45-names[j-1].size(),5,mean[j]) 
            << fmt('e',15,5,sqrt(var[j])) << '\n';
         }
         else {
          detail << '\t' << names[j-1] << ' ' 
            << fmt('f',45-names[j-1].size(),5,mean[j]) 
            << fmt('f',15,5,sqrt(var[j])) << '\n';
         }
      }
      detail.flush();

      r = sim.nrow();
      c = sim.ncol();
      sum.resize(r,1);

      fill(sum,0.0);
      for (INTEGER j=1; j<=c; ++j) {
        for (INTEGER i=1; i<=r; ++i) {
          sum[i] += sim(i,j);
        }
      }
      mean = sum/c;

      fill(sum,0.0);
      for (INTEGER j=1; j<=c; ++j) {
        for (INTEGER i=1; i<=r; ++i) {
          sum[i] += pow(sim(i,j)-mean[i],2);
        }
      }
      var = sum/c;

      detail << starbox("/means and standard errors of annual simulation//");
      detail << '\n';
      detail << "\t\t" << "consumption growth  " 
             << fmt('f',9,5,mean[1]) <<  fmt('f',9,5,sqrt(var[1])) << '\n'; 
      detail << "\t\t" << "stock returns       " 
             << fmt('f',9,5,mean[2]) <<  fmt('f',9,5,sqrt(var[2])) << '\n'; 

      debug_switch = false;
    #endif
  }
  
  bool prospect_usrmod::make_sim(INT_32BIT& seed, realmat& sim)
  {
    realmat theta(mp.p,1);
    get_rho(theta);

    #if defined DEBUG_EMMUSR
    if (debug_switch) {
      *debug << starbox("/Entered make_sim with these parameter values//");
      *debug << theta << '\n';
      *debug << boolalpha;
      debug->flush();
    }
    #endif

    realmat R(2,2);
    R(1,1) = sqrt(1.0 - pow(mp.omega,2));
    R(2,1) = 0.0;
    R(1,2) = mp.omega;
    R(2,2) = 1.0;

    REAL del_c;
    REAL del_d;
    
    REAL cg;
    REAL dg;
    
    realmat e(2,1);
    
    REAL C = REAL_MIN;
    REAL D = REAL_MIN;

    for (INTEGER t=1; t <=mv.nt; t++) {
  
      e[1] = unsk(seed);
      e[2] = unsk(seed);
  
      e = R*e;
   
      del_c = mp.g_c + mp.sig_c*e[1];
      del_d = mp.g_d + mp.sig_d*e[2];
  
      cg = exp(del_c);
      dg = exp(del_d);
    
      C *= cg;
      D *= dg;
  
      mv.log_consumption_growth[t] = del_c;
      mv.log_dividend_growth[t] = del_d;

      mv.consumption_growth[t] = cg;
      mv.dividend_growth[t] = dg;

      mv.consumption[t] = C;
      mv.dividend[t] = D;
    }

    if (!IsFinite(C)) error("Error, usrmod::make_state, C not finite");
    if (!IsFinite(D)) error("Error, usrmod::make_state, D not finite");
  
    const INTEGER npoints = 9;
    realmat x, w;
    if( hquad(npoints,x,w) ) error("Error, usrmod, make_sim, hquad failed");
    
    const REAL root2 = sqrt(2.0);              // 1.4142135623730951
    const REAL rootpi = sqrt(4.0*atan(1.0));   // 1.7724538509055161
    
    realmat dgval(npoints,1);
    realmat dmval(npoints,1);
    realmat dgprob(npoints,1);
    for (INTEGER i=1; i<=npoints; ++i) {
      dgval[i] = exp(mp.g_d + mp.sig_d*root2*x[i]);
      dmval[i] = exp((mp.sig_d - mp.gamma*mp.omega*mp.sig_c)*root2*x[i]);
      dgprob[i] = w[i]/rootpi;
    }
    
    #if defined DEBUG_EMMUSR
    if (debug_switch) {
      *debug << starbox("/quadrature points/x, w//");
      *debug << cbind(x,w);
      *debug << starbox("/quadrature points/dgval, dgprob//");
      *debug << cbind(dgval,dgprob);
      debug->flush();
    }
    #endif

    class rootfunc : public nleqn_base {
    private:
      linear_interpolater f;
      REAL Rbar;
      REAL eta;
      REAL z_old;
      REAL dg_new;
    public:
      rootfunc
        (const linear_interpolater& f0, REAL R0, REAL e0, REAL z0, REAL dg1) 
      {
        update(f0,R0,e0);
        update(z0,dg1);
      }
      void update(const linear_interpolater& f0, REAL R0, REAL e0)
      {
        f = f0; Rbar = R0; eta = e0;
      }
      void update(REAL z0, REAL dg1) 
      {
        z_old = z0; dg_new = dg1;
      }
      REAL operator() (REAL z_new)
      {
	REAL R1 = (1.0 + f(z_new))*dg_new/f(z_old);
        REAL lhs = (z_new - 1.0 + eta)*R1;
        REAL rhs = eta*z_old*Rbar;
        return lhs - rhs;
      }
      REAL get_zmin() { return f.get_xmin();}
      REAL get_zmax() { return f.get_xmax();}
    };

    class hfunc {  // Arguments are (z_t, dg_{t+1}). 
    private:       // This differs from BHS's (z_t, \epsilon_{t+1}).
      rootfunc rf;
    public:
      hfunc(rootfunc rootf) : rf(rootf) { }
      void update(rootfunc rootf) { rf = rootf; }
      REAL operator()(REAL z_old, REAL dg_new)
      { 
        rf.update(z_old,dg_new);
        return nlroot(rf.get_zmin(), rf.get_zmax(), rf, 1.0e-10); 
      }
    };

    class Rbarfunc : public nleqn_base {
    private:
      REAL g_d;
      REAL sig_d;
      REAL eta;
      linear_interpolater f;
      rootfunc rf;
      hfunc h;
      INTEGER M;
      INTEGER N;
      realmat dg;
      realmat sv;
    public:
      Rbarfunc(REAL d, REAL s, REAL e, 
        linear_interpolater f0, rootfunc rf0, hfunc h0) 
      : g_d(d), sig_d(s), eta(e), f(f0), rf(rf0), h(h0),
        M(100), N(5001), dg(N+M,1), sv(N,1) 
      { 
        INT_32BIT seed = 411029;
        for (INTEGER t=1; t<=N+M; ++t) {
          REAL del_d = g_d + sig_d*unsk(seed);
          dg[t] = exp(del_d);
        }
      }
      void update(linear_interpolater f0, rootfunc rf0, hfunc h0)
      {
        f = f0; rf = rf0; h = h0;
      }
      REAL operator() (REAL Rbar)
      { 
        rf.update(f,Rbar,eta); 
        h.update(rf);
    
        REAL z0 = 1.0;
        REAL z1 = 1.0;
        for (INTEGER t=2; t<=M; ++t) {
          z1 = h(z0,dg[t]);
          z0 = z1;
        }
        for (INTEGER t=1; t<=N; ++t) {
          z1 = h(z0,dg[M+t]);
          sv[t] = z0 = z1;
        }
        sv.sort();
    
        return 1.0 - sv[N/2 + 1];
      }
    };
    
    class vfunc {
    private:
      REAL Rf;
      REAL lambda;
      REAL k;
    public:
      vfunc (REAL Rf0, REAL lam, REAL k0) : Rf(Rf0), lambda(lam), k(k0) { }
      REAL operator() (REAL R1, REAL z0)
      {
        if (z0 <= 1.0) {
          return R1 >= z0*Rf ? R1 - Rf : lambda*R1 + (z0 - 1.0 - lambda*z0)*Rf;
        }
        else {
          return R1 >= Rf ? R1 - Rf : (lambda + k*(z0 - 1.0))*(R1 - Rf);
        }
      }
    };

    class eulerfunc : public nleqn_base {
    private:
      prospect_usrmod_parameters mp;
      linear_interpolater f;
      hfunc h;
      vfunc v;
      realmat dgv;
      realmat dmv;
      realmat dgp;
      REAL fac1;
      REAL fac2;
      REAL z0;
    public:
      eulerfunc( prospect_usrmod_parameters m, linear_interpolater ff, 
        hfunc hf, vfunc vf, realmat dv, realmat dm, realmat p)
      : mp(m), f(ff), h(hf), v(vf), dgv(dv), dmv(dm), dgp(p), z0(1.0)
      {
        REAL term1 = mp.g_d - mp.gamma*mp.g_c;
        term1 += pow(mp.gamma*mp.sig_c,2)*(1.0 - pow(mp.omega,2))/2.0;
        fac1 = mp.rho*exp(term1);
        fac2 = mp.b_0*mp.rho;
      }
      void update_z0(REAL z) { z0 = z; }
      REAL operator() (REAL pdv)
      {
        INTEGER npoints = dgv.size();
        REAL sum1 = 0.0;
        REAL sum2 = 0.0;
        for (INTEGER i=1; i<=npoints; ++i) {
          REAL z1 = h(z0,dgv[i]);
          REAL ratio = (1.0 + f(z1))/pdv;
          sum1 += ratio*dmv[i]*dgp[i];
          sum2 += v(ratio*dgv[i],z0)*dgp[i];
        }
        sum1 *= fac1;
        sum2 *= fac2;
        return 1.0 - sum1 - sum2;
      }
    };

    const REAL zmin = 0.0;
    #if defined MONTHLY_FREQUENCY
      const REAL zmax = 3.5;
    #else
      const REAL zmax = 4.0;
      //const REAL zmax = 3.0;
    #endif
    const INTEGER intervals = 24;
    REAL increment = zmax/intervals;
    realmat zgrid;
    REAL zi = zmin;
    for (INTEGER i=0; i<=intervals; ++i) {
      zgrid.push_back(zi);
      zi += increment;
    }
    INTEGER ngrid = zgrid.size();

    realmat pdval(ngrid,1);

    for (INTEGER i=1; i<=ngrid; ++i) {
      pdval[i] = BHSpdval(zgrid[i]);
    }
    realmat pdval_lag = pdval;

    linear_interpolater f(zgrid,pdval); 

    REAL Rf = exp(mp.gamma*mp.g_c - 0.5*pow(mp.gamma*mp.sig_c,2))/mp.rho;
    REAL dg_median = exp(mp.g_d);

    REAL Rbar = (1.0 + f(1.0))*dg_median/f(1.0);
    REAL Rbar_lag = Rbar;

    const INTEGER maxiter = 2000;
    const REAL tol = 1.0e-3;
    bool converge = false;

    #if defined DEBUG_EMMUSR
    INTEGER plotlimit = 2000;
    INTEGER npts = 50;
    realmat rtmp(npts,1);
    realmat ftmp(npts,1);
    realmat rplot(npts,1);
    realmat fplot(npts,1);
    
    if (debug_switch) {
      REAL increment = (zmax - zmin)/npts;
      REAL z = zmin;
      for (INTEGER i=1; i<= npts; ++i) {
        rplot[i] = z;
        fplot[i] = z;
        z += increment;
      }
    }
    #endif

    #if defined DEBUG_EMMUSR
      string rname = "pro.rootf.dat";
      string fname = "pro.f.dat";
    #endif

    #if defined DEBUG_EMMUSR
    if (debug_switch) {
      string msg = "/Policy function iteration plot data written to//";
      msg += rname + "/" + fname + "//";
      *debug << starbox(msg.c_str());
      *debug << starbox("/Policy function iteration statistics//") << '\n';
    }
    #endif

    rootfunc rootf(f,Rbar,mp.eta,1.0,dg_median);
    hfunc h(rootf);
    Rbarfunc Rbarf(mp.g_d, mp.sig_d, mp.eta, f, rootf, h);

    for (INTEGER iter=1; iter<=maxiter; ++iter) {
      
      #if defined USE_ACCURATE_RBAR
        Rbarf.update(f,rootf,h);
        Rbar = nlroot(Rbar - 0.02, Rbar + 0.02, Rbarf, 1.0e-10);
      #else
        Rbar = (1.0 + f(1.0))*dg_median/f(1.0);
      #endif

      Rbar = (Rbar + Rbar_lag)/2.0;

      rootf.update(f,Rbar,mp.eta);
      rootf.update(1.0,dg_median);

      h.update(rootf);

      #if defined DEBUG_EMMUSR
      if (debug_switch && iter <= plotlimit) {
        for (INTEGER i=1; i<= npts; ++i) {
          REAL z = rplot(i,1);
          rtmp[i] = rootf(z);
          ftmp[i] = f(z);
        }
        rplot = cbind(rplot,rtmp);
        fplot = cbind(fplot,ftmp);
      }
      #endif

      vfunc v(Rf,mp.lambda,mp.k);

      eulerfunc euler(mp, f, h, v, dgval, dmval, dgprob);

      for (INTEGER i=1; i<=ngrid; ++i) {
        euler.update_z0(zgrid[i]);
	#if defined MONTHLY_FREQUENCY
          pdval[i] = nlroot(12.0, 600.0, euler, 1.0e-10);
	#else
          pdval[i] = nlroot(1.0, 50.0, euler, 1.0e-10);
	#endif
      }

      for (INTEGER i=1; i<=ngrid; ++i) {
        pdval[i] = (pdval[i] + pdval_lag[i])/2.0;
      }

      f.update(zgrid,pdval);

      converge = true;
      REAL relerr = 0.0;
      INTEGER erridx = 0;
      for (INTEGER i=1; i<=ngrid; ++i) {
        relerr = fabs(pdval[i] - pdval_lag[i])/fabs(pdval_lag[i] + tol);
        if (relerr > tol) {
          erridx = i;
          converge = false;
          break;
        }
      }

      #if defined DEBUG_EMMUSR
      if (debug_switch) {
        *debug << "      ";
        *debug << "iter " << iter << ' ';
        *debug << "erridx " << erridx << ' ';
        //*debug << "pdval " << pdval[erridx] << ' ';
        *debug << "relerr " << relerr << ' ';
        //*debug << "dg_median " << dg_median << ' ';
        *debug << "Rbar " << Rbar << ' ';
        //*debug << "converge " << converge <<'\n';
        *debug << '\n';
      }
      #endif

      for (INTEGER i=1; i<=ngrid; ++i) {
        pdval_lag[i] = pdval[i];
      }
      Rbar_lag = Rbar;

      if (converge) break;
    }

    #if defined DEBUG_EMMUSR
    if (debug_switch) {
      vecwrite(rname.c_str(),rplot);
      vecwrite(fname.c_str(),fplot);
      *debug << starbox("/sgrid, pdval//");
      for (INTEGER i=1; i<=ngrid; ++i) {
       *debug << "      ";
       *debug << "zgrid["<<fmt('d',2,i)<<"] = "<<fmt('f',12,8,zgrid[i])<<';'; 
       *debug << "    ";
       *debug << "pdval["<<fmt('d',2,i)<<"] = "<<fmt('f',12,8,pdval[i])<<';'; 
       *debug << '\n';
      }
      debug->flush();
    }
    #endif

    if (!converge) {
      warn("Warning, convergence failed at this parameter value");
      cerr << theta;
    }

    REAL lRf = log(Rf);
    for (INTEGER t=2; t <= mv.nt; ++t) {
      REAL z0 = mv.state_variable[t-1];
      REAL z1 = h(z0,mv.dividend_growth[t]);
      mv.state_variable[t] = z1;
      mv.price_dividend_ratio[t] = f(mv.state_variable[t]); 
      mv.gross_stock_return[t] = (1.0 + f(z1))*mv.dividend_growth[t]/f(z0);
      mv.gross_risk_free_rate[t] = Rf;
      mv.geometric_stock_return[t] = log(mv.gross_stock_return[t]);
      mv.geometric_risk_free_rate[t] = lRf;
    }

    #if defined DEBUG_EMMUSR
    if (debug_switch) {
      realmat sv = mv.state_variable(seq(mv.n0+1,mv.n0+mv.n),1);
      sv.sort();
      INTEGER N = sv.size();
      *debug << starbox("/quantiles of state vector//") << '\n';
      *debug << "\t\t" << "N      = " << N  << '\n';
      *debug << "\t\t" << "min    = " << sv[1]  << '\n';
      *debug << "\t\t" << "1%     = " << sv[INTEGER(0.01*N)+1] << '\n';
      *debug << "\t\t" << "25%    = " << sv[INTEGER(0.25*N)] << '\n';
      *debug << "\t\t" << "median = " << sv[INTEGER(0.50*N)] << '\n';
      *debug << "\t\t" << "75%    = " << sv[INTEGER(0.75*N)] << '\n';
      *debug << "\t\t" << "99%    = " << sv[INTEGER(0.99*N)] << '\n';
      *debug << "\t\t" << "max    = " << sv[N]  << '\n';
      INTEGER count = 0;
      while(sv[N--] >= 3.0) ++count;
      *debug << "\t\t" << "# >= 3 = " << count << '\n';
      *debug << "\t\t" << "Rbar   = " << Rbar   << '\n';
      *debug << "\t\t" << "dg_med = " << dg_median << '\n';
      *debug << "\t\t" << "G_d    = " << exp(mp.g_d) << '\n';
      debug->flush();
    }
    #endif

    #if defined MONTHLY_FREQUENCY
      sim.resize(n_datum,mv.n/12);
      REAL r0, C0, C12, D0, D12, PD;     
      INTEGER tt = 0;
      for (INTEGER i=1; i <= mv.n/12; ++i) {
        r0 = C0 = C12 = D0 = D12 = 0.0;
        for (INTEGER j=1; j <= 12; j++) {
          ++tt;
          r0  += mv.geometric_stock_return[mv.n0 + tt];
          C0  += mv.consumption[mv.n0 + tt];
          C12 += mv.consumption[mv.n0 + tt - 12];
          D0  += mv.dividend[mv.n0 + tt];
          D12 += mv.dividend[mv.n0 + tt - 12];
        }
        PD = mv.price_dividend_ratio[mv.n0 + tt]*mv.dividend[mv.n0 + tt]/D0;
        sim(1,i) = log(D0) - log(C0) + mp.mudc;
        sim(2,i) = log(C0) - log(C12);
        sim(3,i) = log(PD);
        sim(4,i) = r0;
      }
    #else
      sim.resize(n_datum,mv.n);
      INTEGER tt = 0;
      for (INTEGER i=1; i <= mv.n; ++i) {
        ++tt;
        sim(1,i) = mv.log_dividend_growth[mv.n0 + tt]
                   - mv.log_consumption_growth[mv.n0 + tt] + mp.mudc;
        sim(2,i) = mv.log_consumption_growth[mv.n0 + tt];
        sim(3,i) = log(mv.price_dividend_ratio[mv.n0 + tt]);
        sim(4,i) = mv.geometric_stock_return[mv.n0 + tt];
      }
    #endif

    #if defined DEBUG_EMMUSR
    if (debug_switch) {
    #if defined MONTHLY_FREQUENCY
      *debug << starbox("/First 24 simulated monthly values//");
      realmat monthly;
      vector<string> names = mv.get_prospect_usrmod_variables(monthly);
      intvec idx = seq(mv.n0+1, mv.n0+24);
      monthly = monthly(idx,"");
      for (vector<string>::size_type i=0; i<names.size(); ++i ){ 
        *debug << '\n';
        *debug << names[i] << monthly("",i+1);;
      }
      *debug << starbox("/First 24 simulated monthly values//");
      idx = seq(1,24);
      *debug << sim("",idx);
      vecwrite("pro.simulation.dat",sim);
      *debug->flush();
    #else
      *debug << starbox("/First 5 simulated annual values//");
      intvec idx = seq(1,5);
      *debug << sim("",idx);
      vecwrite("pro.simulations.dat",sim);
      *debug->flush();
    #endif
    }
    #endif    
    
    return converge; 
  }
  
  bool prospect_usrmod::make_sim(INT_32BIT& seed, realmat& sim, realmat& stats)
  {
    bool success = make_sim(seed,sim);
  
    if (success) {
    
      REAL mean_rf = 0.0;
      REAL mean_sr = 0.0;
  
      for (INTEGER t=mv.n0+1; t <= mv.n0+mv.n; t++) {
        mean_rf += mv.geometric_risk_free_rate[t];
        mean_sr += mv.geometric_stock_return[t];
      }
  
      mean_rf /= REAL(mv.n);
      mean_sr /= REAL(mv.n);
    
      REAL var_rf = 0.0;
      REAL var_sr = 0.0;
      
      for (INTEGER t=mv.n0+1; t <= mv.n0+mv.n; t++) {
        var_rf += pow(mv.geometric_risk_free_rate[t]-mean_rf, 2);
        var_sr += pow(mv.geometric_stock_return[t]-mean_sr, 2);
      }
  
      var_rf /= REAL(mv.n);
      var_sr /= REAL(mv.n);
    
      stats.resize(n_funcs,1);
  
      #if defined DEBUG_EMMUSR
        stats.check(1) = mean_rf;
        stats.check(2) = var_rf;
        stats.check(3) = mean_sr;
        stats.check(4) = var_sr;
      #else
        stats[1] = mean_rf;
        stats[2] = var_rf;
        stats[3] = mean_sr;
        stats[4] = var_sr;
      #endif

    }

    #if defined  MONTHLY_FREQUENCY
      return success;
    #else
      stats[1] /= 12.0;
      stats[2] /= 12.0;
      stats[3] /= 12.0;
      stats[4] /= 12.0;
      return success;
    #endif
  }

  bool prospect_usrmod::support(const realmat& theta) 
  {

    realmat theta_lo(n_parms,1);
    realmat theta_hi(n_parms,1);
  
    
    #if defined MONTHLY_FREQUENCY
      theta_lo[ 1] = -0.02;      // g_c
      theta_lo[ 2] = -0.02;      // g_d   
      theta_lo[ 3] =  0.0;       // sig_c 
      theta_lo[ 4] =  0.0;       // sig_d 
      theta_lo[ 5] = -1.0;       // omega
      theta_lo[ 6] =  0.25;      // gamma 
      theta_lo[ 7] =  0.9;       // rho  
      theta_lo[ 8] =  0.0;       // lambda
      theta_lo[ 9] =  0.0;       // k    
      theta_lo[10] =  0.0;       // b_0  
      theta_lo[11] =  0.25;      // eta  
      theta_lo[12] = -10.0;      // mudc

      theta_hi[ 1] =  0.02;      // g_c
      theta_hi[ 2] =  0.02;      // g_d   
      theta_hi[ 3] =  1.0;       // sig_c 
      theta_hi[ 4] =  3.0;       // sig_d 
      theta_hi[ 5] =  1.0;       // omega
      theta_hi[ 6] = 10.0;       // gamma 
      theta_hi[ 7] =  0.99999;   // rho  
      theta_hi[ 8] = 10.0;       // lambda
      theta_hi[ 9] = 20.0;       // k    
      theta_hi[10] = 10.0;       // b_0  
      theta_hi[11] =  1.0;       // eta  
      theta_hi[12] = 10.0;       // mudc
    #else
      theta_lo[ 1] = -0.24;      // g_c
      theta_lo[ 2] = -0.24;      // g_d   
      theta_lo[ 3] =  0.0;       // sig_c 
      theta_lo[ 4] =  0.0;       // sig_d 
      theta_lo[ 5] = -1.0;       // omega
      theta_lo[ 6] =  0.25;      // gamma 
      theta_lo[ 7] =  0.9;       // rho  
      theta_lo[ 8] =  0.0;       // lambda
      theta_lo[ 9] =  0.0;       // k    
      theta_lo[10] =  0.0;       // b_0  
      theta_lo[11] =  0.25;      // eta  
      theta_lo[12] = -10.0;      // mudc
    
      theta_hi[ 1] =  0.24;      // g_c
      theta_hi[ 2] =  0.24;      // g_d   
      theta_hi[ 3] =  3.46;      // sig_c 
      theta_hi[ 4] = 10.39;      // sig_d 
      theta_hi[ 5] =  1.0;       // omega
      theta_hi[ 6] = 10.0;       // gamma 
      theta_hi[ 7] =  0.9999;    // rho  
      theta_hi[ 8] = 10.0;       // lambda
      theta_hi[ 9] = 20.0;       // k    
      theta_hi[10] = 120.0;      // b_0  
      theta_hi[11] =  1.0;       // eta  
      theta_hi[12] = 10.0;       // mudc
    #endif 

    #if defined DEBUG_EMMUSR
      prospect_usrmod_parameters p;
      p.set_theta(theta);
      REAL term2 = pow(p.gamma*p.sig_c,2) 
                 - 2.0*p.gamma*p.omega*p.sig_c*p.sig_d + pow(p.sig_d,2);
      REAL tvc = log(p.rho) - p.gamma*p.g_c + p.g_d + 0.5*term2;
      if (tvc >= 0.0) warn("Warning, transversality conditions violated");
    #endif

    return ( (theta_lo < theta) && (theta < theta_hi) );
  }

  den_val prospect_usrmod::prior(const realmat& theta,const realmat& stats)
  {
    const REAL minus_log_root_two_pi = -9.1893853320467278e-01;
  
    const REAL bond = 0.000743618;             // 0.89% annualized
    const REAL sbond = (1.0/1200.0)/1.96;      // P(within 1% annualized)=0.95
  
    realmat mean(n_parms,1,0.0);
    realmat sdev(n_parms,1,100.0);

    #if defined MONTHLY_FREQUENCY
      mean[1]  = 0.0015333333333334;    // g_c  
      mean[2]  = 0.0015333333333334;    // g_d
      mean[3]  = 0.0109407876011434;    // sig_c  
      mean[4]  = 0.0346410161513775;    // sig_d
      mean[5]  = 0.15;                  // omega
      mean[6]  = 1.0;                   // gamma
      mean[7]  = 0.99831785744726;      // rho
      mean[8]  = 2.25;                  // lambda
      mean[9]  = 10.0;                  // k
      mean[10] = 0.1666666666666667;    // b_0
      mean[11] = 0.991258389045303;     // eta
      mean[12] = -3.3857;               // mudc
    #else
      mean[1]  = 0.0184;                // g_c
      mean[2]  = 0.0184;                // g_d
      mean[3]  = 0.0379;                // sig_c
      mean[4]  = 0.12;                  // sig_d
      mean[5]  = 0.15;                  // omega
      mean[6]  = 1.0;                   // gamma
      mean[7]  = 0.98;                  // rho 
      mean[8]  = 2.25;                  // lambda
      mean[9]  = 10.0;                  // k
      mean[10] = 2.0;                   // b_0
      mean[11] = 0.9;                   // eta
      mean[12] = -3.3857;               // mudc
    #endif

    for (INTEGER i=1; i<=n_parms; ++i) {
      sdev[i] = fabs(mean[i])*0.1/1.96;
    }

    den_val sum(true,0.0);
    
    REAL zbond = (stats[1]-bond)/sbond;
    REAL ebond = minus_log_root_two_pi - log(sbond) - 0.5*pow(zbond,2);
  
    sum += den_val(true,ebond);
  
    for (INTEGER i=1; i<=n_parms; ++i) {
      REAL z = (theta[i] - mean[i])/sdev[i];
      REAL e = minus_log_root_two_pi - log(sdev[i]) - 0.5*pow(z,2);
      sum += den_val(true,e);
    }
    return sum;
  }

  realmat prospect_usrmod::annualize_parms(const realmat& theta)
  {
    prospect_usrmod_parameters mp;
    mp.set_theta(theta);
  
    #if defined MONTHLY_FREQUENCY
      REAL g_c = 100.0*12.0*mp.g_c;
      REAL g_d = 100.0*12.0*mp.g_d;
      REAL sig_c = 100.0*sqrt(12)*mp.sig_c;
      REAL sig_d = 100.0*sqrt(12)*mp.sig_d;
      REAL omega = mp.omega;
      REAL gamma = mp.gamma;
      REAL rho = pow(mp.rho,12);
      REAL lambda = mp.lambda;
      REAL k = mp.k;
      REAL b_0 = mp.b_0*12.0;
      REAL eta = pow(mp.eta,12);
      REAL mudc = mp.mudc;
    
      realmat annual(n_parms,1);
    
      annual[1] = g_c;
      annual[2] = g_d;
      annual[3] = sig_c;
      annual[4] = sig_d;
      annual[5] = omega;;
      annual[6] = gamma;
      annual[7] = rho;
      annual[8] = lambda;
      annual[9] = k;
      annual[10] = b_0;
      annual[11] = eta;
      annual[12] = mudc;
    
      return annual;
    #else
      return theta;
    #endif
  }
  
  realmat prospect_usrmod::annualize_funcs(const realmat& stats)
  {
    REAL  bond_mean  = 100.0*12.0*stats[1];
    REAL  bond_sdev  = 100.0*sqrt(12.0*stats[2]);
    REAL  stock_mean = 100.0*12.0*stats[3];
    REAL  stock_sdev = 100.0*sqrt(12.0*stats[4]);
  
    realmat annual(n_funcs,1);
  
    annual[1] = bond_mean;
    annual[2] = bond_sdev;
    annual[3] = stock_mean;
    annual[4] = stock_sdev;
  
    return annual;
  }

  bool prospect_usrmod::gen_sim(realmat& sim)
  {
    INT_32BIT seed = simulation_seed; 
    realmat stats;
    return make_sim(seed, sim, stats);
  }

  bool prospect_usrmod::gen_sim(realmat& sim, realmat& stats)
  {
    INT_32BIT seed = simulation_seed; 
    return make_sim(seed, sim, stats);
  }

  bool prospect_usrmod::gen_bootstrap(vector<realmat>& bs)
  {
    INTEGER lrho = len_rho();
    vector<realmat>::size_type len_vec = 2*lrho+1;
  
    if (len_vec != bs.size()) {
      bs.resize(len_vec);
    }
  
    INTEGER n_dat = data.ncol();
    INTEGER r_dat = data.nrow();
    INTEGER len = len_vec*n_dat;
  
    realmat sim;
    bool converge =  make_sim(bootstrap_seed, sim);   
    if (!converge) return false;
  
    if (len < n_dat) {
      string msg = "Error, gen_bootstrap, simulation length must be at least";
      msg += fmt('d',7,len).get_ostr();
      error(msg);
    }
  
    realmat dat(r_dat,n_dat);
    INTEGER s = 0;
    for (vector<realmat>::size_type k=0; k<len_vec; ++k) {
      for (INTEGER j=1; j<=n_dat; ++j) {
        for (INTEGER i=1; i<=r_dat; ++i) {
          dat(i,j) = sim(i,s+j);
        }
      }
      bs[k] = dat;
      s += n_dat;
    }
  
    return true;
  }

  void  prospect_usrmod::write_usrvar(const char* filename) 
  { 
    ofstream fout(filename);
    fout << starbox("/parameters//") << '\n';
    realmat rho;
    get_rho(rho);
    fout << rho;
    fout << '\n';
    fout << starbox("/mean of model variables//") << '\n';
    realmat vars;
    vector<string> names = mv.get_prospect_usrmod_variables(vars);
    INTEGER c = vars.ncol();
    INTEGER r = vars.nrow();
    realmat ones(1,r,1.0);
    realmat mean = ones*vars/r;
    INTEGER len = 0;
    for (INTEGER i=1; i<=c; ++i) {
      INTEGER sz = names[i-1].size();
      len = len > sz ? len :  sz;
    }
    for (INTEGER i=1; i<=c; ++i) {
      string pad(len- names[i-1].size()+2,' ');
      fout << names[i-1] << pad << mean[i] << '\n';
    }
    fout << '\n';

    fout << starbox("/mean of data//");
    r = data.nrow();
    c = data.ncol();
    ones.resize(c,1,1.0);
    mean = data*ones/c;
    fout << mean;
    fout << '\n';

    fout << starbox("/mean of annual simulation//");
    realmat sim, func;
    bool success = gen_sim(sim,func);
    if (!success) error("Error, write_usrvar, gen_sim failed");
    r = sim.nrow();
    c = sim.ncol();
    ones.resize(c,1,1.0);
    mean = sim*ones/c;
    fout << mean;
    fout << '\n';
  }
}

