require "numru/netcdf"
require "narray_miss"

module NumRu

  class NetCDFVar

    def get_with_miss(*args)
      __interpret_missing_params if !defined?(@missval)
      data = simple_get(*args)
      if @vmin || @vmax
	if @vmin
	  mask = (data >= @vmin) 
	  mask = mask.and(data <= @vmax) if @vmax
	else
	  mask = (data <= @vmax)
	end
	data = NArrayMiss.to_nam(data, mask)
      elsif @missval	# only missing_value is present.
	mask = (data.ne(@missval)) 
	data = NArrayMiss.to_nam(data, mask)
      end
      data
    end

    def get_with_miss_and_scaling(*args)
=begin
      __interpret_missing_params if !defined?(@missval)
      data = scaled_get(*args)
      if @vmin || @vmax
	if @vmin
	  mask = (data >= @vmin) 
	  mask = mask.and(data <= @vmax) if @vmax
	else
	  mask = (data <= @vmax)
	end
	data = NArrayMiss.to_nam(data, mask)
      elsif @missval	# only missing_value is present.
	mask = (data.ne(@missval))
	data = NArrayMiss.to_nam(data, mask)
      end
      data
=end
      get_with_miss_and_scaling2(*args)
    end

    def put_with_miss(data, *args)
      if data.is_a?( NArrayMiss )
	__interpret_missing_params if !defined?(@missval)
	if @missval
	  simple_put(data.to_na(@missval), *args)
	else
	  simple_put(data.to_na, *args)
	end
      else
	simple_put(data, *args)
      end
    end


# COARDS $B5,LsMQ$N(B put $B%a%=%C%I$rMxMQ(B.
    def put_with_miss_and_scaling(data, *args)
      if data.is_a?( NArrayMiss )
	__interpret_missing_params if !defined?(@missval)
	if @missval
#	  scaled_put(data.to_na(@missval), *args) # Mask $B%G!<%?$r:n$k$H$-$O%3%a%s%H$O$:$9(B
	  scaled_put_without_missval(data, *args) # Mask $B%G!<%?$r:n$k$H$-$O$3$A$i$r%3%a%s%H%"%&%H(B
	else
	  scaled_put(data.to_na, *args)
	end
      else
	scaled_put(data, *args)
      end
    end

## ------- defined by daktu32 for dcchart (its only way to change beahvier for deep GPhys...)

    def scaled_put_without_missval(var, hash=nil)
      sf = att('scale_factor')
      ao = att('add_offset')
      if ( sf == nil && ao == nil ) then
	# no scaling --> just call put
	var = var.to_na(@missval[0])
	simple_put(var,hash)
      else
	if (sf != nil) 
	  csf = sf.get
	  if csf.is_a?(NArray) then  # --> should be a numeric
	    csf = csf[0]
	  elsif csf.is_a?(String)
	    raise TypeError, "scale_factor is not a numeric"
	  end
	  if(csf == 0) then; raise NetcdfError, "zero scale_factor"; end
	else
	  csf = 1.0      # assume 1 if not defined
	end
	if (ao != nil) 
	  cao = ao.get
	if cao.is_a?(NArray) then  # --> should be a numeric
	  cao = cao[0]
	elsif csf.is_a?(String)
	  raise NetcdfError, "add_offset is not a numeric"
	end
	else
	  cao = 0.0      # assume 0 if not defined
	end
	var = var.to_na( @missval[0]*csf + cao)
	simple_put( (var-cao)/csf, hash )
      end
    end  

      
    def get_with_miss_and_scaling2(*args) # make mask before scaling
      __interpret_missing_params if !defined?(@missval)
      packed_data = simple_get(*args)
      scaled_data = scaled_get(*args)
      sf = att('scale_factor')
      ao = att('add_offset')
      if @vmin || @vmax
	vmin = @vmin; vmax = @vmax
	if vmin
	  mask = (scaled_data >= vmin) 
	  mask = mask.and(scaled_data <= vmax) if vmax
	else
	  mask = (scaled_data <= vmax)
	end
	data = NArrayMiss.to_nam(scaled_data, mask)
      elsif @missval	# only missing_value is present.
	eps = 1e-6
	missval = @missval[0].to_f
	vmin = missval - missval * eps
	vmax = missval + missval * eps
	mask = (packed_data <= vmin) 
	mask = mask.or(packed_data >= vmax)
	data = NArrayMiss.to_nam(scaled_data, mask)
      else
	data = scaled_data
      end
      data
    end

    ######### private ##########

    def __interpret_missing_params
      # Interprets the specification of missing data,
      # either by valid_range, (valid_min and/or valid_max), or missing_value.
      # (unlike the User's guiede, missing_value is interpreted, but
      # valid_* has a higher precedence.)
      # Always sets @missval whether missing_value is defined or not,
      # since it will be used as a fill value for data missing.
      #
      @vmin = att('valid_min')
      @vmin = @vmin.get if @vmin  # kept in a NArray(size==1) to consv type
      @vmax = att('valid_max')
      @vmax = @vmax.get if @vmax  # kept in a NArray(size==1) to consv type
      vrange = att('valid_range')
      vrange = vrange.get if vrange
      if vrange
	vrange.sort!
	@vmin = vrange[0..0]        # kept in... (same)
	@vmax = vrange[-1..-1]      # kept in... (same)
      end
      @missval = att('missing_value')
      if @missval
	@missval = @missval.get   # kept in... (same)
	if @vmin && @vmax
	  if @vmin[0] <= @missval[0] && @missval[0] <= @vmax[0]
	    raise "missing_value #{@missval[0]} is in the valid range #{@vmin[0]}..#{@vmax[0]}"
	  end
	else
	  if @vmin && @missval[0] >= @vmin[0]
	    raise "missing_value #{@missval[0]} >= valid min #{@vmin[0]}"
	  elsif @vmax && @missval[0] <= @vmax[0]
	    raise "missing_value #{@missval[0]} <= valid min #{@vmin[0]}"
	  end
	end
      else
	if @vmin || @vmax
	  if (fill=att('_FillValue'))   # equal, not ==
	    fill = fill.get
	    if @vmin && fill[0] < @vmin[0]
	      @missval = fill
	    end
	    if !@missval && @vmax && fill[0] > @vmax[0]
	      @missval = fill
	    end
	  end
	  if !@missval
	    if @vmin
	      if @vmin[0] >= 0
		@missval = 0.99*@vmin - 1
	      else
		@missval = 1.01*@vmin - 1
	      end
	    elsif @vmax
	      if @vmax[0] >= 0
		@missval = 1.01*@vmax + 1
	      else
		@missval = 0.99*@vmax + 1
	      end
	    end
	  end
	end
      end
    end

    private :__interpret_missing_params

  end

end

if $0 == __FILE__
  include NumRu

  filename = "tmp.nc"
  print "creating ",filename,"...\n"
  file=NetCDF.create(filename)
  nx = 10
  dimx = file.def_dim("x",nx)
  xf = file.def_var("xf","sfloat",[dimx])
  xfn = file.def_var("xfn","sfloat",[dimx])
  xf.put_att("valid_range",[-1e12,1e12])
  f = 10 ** (2*NArray.sfloat(nx).indgen!)
  xr = file.def_var("xr","sint",[dimx])
  xr.put_att("valid_max",[0.7])
  xr.put_att("scale_factor",1e-4)
  xr.put_att("add_offset",0.5)
  r = NArray.sfloat(nx).random!
  file.enddef
  xf.put(f)
  xfn.put(f)
  xr.scaled_put(r)
  file.close

  file = NetCDF.open(filename,'r+')
  xf = file.var('xf')
  xfn = file.var('xfn')
  p "f0"
  xf.get.each{|v| print "#{v} "} ; print "\n"
  p( 'f1', nam = xf.get_with_miss )
  def xf.get(*args); get_with_miss(*args); end
  p( 'f12',  xf[2..-3].to_na )
  p( 'fn10', xfn.get_with_miss )
  p( 'fn11', xfn.get_with_miss_and_scaling )
  nam.invalidation([0,1])
  p 'f2', nam
  xf.put_with_miss(nam)
  p( 'f3', xf.get_with_miss )
  xr = file.var('xr')
  p "r0"
  xr.simple_get.each{|v| print "#{v} "} ; print "\n"
  p( 'r1', xr.get_with_miss_and_scaling )
  def xr.get(*args); get_with_miss_and_scaling(*args); end
  def xr.put(*args); put_with_miss_and_scaling(*args); end
  xr[0..3] = xr[0..3]*10
  p( 'r2', xr.get_with_miss_and_scaling )
  file.close
  print "** ncdump tmp.nc **\n", `ncdump tmp.nc`
end
