#include "libscl.h"
#include "svmod.h"
using namespace scl; 
using namespace std;

int main(int argc, char** argp, char** envp)
{
  INTEGER n = 100;
  INTEGER N = 5000;
  INT_32BIT seed = 780695;
  
  svmod m;

  sample s = m.draw_sample(n,seed);

  vector<realmat> smooth(N);
  vector<realmat> filter(N);
  vector<realmat> draws(N);
  REAL weights[N];
  REAL log_likelihood = 0.0;

  // Initialization

  realmat y(n+1,1);
  y[1] = 0.0;
  for (INTEGER t=2; t<=n+1; t++) {
    y[t] = s.y[t-1];
  }

  REAL sum = 0.0;
  realmat x(n+1,1);
  for (INTEGER i=0; i<N; ++i) {
    x[1] = m.draw_x0(seed);
    smooth[i] = filter[i] = draws[i] = x;
  }

  for (INTEGER t=2; t<=n+1; ++t) {

    // Importance sampling step

    sum = 0.0;
    for (INTEGER i=0; i<N; ++i) {
      draws[i][t] = m.draw_xt(draws[i][t-1],seed);
      sum += weights[i] = m.prob_yt(y[t],draws[i][t]);
    }
    log_likelihood += log(sum/REAL(N));
  
    for (INTEGER i=1; i<N; ++i)  weights[i] += weights[i-1];
    for (INTEGER i=0; i<N; ++i)  weights[i] /= sum;
    weights[N-1] = 1.0;

    // Selection step

    for (INTEGER i=0; i<N; ++i) {
      REAL u = ran(seed);
      INTEGER j = 0;
      while(weights[j] <= u) ++j;
      smooth[i] = draws[j];
      filter[i][t] = draws[j][t];
    }

    draws = smooth;
  }  
    
  realmat mean(n+1,1,0.0);
  for (INTEGER i=0; i<N; ++i) {
    mean += smooth[i];
  }
  mean = mean/N;

  realmat sdev(n+1,1,0.0);
  for (INTEGER i=0; i<N; ++i) {
    realmat z = smooth[i] - mean;
    for (INTEGER t=1; t<=n+1; ++t) sdev[t] += z[t]*z[t];
  }
  for (INTEGER t=1; t<=n+1; ++t) sdev[t] = sqrt(sdev[t]/REAL(N-1));
  
  ofstream fout;
  fout.open("smooth.csv");
  if (!fout) error("Error, smooth, cannot open fout");
  
  fout << "mean, sdev, x, y" << '\n';
  fout << mean[1] <<','<< sdev[1] <<','<< s.x0 <<','<< 0 <<'\n';
  for (INTEGER t=2; t<=n+1; ++t) {
    fout << mean[t] <<','<< sdev[t] <<','<< s.x[t-1] <<','<< s.y[t-1] <<'\n';
  }

  fout.clear(); fout.close();

  fill(mean,0.0);
  for (INTEGER i=0; i<N; ++i) {
    mean += filter[i];
  }
  mean = mean/N;

  fill(sdev,0.0);
  for (INTEGER i=0; i<N; ++i) {
    realmat z = filter[i] - mean;
    for (INTEGER t=1; t<=n+1; ++t) sdev[t] += z[t]*z[t];
  }
  for (INTEGER t=1; t<=n+1; ++t) sdev[t] = sqrt(sdev[t]/REAL(N-1));
  
  fout.open("filter.csv");
  if (!fout) error("Error, filter, cannot open fout");
  
  fout << "mean, sdev, x, y" << '\n';
  fout << mean[1] <<','<< sdev[1] <<','<< s.x0 <<','<< 0 <<'\n';
  for (INTEGER t=2; t<=n+1; ++t) {
    fout << mean[t] <<','<< sdev[t] <<','<< s.x[t-1] <<','<< s.y[t-1] <<'\n';
  }

  cout << "The log likelihood is " << log_likelihood << '\n';

  return 0;
}
