open Ppatteries
open Subcommand
open Guppy_cmdobjs

type result =
  {
    distance : float;
    p_value : float option;
  }

(* uniformly shuffle the elements of an array using the Knuth shuffle
 * http://en.wikipedia.org/wiki/Random_permutation
 * http://rosettacode.org/wiki/Knuth_shuffle#OCaml
 *)
let shuffle rng a =
  let swap i j = let x = a.(i) in a.(i) <- a.(j); a.(j) <- x in
  for i = Array.length a - 1 downto 1 do
    swap i (Gsl_rng.uniform_int rng (i+1))
  done

(* just calculate the fraction of elements of a which are geq x.
 * that's the probability that something of value x or greater was drawn from
 * the distribution of a.
 * clearly, l doesn't need to be sorted *)
let int_div x y = (float_of_int x) /. (float_of_int y)
let list_onesided_pvalue l x =
  int_div
    (List.fold_left
      (fun accu a_elt ->
        if a_elt >= x then accu+1
        else accu)
      0 l)
    (List.length l)

let get_distance r = r.distance
let get_p_value r = match r.p_value with
  | Some p -> p
  | None -> failwith "no p-value!"

(* Apply f to n_shuffles pairs generated by shuffling. *)
let map_shuffled_pres f rng n_shuffles pre1 pre2 =
  let pre_arr = Array.of_list (pre1 @ pre2)
  and n1 = List.length pre1
  and n2 = List.length pre2
  in
  let pquery_sub start len = Array.to_list (Array.sub pre_arr start len) in
  List.init
    n_shuffles
    (fun _ ->
      shuffle rng pre_arr;
      f (pquery_sub 0 n1) (pquery_sub n1 n2))

exception Uptri_dim_mismatch
let write_uptril list_output namea fun_namel ul ch =
  match ul with
  | [] -> ()
  | hd::tl ->
  if 0 =
    List.fold_left
      (fun d u -> if d = Uptri.get_dim u then d else raise Uptri_dim_mismatch)
      (Uptri.get_dim hd)
      tl
  then
    failwith "can't do anything interesting with fewer than two place files";
  if list_output then begin
    let make_line i j =
      Array.of_list
        ([namea.(i); namea.(j)] @
          (List.map (fun u -> Printf.sprintf "%g" (Uptri.get u i j)) ul))
    in
    String_matrix.write_padded ch
      (Array.of_list
        ((Array.of_list (["sample_1";"sample_2"] @ fun_namel))::
          (let m = ref [] in
          Uptri.iterij (fun i j _ -> m := ((make_line i j)::!m)) hd;
          List.rev !m)))
  end
  else begin
    List.iter2
      (fun fun_name u ->
        Printf.fprintf ch "%s distances:\n" fun_name;
        Mokaphy_common.write_named_float_uptri ch namea u;)
      fun_namel ul
  end


(* core
 * run pair_core for each unique pair
 *)
class cmd () =
object (self)
  inherit subcommand () as super
  inherit mass_cmd () as super_mass
  inherit refpkg_cmd ~required:false as super_refpkg
  inherit output_cmd () as super_output
  inherit kr_cmd () as super_kr
  inherit normalization_cmd () as super_normalization
  inherit rng_cmd () as super_rng
  inherit placefile_cmd () as super_placefile

  val list_output = flag "--list-out"
    (Plain (false, "Output the KR results as a list rather than a matrix."))
  val density = flag "--density"
    (Plain (false, "Make density plots showing the distribution of randomized \
        values with the calculated values"))
  val n_samples = flag "-s"
    (Formatted (0, "Set how many samples to use for significance calculation (0 means \
        calculate distance only). Default is %d."))
  val gaussian = flag "--gaussian"
    (Plain (false, "Use the Gaussian process approximation for p-value \
        estimation"))

  method specl =
    super_mass#specl
    @ super_refpkg#specl
    @ super_output#specl
    @ super_kr#specl
    @ super_normalization#specl
    @ super_rng#specl
    @ [
      toggle_flag list_output;
      toggle_flag density;
      int_flag n_samples;
      toggle_flag gaussian;
    ]

  method desc =
"calculates the Kantorovich-Rubinstein distance and corresponding p-values"
  method usage = "usage: kr [options] placefiles"


  (* Note that we don't call self#rng to avoid re-seeding the rng. *)
  method private pair_core rng n_samples t name1 pre1 name2 pre2 =
  let p = fv p_exp
  and normalization = self#get_normalization t
  in
  let calc_dist =
    Kr_distance.scaled_dist_of_pres ~normalization p t in
  let original_dist = calc_dist pre1 pre2 in
  let type_str = if fv gaussian then "gaussian" else "density"
  in
  {
    distance = original_dist;
    p_value =
      if 0 < n_samples then begin
        (* We must have unitized masses so shuffling works properly. Otherwise
         * the amount of mass per read will depend on its origin. *)
        let upre1 = Mass_map.Pre.unitize_mass pre1
        and upre2 = Mass_map.Pre.unitize_mass pre2
        in
        let null_dists =
          if fv gaussian then
            Gaussian_approx.pair_approx
              ~normalization rng n_samples p t upre1 upre2
          else
            map_shuffled_pres calc_dist rng n_samples upre1 upre2
        in
        if fv density then
          R_plots.write_density p type_str name1 name2 original_dist null_dists;
        Some (list_onesided_pvalue null_dists original_dist)
      end
      else None;
  }

  method private nontrivial_placefile_action prl =
    let n_samples = fv n_samples
    and pra = Array.of_list prl
    and p = fv p_exp
    and weighting, criterion = self#mass_opts
    and tax_refpkgo = match self#get_rpo with
      | None -> None
      | Some rp ->
        if Refpkg.tax_equipped rp then Some rp
        else None
    and ch = self#out_channel
    and rng = self#rng
    in
    (* in the next section, pre_f is a function which takes a pr and makes a pre,
     * and t is a gtree *)
    let uptri_of_t_pre_f (t, pre_f) =
      let prea = Array.map pre_f pra
      and namea = Array.map Placerun.get_name pra
      in
      Uptri.init
        (Array.length prea)
        (fun i j ->
          let context = Printf.sprintf "comparing %s with %s" namea.(i) namea.(j) in
          try
            self#pair_core rng n_samples t namea.(i) prea.(i) namea.(j) prea.(j)
          with
          | Kr_distance.Invalid_place_loc a ->
              invalid_arg
              (Printf.sprintf
                 "%g is not a valid placement location when %s" a context)
          | Kr_distance.Total_kr_not_zero tkr ->
              failwith
                 ("total kr_vect not zero for "^context^": "^
                    (string_of_float tkr)))
    (* here we make one of these pairs from a function which tells us how to
     * assign a branch length to a tax rank *)
    and t_pre_f_of_bl_of_rank rp bl_of_rank =
      let (taxt, ti_imap) = Tax_gtree.of_refpkg_gen bl_of_rank rp in
      (Decor_gtree.to_newick_gtree taxt,
      Mokaphy_common.make_tax_pre taxt weighting criterion ti_imap)
    and gt = Mokaphy_common.list_get_same_tree prl |> Newick_gtree.add_zero_root_bl in
    (* here we make a list of uptris, which are to get printed *)
    let uptris =
      List.map
        uptri_of_t_pre_f
        ([gt, Mass_map.Pre.of_placerun weighting criterion] @
        (match tax_refpkgo with
        | None -> []
        | Some rp ->
            List.map (t_pre_f_of_bl_of_rank rp)
                     [Tax_gtree.unit_bl; Tax_gtree.inverse]))
    (* here are a list of function names to go with those uptris *)
    and fun_names =
      List.map
        (fun s -> Printf.sprintf "%s%g" s p)
        (["Z_"] @
        (match tax_refpkgo with
        | Some _ -> ["unit_tax_Z_"; "inv_tax_Z_"]
        | None -> []))
    (* the names of the placeruns *)
    and names = Array.map Placerun.get_name pra
    and print_pvalues = n_samples > 0
    and neighborly f l = List.flatten (List.map f l)
    in
    write_uptril
      (fv list_output)
      names
      (if print_pvalues then neighborly (fun s -> [s;s^"_p_value"]) fun_names
      else fun_names)
      (if print_pvalues then
        neighborly (fun u -> [Uptri.map get_distance u; Uptri.map get_p_value u]) uptris
      else (List.map (Uptri.map get_distance) uptris))
      ch

  method private placefile_action prl =
    match List.length prl with
      | (0 | 1) as n ->
        Printf.sprintf "kr requires two or more placefiles (%d given)" n
        |> failwith

      | _ -> self#nontrivial_placefile_action prl

end
