(* Time-stamp: <modified the 28/07/2022 (at 10:49) by Erwan Jahier> *)

(* Author: Antony Zahran *)

open Printf
open Types

let variables = ref [];;
let channels = ref [];;
let main = {name = ""; ctx_type = ""; ctx_new = ""; var_in = []; var_out = []; ch_in = []; ch_out = []};;
let nodes = ref [];;
let instances = ref [];;


(** saves the data from the Yaml parser into references of type Types.t *)

(* converts a `Float list into a int list *)
let (intlist : Yaml.value list -> int list) =
  fun l ->
  let (yv_to_int : Yaml.value -> int) = fun y ->
    match y with
    |`Float f -> int_of_float f
    |_ -> assert false
  in List.map yv_to_int l
;;

let (save_variable_attributes : Types.variable -> (string * Yaml.value) -> unit) =
  fun v (key, value) -> 
  match key, value with
  |"id", `Float _f -> ()
  |"name", `String s -> v.name <- s
  |"type", `String s -> v.var_type <- s
  |_ -> assert false

let (save_variables : Yaml.value -> unit) = fun x ->
  let v = {name = ""; var_type = ""}::[]
  in
  match x with
  |`O l ->
    List.iter (save_variable_attributes (List.hd v)) l;
    variables := !variables @ v
  |_ -> assert false


let (save_channels : Yaml.value -> unit) = fun x ->
  match x with
  |`Float f -> channels := !channels @ [int_of_float f]
  |_ -> assert false


let (save_main : (string * Yaml.value) -> unit) =
  fun (key, value) -> 
  match key, value with
  |"name", `String s -> main.name <- s
  |"ctx_type", `String s -> main.ctx_type <- s
  |"ctx_new",  `String s -> main.ctx_new  <- s
  |"var_in",   `A l      -> main.var_in  <- List.map (List.nth !variables) (intlist l)
  |"var_out",  `A l      -> main.var_out <- List.map (List.nth !variables) (intlist l)
  |"ch_in",    `A l      -> main.ch_in  <- intlist l
  |"ch_out",   `A l      -> main.ch_out <- intlist l
  |_ -> assert false


let (save_node_attributes : Types.node -> (string * Yaml.value) -> unit) =
  fun n (key, value) -> 
  match key, value with
  |"id", `Float _f -> ()
  |"file_name", `String s -> n.file_name <- s
  |"fct_name",  `String s -> n.fct_name <- s
  |"ctx", `Bool b -> n.ctx <- b
  |"ctx_tab", `String s -> n.ctx_tab <- s
  |_ -> assert false

let (save_nodes : Yaml.value -> unit) = fun x ->
  let n = {file_name = ""; fct_name = ""; ctx = false; ctx_tab = ""}::[] in
  match x with
  |`O l ->
    List.iter (save_node_attributes (List.hd n)) l;
    nodes := !nodes @ n
  |_ -> assert false


let (save_instance_attributes : Types.instance -> (string * Yaml.value) -> unit) =
  fun i (key, value) -> 
  match key, value with
  |"id",   `Float f -> i.id <- int_of_float f
  |"node", `Float f -> i.node <- List.nth !nodes (int_of_float f)
  |"var_in",  `A l -> i.var_in  <- List.map (List.nth !variables) (intlist l)
  |"var_out", `A l -> i.var_out <- List.map (List.nth !variables) (intlist l)
  |"ch_in",   `A l -> i.ch_in  <- intlist l
  |"ch_out",  `A l -> i.ch_out <- intlist l
  |_ -> assert false

let (save_instances : Yaml.value -> unit) = fun x ->
  let i = {id = 0; node = List.nth !nodes 0; var_in = []; var_out = []; ch_in = []; ch_out = []}::[]
  in
  match x with
  |`O l ->
    List.iter (save_instance_attributes (List.hd i)) l;
    instances := !instances @ i
  |_ -> assert false


let (save_data2 : (string * Yaml.value) -> unit) =
  fun (key, value) ->
  match key, value with
  |"variables", `A l -> List.iter save_variables l
  |"channels" , `A l -> List.iter save_channels l
  |"main" , `O l -> List.iter save_main l
  |"nodes", `A l -> List.iter save_nodes l
  |"instances", `A l -> List.iter save_instances l
  |_ -> ()

let (save_data : Yaml.value -> unit) = fun x ->
  match x with
  |`O l -> List.iter save_data2 l
  |_ -> assert false



let yaml_file = ref ""


let main () =
  if (Array.length Sys.argv) <= 1 then (
    Arg.usage MainArgs.speclist MainArgs.usage; flush stdout; exit 2
  ) else (
    try Arg.parse MainArgs.speclist (fun s -> yaml_file := s) MainArgs.usage
    with  
    | Failure(e) -> print_string e; flush stdout; flush stderr; exit 2
    | e -> print_string (Printexc.to_string e);  flush stdout; exit 2
  );
  let yaml = Yaml_unix.of_file_exn Fpath.(v !yaml_file) in
  save_data yaml;


  (* creates the .c file *)
  let cfile = open_out (main.name ^ "_pthread.c") in

  (* includes *)
  fprintf cfile "#include <stdio.h>\n";
  fprintf cfile "#include <stdlib.h>\n";
  fprintf cfile "#include <string.h>\n";
  fprintf cfile "#include <stdbool.h>\n";
  fprintf cfile "#include <pthread.h>\n";
  fprintf cfile "#include <semaphore.h>\n";
  fprintf cfile "#include <errno.h>\n";
  fprintf cfile "#include \"%s.h\"\n" main.name;
  fprintf cfile "#include \"%s_loop_io.h\"\n" main.name;
  fprintf cfile "\n";

  (* semaphores macro *)
  fprintf cfile "/* Initialize the semaphore with the given count.  */\n";
  fprintf cfile "#define SEM_INIT(sem, v, max_v) sem_init(&(sem), 0, (v))\n";
  fprintf cfile "/* wait for the semaphore to be active (i.e. equal to 1) and then decrement it by one. */\n";
  fprintf cfile "#define SEM_WAIT(sem) while (sem_wait(&sem) != 0 && errno == EINTR) continue\n";
  fprintf cfile "/* make the semaphore active (i.e. equal to 1) */\n";
  fprintf cfile "#define SEM_SIGNAL(sem) sem_post(&(sem))\n";
  fprintf cfile "\n";

  (* variables declaration *)
  fprintf cfile "/* Declare variables */\n";
  List.iter (fun x -> fprintf cfile "%s %s;\n" x.var_type x.name) !variables;
  fprintf cfile "\n";

  (* semaphores declaration *)
  fprintf cfile "/* Declare semaphores */\n";
  List.iter (fun x -> fprintf cfile "sem_t channel%i;\n" x) !channels;
  fprintf cfile "\n";

  (* ctx declaration *)
  fprintf cfile "/* Declare context */\n";
  fprintf cfile "%s* ctx;\n" main.ctx_type;
  fprintf cfile "\n";

  (* instance loops *)
  fprintf cfile "/* Instance loops */\n";
  let print_instance_loop instance =
    fprintf cfile "void loop_%s%i() {\n" instance.node.file_name instance.id;
    fprintf cfile "    while(true) {\n";

    List.iter (fprintf cfile "        SEM_WAIT(channel%i);\n") instance.ch_in;
    fprintf cfile "        %s(" instance.node.fct_name;
    fprintf cfile "%s" (List.hd instance.var_in).name;
    List.iter (fun (x:variable) -> fprintf cfile ", %s" x.name) (List.tl instance.var_in);
    List.iter (fun (x:variable) -> fprintf cfile ", &%s" x.name) instance.var_out;
    if instance.node.ctx then fprintf cfile ", &ctx->%s[%i]" instance.node.ctx_tab instance.id;
    fprintf cfile ");\n";
    List.iter (fprintf cfile "        SEM_SIGNAL(channel%i);\n") instance.ch_out;

    fprintf cfile "    }\n";
    fprintf cfile "}\n";
    fprintf cfile "\n";
  in
  List.iter print_instance_loop !instances;

  (* main function *)
  fprintf cfile "/* Main function */\n";
  fprintf cfile "void main() {
  int _s = 0;\n";

  fprintf cfile "    /* Initialize context */\n";
  fprintf cfile "    ctx = %s(NULL);\n" main.ctx_new;
  fprintf cfile "\n";

  fprintf cfile "    /* Initialize semaphores */\n";
  List.iter (fprintf cfile "    SEM_INIT(channel%i, 0, 1);\n") !channels;
  fprintf cfile "    \n";
  fprintf cfile "    /* Declare pthreads */\n";
  List.iter (fun x -> fprintf cfile "    pthread_t pt_%s%i;\n" x.node.file_name x.id) !instances;
  fprintf cfile "    \n";
  fprintf cfile "    /* Initialize pthreads */\n";
  List.iter (fun x -> fprintf cfile "    pthread_create(&pt_%s%i, NULL, loop_%s%i, NULL);\n" x.node.file_name x.id x.node.file_name x.id) !instances;
  fprintf cfile "    \n";

  (* main loop *)
  fprintf cfile "print_rif_declaration();\n";
  output_string cfile "    /* Main loop */
 while(true) {
        if (ISATTY) printf(\"#step \\%d \\n\", _s+1);
        else if(_s) printf(\"\\n\");
        fflush(stdout);
        ++_s;\n";
  fprintf cfile "        get_inputs(ctx";
  List.iter (fun (x:variable) -> fprintf cfile ", &%s" x.name) main.var_in;
  fprintf cfile ");\n";

  List.iter (fprintf cfile "        SEM_SIGNAL(channel%i);\n") main.ch_in;
  List.iter (fprintf cfile "        SEM_WAIT(channel%i);\n") main.ch_out;

  fprintf cfile "        print_outputs(%s" (List.hd main.var_out).name;
  List.iter (fun (x:variable) -> fprintf cfile ", %s" x.name) (List.tl main.var_out);
  fprintf cfile ");\n";

  fprintf cfile "    }\n";
  fprintf cfile "    \n";

  List.iter (fun x -> fprintf cfile "    pthread_join(&pt_%s%i, NULL);\n" x.node.file_name x.id) !instances;
  fprintf cfile "}\n";
  fprintf cfile "\n"
;;

let _ = main ()