solver.ml 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
(*-----------------------------------------------------------------------
** Copyright (C) 2001 - Verimag.
** This file may only be copied under the terms of the GNU Library General
** Public License 
**-----------------------------------------------------------------------
**
** File: solver.ml
** Main author: jahier@imag.fr
*)

open Formula
12
open Env_state
13

14 15
(****************************************************************************)
(****************************************************************************)
16 17 18



19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
(****************************************************************************)

type sol_nb = (float * int)
(** This type is used to represent the number of solutions of a
   bdd. The idea is to approximate this number of solutions by
   [a.2**n] where [a] is null or a float inside the interval [[1, 2[], and
   [n] a positive integer. The reason for that logarithmic encoding
   is that the number of solutions of a formula is exponential in the
   number of variables; this means that number of solutions is likely
   to trigger overflow errors.

   In order to compute the [sol_nb] of a bdd using the ones of its
   sub-branches, we take advantage of the following equality:

     (1) [a.2**(n+p) + b.2**n = 2**(-k).(a+b.2**(-p)).2**(n+p+k)]

   where [k] is the smallest integer such that:
     [2**(-k).(a+b.2**(-p)) < 2].
*)

let rec (add_sol_nb: sol_nb -> sol_nb -> sol_nb) =
  fun (a, n) (b, m) -> 
    (** Adds two [sol_nb] [(a, n)] and [(b, m)] using the formula (1) above. *)
    if      a = 0. then (b, m) 
    else if b = 0. then (a, n)
    else let _ = assert ((a >= 1.) && (a < 2.) && (b >= 1.) && (b < 2.)) in
      if
	(n > m)
      then
	add_sol_nb (b, m) (a, n)
      else
	let p = m - n in
	let temp = a +. b *. ((2.0)**((float_of_int (-p)))) in
	let k = (int_of_float (floor ((log temp) /. (log 2.0)))) in
	let new_cst = temp *. (2.0)**((float_of_int (-k))) in
	let _ = assert (((1.0 <= new_cst) && (new_cst < 2.0)) || new_cst = 0.) in
	  (new_cst, (n+k))

let _ = assert ((add_sol_nb  (1., 0) (1., 0)) = (1., 1))
let _ = assert ((add_sol_nb  (1., 1) (1., 1)) = (1., 2))
let _ = assert ((add_sol_nb  (1., 5) (1., 5)) = (1., 6))

(* 2^2+2^3 = 12 = 0.75 * 2^4 *)
let _ = assert ((add_sol_nb  (1., 2) (1., 3)) = (1.5, 2))

let _ = assert ((add_sol_nb  (1., 0) (0., 1)) = (1., 0))
let _ = assert ((add_sol_nb  (1.453, 45) (0., 1)) = (1.453, 45))
66 67

(****************************************************************************)
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

type var = int

(** In the following, we call a comb the bdd of a conjunction of
 litterals (var). They provide the ordering in which litterals
 appear in the bdds we manipulate.
*)


let rec (split_comb: Bdd.t -> var -> var list * var list) =
  fun comb v -> 
    (* Splits [comb] into two list vars; the first one (resp the
       second one) contains variables appearing before [v] (resp
       after) in [comb]. [v] should appear in [comb].           
    *)
    let _ = assert (List.mem v (Bdd.int_of_support comb)) in
    let top = Bdd.topvar comb in
    let combt = (Bdd.dthen comb) in
      if
	v = top
      then
	([], (Bdd.int_of_support combt))
      else
	let (vars_before, vars_after) = split_comb combt v in
	  (top::vars_before, vars_after)
	  

let (get_vars_between: var -> var -> Bdd.t -> var list) =
  fun v1 v2 support -> 
    (* Returns the variables between (strictly) [v1] and [v2].
       [v1] ougth to be strictly smaller than [v2].  
    *) 
    let _ = assert ((Dd.var_of_level v1) < (Dd.var_of_level v2)) in
    let (_, vars1) = split_comb support v1 in
    let (vars2, _) = split_comb support v2 in
      Util.list_intersec vars1 vars2

let (get_remaining_vars: var -> Bdd.t -> var list) =
  fun v support -> 
    (* Returns the list of variables greater than [v]. *)
    let (_, vars) = split_comb support v in
      vars
    
111 112
(****************************************************************************)

113 114 115 116 117 118 119 120
type sol_nb_table = (Bdd.t, sol_nb * sol_nb) Hashtbl.t	      
(* 
** Associates to a bdd its number of solutions in its lhs (then) and
**  rhs (else) branches.
*)

let zero_sol = (0.0, 1)

121 122 123


(* XXX La refaire en parcourant le peigne en parallele comme j'ai fait pour le tirage. *)
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
let rec (build_sol_nb_table: Bdd.t -> sol_nb_table -> Bdd.t -> sol_nb * sol_nb) =
  fun bdd snt support -> 
    (** Returns the number of solutions of [bdd] and the ones of its
      negation. Also udpadtes the solution number table [snt] for
      for [bdd] and its negation, and recursively for all its sub-bdd.
    *)
    if 
      Bdd.is_cst bdd
    then
      if (Bdd.is_true bdd) then ((1.0, 0), zero_sol) else (zero_sol, (1.0, 0))
    else 
      let var = Bdd.topvar bdd 
      and bddt = (Bdd.dthen bdd)
      and bdde = (Bdd.delse bdd) in
      let ((a, n), (not_a, not_n)) = build_sol_nb_table bddt snt support in
      let ((b, m), (not_b, not_m)) = build_sol_nb_table bdde snt support in
	
      (* [n] and [m] correspond to the relative number of solutions
	 in the then and else branches. We need to multiply this
	 relative number of solutions by 2 to the power of the number
	 of missing variables (the unconstraint ones) between the
	 topvar of the current bdd and the one of its respective
	 child. Ditto for the negation of [bdd]. *)
      let nt = 
	if
	  Bdd.is_cst bddt
	then
	  List.length (get_remaining_vars var support)
	else
	  List.length (get_vars_between var (Bdd.topvar bddt) support)
      and ne = 
	if
	  Bdd.is_cst bdde
	then
	  List.length (get_remaining_vars var support)
	else
	  List.length (get_vars_between var (Bdd.topvar bdde) support)
      in
      let n1 = n + nt 
      and m1 = m + ne 
      and not_n1 = not_n + nt 
      and not_m1 = not_m + ne 
      in
      let sol_nb = add_sol_nb (a,n1) (b,m1)
      and not_sol_nb = add_sol_nb (not_a, not_n1) (not_b, not_m1)
      in
      let _ = (Bdd.dnot bdd) in
	Hashtbl.add snt bdd ((a, n1), (b, m1));
  	Hashtbl.add snt (Bdd.dnot bdd) ((not_b, not_m1), (not_a, not_n1)); 
        (sol_nb, not_sol_nb)
	  
(****************************************************************************)
	  
let (formula_list_to_conj: formula list -> formula) =
178
  fun fl -> 
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
    (** Transform a (non-empty) list of formula to the conjunction
       made of those formula.  
    *)
    match fl with
	[] -> assert false
      | f::[] -> f
      | f1::f2::tail -> 
          List.fold_left (fun x y -> And(x, y)) (And(f1, f2)) tail

      
let rec (formula_to_bdd : formula -> Bdd.t) =
  fun f ->
    (** Transform the formula [f] into a bdd. Also tabulates the
      result in the [bdd_tbl] field of [env_state] because the
      translation is very expensive.
    *)
    (* XXX Should I rather store only toplevel formula in this table ???
       (which would simply require to remove the following 4 lines, 
       because is_satiafiable already performs the check)
       ==> I should time profile both!
    *)
    if
      Hashtbl.mem env_state.bdd_tbl f
    then
      Hashtbl.find env_state.bdd_tbl f 
    else
      let bdd =
	match f with 
	    Not(f) ->      Bdd.dnot (formula_to_bdd f)
	  | Or(f1, f2) ->  Bdd.dor (formula_to_bdd f1) (formula_to_bdd f2)
	  | And(f1, f2) -> Bdd.dand (formula_to_bdd f1) (formula_to_bdd f2)
	      
	  | True ->        Bdd.dtrue ()
	  | False ->       Bdd.dfalse ()
	  | Bvar(vn) ->    Bdd.var (Env_state.vn_to_index vn)
	      
	  | Eq(e1, e2) ->  assert false (* XXX FIX US !!! *)
	  | Ge(e1, e2) ->  assert false (* XXX FIX US !!! *)
	  | G(e1, e2)  ->  assert false (* XXX FIX US !!! *)
      in
      let _ = match f with 
	  Not(nf) -> Hashtbl.add env_state.bdd_tbl nf (Bdd.dnot bdd)
	| _  -> Hashtbl.add env_state.bdd_tbl (Not(f)) (Bdd.dnot bdd) 
      in
	Hashtbl.add env_state.bdd_tbl f bdd;
	bdd
    

(****************************************************************************)
(****************************************************************************)

let (toss_up_one_var: var -> var * bool) =
  fun var -> 
    let ran = Random.float 1. in
      if (ran < 0.5) then (var, true) else (var, false)
234 235


236 237 238 239 240 241
let rec (draw_in_bdd: Bdd.t -> sol_nb_table -> Bdd.t -> (var * bool) list) = 
  fun bdd snt support ->
    (** Returns a draw of the variables from the topvar of [bdd] to the end
      (according to the ordering of the support).
    *)
    let _ = assert (not (Bdd.is_cst bdd)) in
242 243 244 245
    let bddvar = Bdd.topvar bdd in
    let suppvar = Bdd.topvar support in
      if
	bddvar = suppvar
246
      then
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
	let ((a, n), (b, m)) = Hashtbl.find snt bdd in
	let t = a +. b *. 2.**(float_of_int (m-n)) in
	let (bool, newbdd) =
	  (* we draw [true] with the probability [a / a + b * 2**(m-n)]. *)
	  if (b = 0.) then (true, (Bdd.dthen bdd))
	  else if (a = 0.) || (t = infinity) then (false, (Bdd.delse bdd))
	  else 
	    let ran = Random.float 1. in
	      if ran < (a /. t)
	      then (true, (Bdd.dthen bdd))
	      else (false, (Bdd.delse bdd))
	in
	  if Bdd.is_cst newbdd then 
	    if Bdd.is_false newbdd then 
	      assert false (* a branch with no solution 
			      should not have been drawn ! *)
	    else
	    (* Bdd.is_true newbdd; we toss up constraint vars *)
	      (bddvar, bool)::(List.map toss_up_one_var (Bdd.int_of_support (Bdd.dthen support)))
	  else
	    (* bddnew is not a constant *)
	    (bddvar, bool)::(draw_in_bdd newbdd snt (Bdd.dthen support))
269
      else
270 271
	(* bddvar <> suppvar *)
        (toss_up_one_var bddvar)::(draw_in_bdd bdd snt (Bdd.dthen support))
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
      
(****************************************************************************)
(****************************************************************************)


(* Exported *)
let rec (is_satisfiable: formula list -> bool) = 
(* 
   As a side effect, also updates the [bdd_tbl] field of [env_state]
   with the bbd of [fl] (so that the expensive formula to bdd
   transcrition can be reused in [solve_formula]).
*)
  fun fl -> 
    let f = formula_list_to_conj fl in
    let bdd =
      if
	Hashtbl.mem env_state.bdd_tbl f
      then
	Hashtbl.find env_state.bdd_tbl f 
      else
	let bdd0 = formula_to_bdd f in
	  Hashtbl.add env_state.bdd_tbl f bdd0 ;
	  bdd0
    in
      not (Bdd.is_false bdd)
   
	
    
(* Exported *)
let (solve_formula: int -> formula list -> var_name list -> (subst list * subst list) list) =
  fun p fl vars ->
    let support =
	List.fold_left
	  (fun acc vn -> (Bdd.dand (Bdd.var (vn_to_index vn)) acc))
	  (Bdd.dtrue ())
	  vars
    in
    let (draw_and_split : Bdd.t * sol_nb_table -> subst list * subst list) =
      fun (bdd, snt) ->
	  (* Draw values in the bdd *)
	let var_index_bool_l = draw_in_bdd bdd snt support in
313 314
	let _ = assert ((List.length var_index_bool_l) = (List.length vars)) in

315 316 317 318 319 320 321 322 323 324 325 326 327 328
	  (* Replace the indexes by the corresponding var names, 
	     and booleans by atomic_expr. *)
	let (translate : int * bool -> subst) =
	  fun (i, b) -> ((Env_state.index_to_vn i), Lurette_stub.Bool(b))
	in
        let subst_l = List.map (translate) var_index_bool_l in
	  (* Remove the types from the list of var names and types of output vars. *)
	let (out_vars, _) = List.split (Env_state.out_env_unsorted ()) in 
	  (* Split output and local vars. *)
          List.partition (fun (vn, _) -> List.mem vn out_vars) subst_l
    in
    let f = formula_list_to_conj fl in
    let bdd = Hashtbl.find env_state.bdd_tbl f in
    let snt = (Hashtbl.create 1) in
329
    let _ = build_sol_nb_table bdd snt support in 
330 331
      Util.unfold (draw_and_split) (bdd, snt) p