#include "na_fft.h"
#include <fftw.h>
#include <rfftw.h>

#define DEBUG

#ifdef FFTW_ENABLE_FLOAT
# define NA_FFTW_REAL NA_SFLOAT
# define NA_FFTW_COMP NA_SCOMPLEX
# define COMP_T scomplex
#else
# define NA_FFTW_REAL NA_DFLOAT
# define NA_FFTW_COMP NA_DCOMPLEX
# define COMP_T dcomplex
#endif

typedef enum {
  FFT_NO,
  FFT_YES
} DOING_FFT_DIM_FLAG;

/**** prototype declaration ****/
VALUE fft_fftw(int argc, VALUE *argv, VALUE self);
VALUE ffti_fftw(int argc, VALUE *argv, VALUE self);
static NA *fft_sub(int argc, VALUE *argv, VALUE self, fftw_direction dir);
static NA *fft_DCOMPLEX(int arg_c, int *arg_v, NA *na, fftw_direction dir);
static NA *fft_SCOMPLEX(int arg_c, int *arg_v, NA *na, fftw_direction dir);
static void check_argv(int argc, int *argv, int rank);
/**** end prototype ****/

VALUE 
fft_fftw(int argc, VALUE *argv, VALUE self)
{
  NA *na;
  na=ALLOC(NA);
  na=fft_sub(argc, argv, self, FFTW_FORWARD);
  return Data_Wrap_Struct(cNArray, 0, na_free, na);
}

VALUE 
ffti_fftw(int argc, VALUE *argv, VALUE self)
{
  NA *na;
  na=ALLOC(NA);
  na=fft_sub(argc, argv, self, FFTW_BACKWARD);
  return Data_Wrap_Struct(cNArray, 0, na_free, na);
}

NA *
fft_sub(int argc, VALUE *argv, VALUE self, fftw_direction dir)
{
  NA *na;
  int *arg_v, arg_c, *temp_argv;
  int n;

  /* functions used in this function */
  static NA *fft_DCOMPLEX();
  static NA *fft_SCOMPLEX();
  static void check_argv();

  GetNArray(self, na);

  if(argc){
    temp_argv=ALLOCA_N(int,argc);
    for(n=0 ; n<argc ; n++){
      temp_argv[n]=NUM2INT(argv[n]);    
    }
  }

  /* DEBUG */
  if(argc){
    for(n=0 ; n<argc ; n++){
      printf("argv[%d] = %d, ",n,temp_argv[n]);
    }
  } else {
    printf("argc = 0.");
  }
  printf("\n");
  /* DEBUG */

  /* $B0z?t$,0l8D0J>e$N$H$-%A%'%C%/!#0J2<$N>r7o$G%(%i!<$r%l%$%:!#(B
     $B!&0l8D$G!"0z?t$NCf?H$,(B rank $B0J>e$N$H$-(B
     $B!&Fs8D0J>e$G!"(B
       * $BIi$N?t$r4^$`$H$-(B
       * $B0z?t$NCf?H$,(B rank $B0J>e$N$H$-(B        */
  check_argv(argc, temp_argv, na->rank);

  /* VALUE argv[] => FFT_OP arg_v[rank] */

  arg_v=ALLOCA_N(int,na->rank);
  for(n=0 ; n<(na->rank) ; n++)
    arg_v[n] = FFT_NO ;    /* initializing int arg_v[rank] */

  if(argc==0){             /* $B0z?t(B $B$J$7(B */
    for(n=0 ; n<(na->rank) ; n++){
      arg_v[n] = FFT_YES ;
    }
  }else if(argc==1){
    if(*temp_argv<0){         /* $B0z?t0l8D(B $B$G!"Ii$N$H$-(B */
      for(n=0 ; n<(na->rank) ; n++){
	arg_v[n] = FFT_YES;
      }
    }else{                  /* $B0z?t0l8D(B $B$G!"@5$N$H$-(B */
      arg_v[ *temp_argv ] = FFT_YES ;
    }  
  }
  else{                     /* $B0z?t(B $BFs8D0J>e(B */
    for(n=0 ; n<argc ;n++)
      arg_v[ temp_argv[n] ] = FFT_YES ;
  }

  switch(na->type){
  case NA_DCOMPLEX:
    return fft_DCOMPLEX(arg_c, arg_v, na, dir);

  case NA_SCOMPLEX:
    return fft_SCOMPLEX(arg_c, arg_v, na, dir);

  default:
    rb_raise(rb_eTypeError, "operand is not valid type");
  }
}

NA *
fft_DCOMPLEX(int arg_c, int *arg_v, NA *na, fftw_direction dir)
{
  fftw_plan p;
  int n, m, rank,total, howmany, dim, length;
  
  dcomplex *ptr;
  char *new_ptr;

  int idist, istride;
  NA *ret_na;

  rank    = na->rank;
  total   = na->total;
  ptr     = (dcomplex *)na->ptr;

  new_ptr = ALLOC_N( char , sizeof(fftw_complex) * na->total ); 

  DEBUG printf("1...\n");

  for(dim=0 ; dim<(na->rank) ; dim++){
    if(arg_v[dim]==FFT_YES){
      if(dim == 0){
	p = fftw_create_plan(na->shape[dim], dir, FFTW_ESTIMATE);
	if(p==NULL) rb_raise(rb_eRuntimeError,"cannot allocate FFTW plan");

	idist   = na->shape[0];
	istride = 1;
	howmany = total / na->shape[0];

  DEBUG printf("2...before fftw()\n");
  DEBUG printf("idist = %d\n",idist);
  DEBUG printf("istride = %d\n",istride);
  DEBUG printf("howmany = %d\n",howmany);

	fftw(p, howmany,
	       (fftw_complex *)ptr    , istride, idist,
	       (fftw_complex *)new_ptr, istride, idist);
  DEBUG printf("3...after fftw()\n");

	fftw_destroy_plan(p);
      }else if(dim == rank - 1){
	p = fftw_create_plan(na->shape[dim], dir, FFTW_ESTIMATE);
	if(p==NULL) rb_raise(rb_eRuntimeError,"cannot allocate FFTW plan");

	idist   = 1;
	istride = 
	howmany = total / na->shape[rank - 1];

  DEBUG printf("4...before fftw()\n");
  DEBUG printf("idist = %d\n",idist);
  DEBUG printf("istride = %d\n",istride);
  DEBUG printf("howmany = %d\n",howmany);

  DEBUG printf("rank = %d\n",rank);
  DEBUG printf("total = %d\n",total);
	for(m=0 ; m<rank ; m++)
  DEBUG   printf("shape[%d] = %d\n",m,na->shape[m]);

	fftw(p, howmany,
	       (fftw_complex *)ptr    , istride, idist,
	       (fftw_complex *)new_ptr, istride, idist);
  DEBUG printf("5...after fftw()\n");
	fftw_destroy_plan(p);

      }else{
	idist   = 1;
	istride = 1; /* initialize */
	for(m=0 ; m<dim ; m++)
	  istride = istride * na->shape[m];
	howmany = istride;
	length = istride * na->shape[dim];

  DEBUG printf("6...before fftw()\n");
  DEBUG printf("idist = %d\n",idist);
  DEBUG printf("istride = %d\n",istride);
  DEBUG printf("howmany = %d\n",howmany);

	p = fftw_create_plan(na->shape[dim], dir, FFTW_ESTIMATE);
	if(p==NULL) rb_raise(rb_eRuntimeError,"cannot allocate FFTW plan");

	for(m=0 ; m<total ; m += length){
	  fftw(p, howmany,
		 (fftw_complex *)&ptr[m]    , istride, idist,
		 (fftw_complex *)&new_ptr[m], istride, idist);
  DEBUG   printf("7...after fftw()\n");
	}
	fftw_destroy_plan(p);
      }
    }
  }

  ret_na = na_alloc_struct(NA_FFTW_COMP, rank, na->shape);
  ret_na->ptr = new_ptr;

  return ret_na;
}

NA *
fft_SCOMPLEX(int arg_c, int *arg_v, NA *na, fftw_direction dir)
{
  int n;
  NA *dcmp_na;
  dcomplex *dcmp_ptr;
  scomplex *scmp_ptr;

  DEBUG printf("1s...()\n");

  dcmp_na = na_alloc_struct(NA_DCOMPLEX, na->rank, na->shape);
  dcmp_ptr = ALLOC_N(dcomplex ,na->total);
  scmp_ptr = (scomplex *)na->ptr;

  DEBUG printf("2s...()\n");

  for(n=0 ; n<(na->total) ; n++){
    dcmp_ptr[n].r = scmp_ptr[n].r;
    dcmp_ptr[n].i = scmp_ptr[n].i;
  }
  DEBUG printf("3s...()\n");

  dcmp_na->ptr = (char *)dcmp_ptr;
  return fft_DCOMPLEX(arg_c, arg_v, dcmp_na, dir);
}

static void 
check_argv(int argc, int *argv, int rank)
{
  int n;

  if(argc>rank) 
    rb_raise(rb_eArgError, "too many arguments");

  if(argc==1){
    if( (*argv)>(rank-1) )
      rb_raise(rb_eArgError, "invalid value");
  }else if(argc>1){
    for(n=0 ; n<argc ; n++){
      if( argv[n]>(rank-1) || argv[n]<0 )
	rb_raise(rb_eArgError, "invalid value");
    }
  }
  return;
}

void 
Init_fft(void)
{
  rb_define_method(cNArray, "fft"     ,fft_fftw    , -1);
  rb_define_method(cNArray, "ffti"    ,ffti_fftw   , -1);
/*rb_define_method(cNArray, "rfft"    ,rfft_fftw   , -1); */
/*rb_define_method(cNArray, "rffti"   ,rffti_fftw  , -1); */
}

