(*   
   Conversion from Abstract_syntax to Typed_syntax.

   Should catch all unknown identifiers, syntax, and typing errors.
   Side-effects in expressions are evaluated left to right.

   Copyright (C) 2011 Antoine Miné
*)

open Datatypes
open Abstract_syntax
open Typed_syntax


(************************************************************************)
(* UTILITIES *)
(************************************************************************)


let error x =
  Printf.ksprintf 
    (fun s -> failwith (Printf.sprintf "%s: error: %s" (string_of_extent x) s))


(* relative size of int / float *)

let int_rank = function
  | A_CHAR -> 1 | A_SHORT -> 2 | A_INT -> 3 | A_LONG -> 4 | A_INTEGER -> 99

let float_rank = function
  | A_FLOAT -> 1 | A_DOUBLE -> 2 | A_REAL -> 99


let new_var name synthetic x typ scope =
  let id = new_id () in
  let n = if synthetic then string_of_id name id else name in
  { var_name = n;
    var_extent = x;
    var_id = id;
    var_typ = typ;
    var_synthetic = synthetic;
    var_scope = scope;
  }


(************************************************************************)
(* ENVIRONMENT *)
(************************************************************************)


type env = {
    (* note: synthetic variables are not in env *)
    env_locals: var StringMap.t;
    env_globals: var StringMap.t;

    (* all functions and variables declared *)
    env_funcs: func StringMap.t;
    env_vars: var IdMap.t;

    (* where to store the value of a "return" statement *) 
    env_return: var option;
  }

let empty_env = 
  { env_locals = StringMap.empty;
    env_globals = StringMap.empty;
    env_funcs  = StringMap.empty;
    env_vars = IdMap.empty;
    env_return = None;
  }


(* resolve identifiers *)

let get_var env v x =
  try StringMap.find v env.env_locals
  with Not_found ->
    try StringMap.find v env.env_globals
    with Not_found ->
      error x "unknown identifier %s" v

let get_func env v x =
  try StringMap.find v env.env_funcs
  with Not_found -> error x "unknown function %s" v



(************************************************************************)
(* CASTS *)
(************************************************************************)


let cast ((e,t,x) as ee) t' x = 
  if t = t' then ee else (
  if t = A_BOOL || t' = A_BOOL then 
    error x "invalid cast, from %s to %s" (string_of_typ t) (string_of_typ t');
  T_unary (A_cast (t',x), ee), t', x
 )

let as_bool ((e,t,x) as ee) =
  if t <> A_BOOL then 
    error x "boolean expression expected, found %s" (string_of_typ t);
  ee

(* promotes to signed int, if fits in it *)
let as_int ((e,t,x) as ee) = 
  match t with
  | A_int ((A_CHAR | A_SHORT), _) -> cast ee (A_int (A_INT, A_SIGNED)) x
  | A_int _ -> ee
  | A_float _ -> ee
  | _ -> error x "integer expression expected, found %s" (string_of_typ t)
        
(* int or float *)
let as_num ((e,t,x) as ee) = 
  match t with
  | A_float _ -> ee
  | A_int _ -> as_int ee
  | _ -> 
      error x "integer or float expression expected, found %s" (string_of_typ t)
 
(* cast to a common type *)
let promote_compatible ((e1,t1,x1) as ee1) ((e2,t2,x2) as ee2) x = 
  let t = match t1,t2 with
  | A_BOOL, A_BOOL -> A_BOOL
  | A_float f1, A_float f2 -> if float_rank f1 > float_rank f2 then t1 else t2
  | A_float _, A_int _ -> t1
  | A_int _, A_float _ -> t2
  | A_int (i1,s1), A_int (i2,s2) ->
      let i1,s1 = if int_rank i1 < int_rank A_INT then A_INT,A_SIGNED else i1,s1
      and i2,s2 = if int_rank i2 < int_rank A_INT then A_INT,A_SIGNED else i2,s2
      in
      if int_rank i1 < int_rank i2 then A_int (i2,s2) else
      if int_rank i1 > int_rank i2 then A_int (i1,s1) else
      if s1 = s2 then A_int (i1,s1) else
      A_int (i1, if i1 = A_INTEGER then A_SIGNED else A_UNSIGNED)
  | _ -> 
      error x "incompatible operand types, %s and %s" 
        (string_of_typ t1) (string_of_typ t2)
  in
  cast ee1 t x1, cast ee2 t x2, t



(************************************************************************)
(* TRANSLATION *)
(************************************************************************)


(* expressions *)
(* *********** *)

(* returns an expression free of side-effect;
   side-effects are accumulated in pre / post;
   pre may create temp variables, which are deleted in post;
 *)
let rec pure_expr env pre post (e,x) = 
  match e with
    
  | A_identifier s ->
      let v = get_var env s x in
      (T_var v, v.var_typ, x), pre, post
        
  | A_float_const s ->
      (* always double type *)
      (* TODO: return a sound interval enclosing the decimal literal *)
      let f = Float.of_string s in
      (T_float_const (f,f), A_float A_DOUBLE, x), pre, post
        
  | A_int_const s ->
      let i = Int.of_string s in
      (* try these types in order: int, long, integer *)
      let t = 
        if Int.fits_int32 i then A_int (A_INT, A_SIGNED) else
        if Int.fits_int64 i then A_int (A_LONG, A_SIGNED) else
        A_int (A_INTEGER, A_SIGNED)
      in
      (T_int_const (Finite i, Finite i), t, x), pre, post

  | A_float_itv ((s1,_),(s2,_)) ->
      let f1, f2 = Float.of_string s1, Float.of_string s2 in
      (T_float_const (f1,f2), A_float A_DOUBLE, x), pre, post
        
  | A_int_itv ((s1,_),(s2,_)) ->
      let i1, i2 = Int.of_string s1, Int.of_string s2 in
      let t = 
        if Int.fits_int32 i1 && Int.fits_int32 i2 
        then A_int (A_INT, A_SIGNED) 
        else if Int.fits_int64 i1 && Int.fits_int64 i2 
        then A_int (A_LONG, A_SIGNED)
        else A_int (A_INTEGER, A_SIGNED)
      in
      (T_int_const (Finite i1, Finite i2), t, x), pre, post

  | A_bool_const b ->
      (T_bool_const (tbool_of_bool b), A_BOOL, x), pre, post
        
  | A_unary (op, e1) ->
      let e1, pre, post = pure_expr env pre post e1 in
      (match op with
        
      | A_UNARY_PLUS ->
          as_num e1, pre, post
            
      | A_UNARY_MINUS -> 
          let (_,t,_) as e1 = as_num e1 in 
          (T_unary (op, e1), t, x), pre, post
            
      | A_NOT ->
          (T_unary (op, as_bool e1), A_BOOL, x), pre, post
            
      | A_cast (t',x) ->
          cast e1 t' x, pre, post
      )
        
  | A_binary (op, e1, e2) ->
      let e1, pre, post = pure_expr env pre post e1 in
      let e2, pre, post = pure_expr env pre post e2 in
      (match op with
        
      | A_PLUS | A_MINUS | A_MULTIPLY | A_DIVIDE ->
          let e1, e2, t = promote_compatible (as_num e1) (as_num e2) x in
          (T_binary (op, e1, e2), t, x), pre, post
            
      | A_MODULO ->
          let e1, e2, t = promote_compatible (as_int e1) (as_int e2) x in
          (T_binary (op, e1, e2), t, x), pre, post
            
      | A_LESS | A_LESS_EQUAL | A_GREATER | A_GREATER_EQUAL ->
          let e1, e2, _ = promote_compatible (as_num e1) (as_num e2) x in
          (T_binary (op, e1, e2), A_BOOL, x), pre, post
            
      | A_EQUAL | A_NOT_EQUAL ->
          let e1, e2, _ = promote_compatible e1 e2 x in
          (T_binary (op, e1, e2), A_BOOL, x), pre, post
            
      | A_AND | A_OR ->
          (* TODO: shortcut && and || *)
          let e1, e2 = as_bool e1, as_bool e2 in
          (T_binary (op, e1, e2), A_BOOL, x), pre, post
      )
        
  | A_call ((s,sx), args) ->
      let ee, pre, post = call (s,sx) args env pre post x in
      (match ee with
      | None -> error x "function %s has no return value" s
      | Some ee -> ee, pre, post
      )
        
  | A_increment (l, i, A_PRE) ->
      (* ++x, --x => x+=1, x-=1 *)
      let op = 
        match i with A_INCR -> A_PLUS_ASSIGN | A_DECR -> A_MINUS_ASSIGN in
      let e = A_assign (l, Some op, (A_int_const "1", x)), x in
      pure_expr env pre post e
        
  | A_increment (l, i, A_POST) ->
      (* as ++x, --x, but appends effect after expression evaluation *)
      let op = 
        match i with A_INCR -> A_PLUS_ASSIGN | A_DECR -> A_MINUS_ASSIGN in
      let e = A_assign (l, Some op, (A_int_const "1", x)), x in
      let ee, pre1, post1 = pure_expr env [] [] e in
      ee, pre, pre1@post1@post
                             
  | A_assign ((s,sx), op, (e,ex)) ->
      (* optionally translate into v = v op e *)
      let ve = A_identifier s, sx in
      let e = match op with 
      | None -> e
      | Some A_PLUS_ASSIGN     -> A_binary (A_PLUS, ve, (e,ex))
      | Some A_MINUS_ASSIGN    -> A_binary (A_MINUS, ve, (e,ex))
      | Some A_MULTIPLY_ASSIGN -> A_binary (A_MULTIPLY, ve, (e,ex))
      | Some A_DIVIDE_ASSIGN   -> A_binary (A_DIVIDE, ve, (e,ex))
      | Some A_MODULO_ASSIGN   -> A_binary (A_MODULO, ve, (e,ex))
      in
      let ee, pre1, post1 = pure_expr env [] [] (e,ex) in
      (* cast back to the type of v *)
      let v = get_var env s x in
      if v.var_scope = T_INPUT then 
        error x "%s is an input and cannot be assigned" s;
      if v.var_scope = T_VOLATILE then 
        error x "%s is a volatile and cannot be assigned" s;
      let ee = cast ee v.var_typ x in
      (* pre = assign variable, expr = variable *)
      (T_var v, v.var_typ, x), 
      pre@pre1@[T_assign ((v,sx), ee), x]@post1,
      post

and call (s,sx) args env pre post x =
  (* resolve identifier *)
  let f = get_func env s sx in
  (* translate & bind actual arguments *)
  if List.length args <> List.length f.func_args then 
    error x "the function expects %i arguments, got %i"  
      (List.length f.func_args) (List.length args);
  let pre',post' =
    List.fold_left2
      (fun (pre,post) e v ->
        let ee, pre', post' = pure_expr env [] [] e in
        let ee = cast ee v.var_typ x in
        pre@pre'@[T_add_var (v, Some ee), x]@post',
        [T_del_var v, x]@post
      )
      ([],[]) args f.func_args
  in
  (* handle return value *)
  match f.func_return with
  | None -> 
      (* function without return *)
      None, 
      pre@pre'@[T_call (f,sx), x]@post',
      post
  | Some v ->
      (* function with return value *)
      let v1 = new_var "$returned" true x v.var_typ T_LOCAL in
      (* note: all the formal argument and return variables are deleted
         just after the call;
         the actual argument is copied into a temporary v1 to be used by the
         expression, thus, an expression can safely call the same function
         several times without conflicts in formal arguemnt and return
         variables
       *)
      Some (T_var v1, v.var_typ, x), 
      pre@pre'
      @[T_add_var (v1,None),x; 
        T_add_var (v,None),x;
        T_call (f,sx),x;
        T_assign ((v1,x), (T_var v, v.var_typ, x)),x;
        T_del_var v,x
      ]@post',
      [T_del_var v1, x]@post


(* variable declaration and initialisation *)
and add_var env (t,tx) (s,sx) i scope =
  let v = new_var s false sx t scope in
  let i = match i with
  | None -> [T_add_var (v, None), sx]
  | Some (e,x) ->
      let ee, pre, post = pure_expr env [] [] (e,x) in
      let ee = cast ee v.var_typ x in
      pre@[T_add_var (v, Some ee), sx]@post
  in
  let env =
    { env with
      env_locals = 
      if scope = T_LOCAL then StringMap.add s v env.env_locals 
      else env.env_locals;
      env_globals = 
      if scope <> T_LOCAL then StringMap.add s v env.env_globals 
      else env.env_globals;
      env_vars = IdMap.add v.var_id v env.env_vars;
    }
  in
  env, v, i



(* statements *)
(* ********** *)

(* block from statement list: adds unique labels *)
let mk_block l locs x =
  let rec aux l acc = match l with
  | [] -> acc
  | (s,x)::l -> aux l (T_stat ((new_id (), fst x), (s,x), acc))
  in
  let del = List.rev_map (fun v -> T_del_var v, (snd x, snd x)) locs in
  aux (del@(List.rev l)) (T_empty (new_id(), snd x))



(* translate a statement;
   also returns the set of created local variables, for easy destruction
*)
let rec stat env (e,x) = match e with

| A_SKIP -> env, [], []

| A_expr (A_call ((s,sx), args), x) ->
    (* unlike pure_expr, does not fail if there is no return value *)
    let _, pre, post = call (s,sx) args env [] [] x in
    env, pre@post, []
           
| A_expr (e1,x) ->
    let _, pre, post = pure_expr env [] [] (e1,x) in
    env, pre@post, []

| A_if (e1,s1,s2) ->
    let ee, pre, post = pure_expr env [] [] e1 in
    let ee = as_bool ee in
    let env', u1, l1 = stat env s1 in
    let b1 = mk_block (post@u1) l1 (snd s1) in
    let env', b2 = match s2 with
    | None -> env', T_empty (new_id(), snd x)
    | Some s2 ->
        let env', u2, l2 = stat env' s2 in
        env', mk_block (post@u2) l2 (snd s2)
    in
    { env with env_vars = env'.env_vars; }, 
    pre@[T_if (ee, b1, b2), x], 
    []
    
| A_while (e1,s1) ->
    let ee, pre, post = pure_expr env [] [] e1 in
    let (_,_,xx) as ee = as_bool ee in
    let env', u1, l1 = stat env s1 in
    let b1 = mk_block (post@u1) l1 (snd s1) in
    let lbl = new_id(), fst xx in
    { env with env_vars = env'.env_vars; }, 
    pre@[T_while (lbl, ee, b1), x], 
    []

| A_return None ->
    env, [T_RETURN, x], []

| A_return (Some e1) ->
    let ee, pre, post = pure_expr env [] [] e1 in
    (match env.env_return with
    | None -> error x "function cannot return a value"
    | Some v ->
        let ee = cast ee v.var_typ x in
        env, pre@[T_assign ((v,x), ee), x; T_RETURN, x]@post, []
    )

| A_block l -> 
    let env', r, locs' = stat_list env l in
    let del = List.map (fun v -> T_del_var v,x) locs' in
    { env with env_vars = env'.env_vars; }, r@del, []

| A_local (t,l) ->
    let env, rstats, locs =
      List.fold_left 
        (fun (env,rstats,locs) (n,i) -> 
          let env, v, i = add_var env t n i T_LOCAL in
          env, List.rev_append i rstats, v::locs
        ) 
        (env,[],[]) l
    in
    env, List.rev rstats, locs

| A_assert e ->
    let ee, pre, post = pure_expr env [] [] e in
    let ee = as_bool ee in
    env, pre@[T_assert ee, x]@post, []

| A_print l ->
    let l = List.map (fun (s,x) -> get_var env s x, x) l in
    env, [T_print l, x], []

    
and stat_list env l =
  let env, rstats, rlocs = 
    List.fold_left
      (fun (env,rstats,rlocs) s -> 
        let env, stats, locs = stat env s in
        env, List.rev_append stats rstats, List.rev_append locs rlocs
      )
      (env,[],[]) l
  in
  env, List.rev rstats, List.rev rlocs


(* declarations *)
(* ************ *)

let decl env d = 
  match d with
    
  | A_global (((t,l),x),kind) ->
      let scope = match kind with 
      | A_VARIABLE -> T_GLOBAL
      | A_INPUT -> T_INPUT
      | A_VOLATILE -> T_VOLATILE
      in
      let env, rstats =
        List.fold_left 
          (fun (env,rstats) (n,i) -> 
            let env, v, i = add_var env t n i scope in
            env, List.rev_append i rstats
          ) 
          (env,[]) l
      in
      env, List.rev rstats, []
        
  | A_function ((r,(s,sx),args,body),x) ->
      let ret = match r with
      | None -> None
      | Some (t,_) -> Some (new_var "$return" true x t T_LOCAL)
      in
      let env_body = { env with env_return = ret; } in
      let args, env_body =
        List.fold_left
          (fun (args,env) ((s,sx),(t,tx)) ->
            let v = new_var s false sx t T_LOCAL in
            v::args,
            { env with env_locals = StringMap.add s v env.env_locals; }
          )
          ([],env_body) (List.rev args)
      in
      let env', r, locs = stat_list env_body body in
      let b = mk_block r locs x in
      let f =
        { func_name = s;
          func_extent = sx;
          func_id = new_id ();
          func_return = ret;
          func_args = args;
          func_body = b;
        }
      in
      { env with 
        env_funcs = StringMap.add s f env.env_funcs; 
        env_vars = env'.env_vars; },
      [],
      [f]



(************************************************************************)
(* ENTRY POINTS *)
(************************************************************************)

        
(* translation entry point *)
let translate_program (ps:decl list ext list) : prog =
  let x = snd (List.hd ps) in
  let env, rstats, rfuncs =
    List.fold_left
      (fun (env,rstats,rfuncs) (p,_) ->
        List.fold_left
          (fun (env,rstats,rfuncs) d ->
            let env, stats, funcs = decl env d in
            env, List.rev_append stats rstats, List.rev_append funcs rfuncs
          )
          (env,rstats,rfuncs) p
      )
      (empty_env,[],[]) ps
  in
  let init = mk_block (List.rev rstats) [] x in
  let funcs =
    List.fold_left 
      (fun acc f -> StringMap.add f.func_name f acc) StringMap.empty rfuncs
  in
  init,funcs,env.env_vars
