let rec equality e s =
  Trace.msg "rule" "Assert" e Fact.pp_equal;
  let (a, b, _) = Fact.d_equal e in
    if Term.eq a b then s else 
      match a, b with
        | Var _, Var _ -> 
            merge_v e s
        | App(f, _), Var _ -> 
            let i = Th.of_sym f in
              merge_i i e s
        | Var _, App(f, _) -> 
            let i = Th.of_sym f in
              merge_i i e s
        | App(f, _), App(g, _) ->
            let i = Th.of_sym f and j = Th.of_sym g in
              if Th.eq i j && Th.is_fully_interp i then
                merge_i i e s
              else 
                let (s', x') = name i (s, a) in
                let (s'', y') = name j (s', b) in
                  merge_v (Fact.mk_equal x' y' None) s''
        
(** Processing of a variable equality. *)
  
and merge_v e s =
  let propagate e = 
    List.fold_right 
      (fun i s -> 
         if is_empty i s then 
           s 
         else if Th.is_fully_interp i then
           merge_i i e s
         else 
           fuse i e s)
  in
  let (x, y, prf) = Fact.d_equal e in
    match is_equal s x y with
      | Three.Yes -> 
          s
      | Three.No -> 
          raise Exc.Inconsistent
      | Three.X ->
          Trace.msg "rule" "Merge(v)" e Fact.pp_equal;
          let (ch', p') = Partition.merge e s.p in
            s.p <- p';
            let s' = propagate e Th.all s in  (* propagate on rhs. *)  
              close_p ch' s'


(** Processing of an interpreted equality. *)

and merge_i i e s =
  let (x, y, rho) = Fact.d_equal e in
  let a = find i s x
  and b = find i s y in
    if Term.eq a b then s else
      let e' = Fact.mk_equal a b None in
        Trace.msg "rule" "Merge(i)" e' Fact.pp_equal;
        try
          let sl = Th.solve i e' in
            compose i s sl
        with 
            Exc.Incomplete ->
              let (a, b, _) = Fact.d_equal e in
              let (s, x) = name i (s, a) in
              let (s, y) = name i (s, b) in
              let e' = Fact.mk_equal x y None in
                merge_v e' s



and fuse i e s =   
  let (ch', es', eqs') = Solution.fuse i (eqs_of s i) [e] in
    Array.set s.eqs i eqs';
    let s' = Fact.Equalset.fold merge_v es' s in
      close_i i ch' s'

and compose i s r =
  let (ch', es', eqs') = Solution.compose i (eqs_of s i) r in
    Array.set s.eqs i eqs';
    let s' =  Fact.Equalset.fold merge_v es' s in
      close_i i ch' s'

and refine c s = 
  Trace.msg "rule" "Refine" c Fact.pp_cnstrnt;
  let (k, i, _) = Fact.d_cnstrnt c in
  let (ch', p') = Partition.add c s.p in
    s.p <- p';
    close_p ch' s

(** Infer new disequalities from equalities. *)

and infer i e s =
  let (a, b, _) = Fact.d_equal e in
    if Term.eq a b then 
      s
    else if Th.eq i Th.la then
      infer_la e s
    else if Th.eq i Th.bv then
      infer_bv e s
    else if Th.eq i Th.cop then
      infer_cop e s
    else 
      s

(** If x = q and y = p with q, p numerical constraints, then deduce x <> y. *)

and infer_la e s =
  let (x, a, _) = Fact.d_equal e in
    match Arith.d_num a with  
      | Some(q) ->
          Solution.fold
            (fun y (b, _) s ->
               match Arith.d_num b with
                 | Some(p) when not(Q.equal q p) ->
                     let d = Fact.mk_diseq (v s x) (v s y) None in
                        diseq d s
                 | _ -> s)               
            (eqs_of s Th.la) s
      | None ->
         s 

and infer_bv e s =
  let (x, a, _) = Fact.d_equal e in
    match Bitvector.d_const a with  
      | Some(c) ->
          Solution.fold
            (fun y (b, _) s ->
               match Bitvector.d_const b with
                 | Some(d) when not(Bitv.equal c d) ->
                     let d = Fact.mk_diseq (v s x) (v s y) None in
                        diseq d s
                 | _ -> s)               
            (eqs_of s Th.bv) s
      | None ->
         s 

(** If x = a, y = b are in the coproduct solution set, and if a and b are diseq in this theory, then deduce x <> y. *)

and infer_cop e s =
  let (x, a, _) = Fact.d_equal e in
    match a with
     | App(Coproduct(InL | InR), [_]) ->
         Solution.fold
           (fun y (b, _) s ->
              if Coproduct.is_diseq a b then
                let d = Fact.mk_diseq (v s x) (v s y) None in 
                  diseq d s              
              else
                 s)
           (eqs_of s Th.cop) s
     | _ -> s

(** Deduce new constraints from an equality *)

and deduce i e s =
  let (a, b, _) = Fact.d_equal e in
    if Term.eq a b then 
      s
    else if Th.eq i Th.la then
      deduce_la e s
    else if Th.eq i Th.bvarith then
      deduce_bvarith e s
    else if Th.eq i Th.pprod then
      deduce_nonlin e s (* (deduce_nonlin2 e s) *)
    else 
      s

and deduce_bvarith e s = 
  let (x, b, _) = Fact.d_equal e in
    match b with
      | App(Bvarith(Unsigned), [y]) ->   (* [x = unsigned(y)] *)      
          let c = Fact.mk_cnstrnt (v s x) Sign.T None in
            add c s
      | _ -> 
          s

and deduce_nonlin e s =
  Trace.msg "rule" "Deduce(nl)" e Fact.pp_equal;
  let (x, a, _) = Fact.d_equal e in
  let x' = v s x in
  try
    let j = c s a in
      (try
         let i = c s x' in
           if Sign.sub j i then s else 
             add (Fact.mk_cnstrnt x' j None) s
       with
           Not_found -> 
             add (Fact.mk_cnstrnt x' j None) s)
  with
      Not_found -> s

and deduce_nonlin2 e s =
  let partition a = 
    let (bl, dl) = 
      Pp.fold
        (fun y n (bl, dl) ->
           try
             let _ = c s y in
               (bl, (y, n) :: dl)
           with
               Not_found -> ((y, n):: bl, dl))
        a ([], [])
    in
      (Pp.of_list bl, Pp.of_list dl)
  in
  let (x, a, _) = Fact.d_equal e in
  let x' = v s x in
  let (b, d) = partition a in   (* [b] is unconstrained, [d] is constrained. *)
    Trace.msg "foo" "Partition" (b, d) (Pretty.pair Term.pp Term.pp);
    try
      (match c s x', c s d with
         | Sign.PosSign.Pos ->
             add (Fact.mk_cnstrnt b Sign.Pos None) s
         | Sign.NonnegSign.Nonneg -> 
             add (Fact.mk_cnstrnt b Sign.Nonneg None) s
         | _ -> s)
    with
        Not_found -> 
          if not(Term.eq b Pp.mk_one) then s else
            try
              add (Fact.mk_cnstrnt x' (c s d) None) s
            with
                Not_found -> s  (* should not happen *)


(** If [k = R[k''], [k' = R'[k'']] in [A], then
    isolate [k''] in [k' = R'[k'']] to obtain [k'' = R'']
    and plug this solution into [R] to obtain [R'''] as [sigma(R[k'':= R''])].
    deduce new constraints from [k = R'''].  Notice that equality [k' = R'] has
    only to be considered once. *)


and deduce_la e s =
  let s' = deduce_la1 e s in
  let (k, r, _) = Fact.d_equal e in
  let k = v s k in
    if not(is_slack k) then s else
      let visited = ref Term.Set.empty in
        Arith.fold
          (fun _ k'' s ->
             Set.fold
               (fun k' s ->
                  if Term.Set.mem k' !visited then s else 
                    try
                      let r' = apply Th.la s k' in
                      let r'' = Arith.isolate k'' (k', r') in
                      let e' = Fact.mk_equal k (Arith.replace r k'' r'') None in
                        visited := Term.Set.add k' !visited;
                        deduce_la1 e' s
                    with
                        Not_found -> s)
               (use Th.la s k'') 
               s)
          r s'
     
    
(* If C(k) = 0, then we can assert that S(k) = 0.
   Inequality Propagation: If S(k) = R+[k'] - R-[k''],
   then if either C(k) or C[R-[k'']] goes from >= to >,
   then we can assert R+[k']>0. *)
 

and deduce_la1 e s =
  let inconsistent (x, a) =
    try
      (dom s x = Dom.Int&&
      (match Arith.d_num a with
         | Some(q) -> not(Q.is_integer q)
         | None -> false)
    with
        Not_found -> false
  in
  let partition s a =
    let sign_of_monomial = function
      | App(Arith(Num(q)), []) -> Sign.of_q q
      | App(Arith(Multq(q)), [k]) -> Sign.multq q (c s k)
      | k -> cnstrnt s k
    in
    let rec loop posl negl = function
      | [] -> (Arith.mk_addl posl, Arith.mk_addl negl)
      | m :: ml ->
          (match sign_of_monomial m with
             | (Sign.Pos | Sign.Nonneg | Sign.Zero->
                 loop (m :: posl) negl ml
             | (Sign.Neg | Sign.Nonpos-> 
                 loop posl (Arith.mk_neg m :: negl) ml
             | Sign.F -> raise Exc.Inconsistent  (* following should not happen *)
             | Sign.T -> raise Not_found)
    in
      loop [] [] (Arith.monomials a)
  in
  let (x, sk, prf) = Fact.d_equal e in
  let k = v s x in
    if not(is_slack k) then 
      s 
    else if inconsistent (k, sk) then
      raise Exc.Inconsistent
    else  
      begin
        Trace.msg "rule" "Deduce(la)" e Fact.pp_equal;
        try  
          let c_sk = c s sk in  (* Keep invariant that [c s k] is stronger than [c s sk]. *)
          let s = add (Fact.mk_cnstrnt k c_sk None) s in 
          let c_k = c s k in 
            if c_k = Sign.Zero then
              equality (Fact.mk_equal k Arith.mk_zero None) s
            else 
              let (r_plus, r_minus) = partition s (find Th.la s k) in
                if c_k = Sign.Pos || c s r_minus = Sign.Pos then
                  if c s r_plus = Sign.Pos then s else 
                    add (Fact.mk_cnstrnt r_plus Sign.Pos None) s 
                else 
                  s
        with
            Not_found -> s (* should not happen *)
      end 
            

(** Merging variable equality/disequalities/constraints *)

and diseq d s =
  Trace.msg "rule" "Diseq" d Fact.pp_diseq;
  let (x, y, _) = Fact.d_diseq d in
  let (ch', p') = Partition.diseq d s.p in
    s.p <- p';
    close_p ch' s

and add c s =
  Trace.msg "rule" "Add" c Fact.pp_cnstrnt;
  let normalize (a, i) =
    let (a', i') = match i with
      | Sign.Neg -> (Arith.mk_neg a, Sign.Pos)
      | Sign.Nonpos -> (Arith.mk_neg a, Sign.Nonneg)
      | _ -> (a, i)
    in
      (lookup s a', i')   
  in
  let (a, i, prf) = Fact.d_cnstrnt c in
    match Arith.d_num a with
      | Some(q) -> 
          if Sign.disjoint i (Sign.of_q q) then 
            raise Exc.Inconsistent
          else 
            s
      | None ->
          (match i with
             | Sign.F ->
                 raise Exc.Inconsistent
             | Sign.Zero -> 
                 equality (Fact.mk_equal a Arith.mk_zero None) s
             | Sign.T ->
                 s
             | _ ->
                 let (b, i) = normalize (a, i) in
                   match b with
                     | Var _ when is_slack b ->
                         refine (Fact.mk_cnstrnt b i prf) s
                     | App(Arith(Multq(q)), [x]) when is_slack x ->
                         refine (Fact.mk_cnstrnt x (Sign.multq (Q.inv q) i) None) s
                     | _ ->
                         let d = if is_int s a then Some(Dom.Intelse None in
                         let alpha = if i = Sign.Pos then false else true in
                         let k = Term.mk_slack None alpha d in
                           equality (Fact.mk_equal k b None)
                             (refine (Fact.mk_cnstrnt k i None) s))


(** Propagate changes in the variable partitioning. *)
    
and close_i i =
  Set.fold
    (fun x s ->
       try
         let e' = equation i s x in
         let s' =  if Th.eq i Th.la then nonlin_equal e' s else s in
         let s'' =  deduce i e' s' in
         let s''' = infer i e' s'' in
           s'''
       with
           Not_found -> s)

and nonlin_equal e s =
  let rec linearize occs s =
    Set.fold
      (fun x s ->
         try 
           let a = apply Th.pprod s x in
           let b = Sig.map (find Th.la s) a in
             if Term.eq a b then s else 
               let (s', b') = Abstract.term Th.u (s, b) in
               let e' = Fact.mk_equal (v s' x) b' None in
                 merge_v e' s'
         with
             Not_found -> s)
      occs s
  in
  let (x, _, _) = Fact.d_equal e in
    linearize (use Th.pprod s x) s


(** Propagate changes in the variable partitioning. *)
    
and close_p ch s =
  close_v ch.Partition.chv
    (close_c ch.Partition.chc 
       (close_d ch.Partition.chd s))

and close_v chv = 
  Set.fold
    (fun x s ->
       let y = v s x in
         if Term.eq x y then s else 
           let e = Fact.mk_equal x y None in
             Trace.msg "rule" "Close(v)" e Fact.pp_equal;
             let s' =  List.fold_right
                         (fun i s ->
                            let a = find i s x 
                            and b = find i s y in
                              if Term.eq a b || (is_var a && is_var b) then 
                                s 
                              else
                                merge_i i e s)
                         Th.interp s
             in
             let s'' = arrays_equal e s' in
             let s''' = bvarith_equal e s'' in
               s''')
    chv 

and close_c chc = 
  Set.fold
    (fun x s ->
       try
         let i = c s x in
           match i with
             | Sign.F -> 
                 raise Exc.Inconsistent
             | Sign.Zero ->
                 equality (Fact.mk_equal x Arith.mk_zero None) s
             | _ ->
                 s
           with
               Not_found -> s)
    chc

and close_d chd =
  Set.fold
    (fun x s ->
       let yl = d s x in
         List.fold_right
           (fun y s ->
              let d = Fact.mk_diseq x y None in
                arrays_diseq d 
                   (bv_diseq d s))
           yl s)
    chd 

(**
Bitvector propagation
*)


and bv_diseq d' s =
  let add x c =
    try
      let (k, cs) = Term.Map.find x s.diseqs in
        failwith "to do"
    with
        Not_found -> 
          Term.Map.add x (1, Term.Set.singleton c)
  in
  let (x, _, _) = Fact.d_diseq d' in
    try
      let a = apply Th.bv s x in
        if not(Bitvector.is_const a) then s else
          s
    with
        Not_found -> s
  
(**
Array propagation
*)


(** Forward chaining on the array properties
  • a[i:=x][i] = x
  • i <> j implies a[i:=x][j] = a[x]
  • i <> j and i <> k implies a[j:=x][i] = a[k:=y][i]
  • a[j:=x] = b[k:=y], i <> j, i <> k implies a[i] = b[i].
  • a[i:=x] = b[i := y] implies x = y.
*)


and arrays_diseq d s =
  if is_empty Th.arr s then s else 
    arrays_diseq1 d
      (arrays_diseq2 d
         (arrays_diseq3 d s))

(** i <> j implies a[i:=x][j] = a[j]. Thus, look for v = u[j] and u' = a[i := x] with u = u' in s using the use lists. Now, when w = a[j], then infer v = w. *)

and arrays_diseq1 d s =
  let (i, j, prf) = Fact.d_diseq d in
  let diseq  (i, j) s =
    Set.fold
      (fun v s -> 
         try
           let (u, j) = d_select s (tt, tt) v in     (* [v = u[j]] *)
             fold s
               (fun u' s ->                         (* [u = u'] *)
                  try                               (* [u' = a[i:=x]] *)
                    let (a, i, x) = d_update s (tt, tt, tt) u' in
                    let (s, w) = name Th.arr (s, Arr.mk_select Term.is_equal a j) in
                    let e' = Fact.mk_equal v w None in
                      merge_v e' s
                  with
                      Not_found -> s)
               u s
         with
             Not_found -> s)
      (use Th.arr s j) s
  in
    diseq (i, j)
      (diseq (j, i) s)

(** i <> j and i <> k implies a[j:=x][i] = a[k:=y][i]. Thus, for i <> j, select v = u[i] and u' = a[j:=x] with u = u' in s. Now, for all u'' = a[k:=y] with k <> i, add a[j:=x][i] = a[k:=y][i]. *)

and arrays_diseq2 d s =
  let (i, j, _) = Fact.d_diseq d in
  let diseq (i, j) s =
    Set.fold 
      (fun v s ->
         (try
            let (u, _) = d_select s (tt, is_eq s i) v in
              fold s
                (fun u' s ->
                   if not(is_eq s u u') then s else
                     try
                       let (a, _, x) = d_update s (tt, is_eq s j, tt) u' in
                         Set.fold
                           (fun u'' s ->
                              try
                                let (_, k, y) = d_update s (is_eq s a, is_diseq s i, tt) u'' in
                                let (s, u1) = name Th.arr (s, Arr.mk_update Term.is_equal a j x) in
                                let (s, v1) = name Th.arr (s, Arr.mk_select Term.is_equal u1 i) in
                                let (s, u2) = name Th.arr (s, Arr.mk_update Term.is_equal a k y) in
                                let (s, v2) = name Th.arr (s, Arr.mk_select Term.is_equal u2 i) in
                                let e' = Fact.mk_equal v1 v2 None in
                                  merge_v e' s
                              with
                                  Not_found -> s)
                           (use Th.arr s a) s
                     with
                         Not_found -> s)
                u s
          with 
              Not_found -> s))
      (use Th.arr s i) s
  in
    diseq (i, j)
      (diseq (i, j) s)

(** a[j:=x] = b[k:=y], i <> j, i <> k implies a[i] = b[i]. We are propagating disequalities i <> j. If u = a[j := x], then look for all k disequal from i for v = b[k:=y]. Now, assert a[i] = b[i]. *)

and arrays_diseq3 d' s = 
  let (i, j, _) = Fact.d_diseq d' in
  let diseq (i, j) s =
    Set.fold
      (fun u s ->
         try
           let (a, j, _) = d_update s (tt, tt, tt) u in
             List.fold_right
               (fun k s ->
                  Set.fold
                    (fun v s ->
                       try
                         let (b, _, _) = d_update s (tt, is_eq s k, tt) v in
                         let (s, w1) = name Th.arr (s, Arr.mk_select Term.is_equal a i) in
                         let (s, w2) = name Th.arr (s, Arr.mk_select Term.is_equal b i) in
                         let e' = Fact.mk_equal w1 w2 None in
                           merge_v e' s
                       with
                           Not_found -> s)
                    (use Th.arr s k) s)
               (d s i) s
         with
             Not_found -> s)
      (use Th.arr s j) s
  in
    diseq (i, j)
      (diseq (j, i) s)


and arrays_equal e s =
  if is_empty Th.arr s then s else 
    arrays_equal1 e
      (arrays_equal2 e
         (arrays_equal3 e s))
     
(** i = j implies a[i:= x][j] = x. Since i and j are already merged on rhs, it suffices to look for v1 = v2'[i] and v2 = a[i := x] with v2 equal v2' in s. Now, v1 = x. *)

and arrays_equal1 e s =
  let (i, j, _) = Fact.d_equal e in
    Set.fold
      (fun v1 s ->
         try
           let (v2', _) = d_select s (tt, is_eq s i) v1 in
             Set.fold
               (fun v2 s -> 
                  if not(is_eq s v2 v2') then s else 
                    try
                      let (a, _, x) = d_update s (tt, is_eq s i, tt) v2 in
                      let e' = Fact.mk_equal (v s v1) x None in
                        merge_v e' s
                    with
                        Not_found -> s)
               (use Th.arr s (v s j))
               s
          with
              Not_found ->  s)
      (use Th.arr s (v s j))
      s

(** a[i:=x] = b[i := y] implies x = y. Thus, if v = u has been merged, then look for v = a[i:=x] and v' = b[i := y] with v = v' in s, now merge x = y. *)

and arrays_equal2 e s = 
  let (v, _, _) = Fact.d_equal e in   (* [find] of [v] has changed. *)
    try
      let (_, i, x) = d_update s (tt, tt, tt) v in
        fold s
          (fun v' s ->
             if Term.eq v v' then s else
               try
                 let (_, _, y) = d_update s (tt, is_eq s i, tt) v' in
                 let e' = Fact.mk_equal x y None in
                   merge_v e' s
               with
                   Not_found -> s)
          v s
    with
        Not_found -> s

(** a[j:=x] = b[k:=y], i <> j, i <> k implies a[i] = b[i]. Thus, for a merged v = u, look for v = a[j:= x] and v' = b[k := y] with v = v' in s. For all i such that i <> j and i <> k add w1 = w2 for w1 = a[i] and w2 = b[i], possibly extending the solution set. *)

and arrays_equal3 e s = 
  let (v, _, _) = Fact.d_equal e in   (* [find] of [v] has changed. *)
    try
      let (a, j, _) = d_update s (tt, tt, tt) v in
        fold s
          (fun v' s ->
             if Term.eq v v' then s else
               try
                 let (b, k, _) = d_update s (tt, tt, tt) v' in
                   List.fold_right
                     (fun i s ->
                        if not(is_diseq s i k) then s else
                          let (s, w1) = name Th.arr (s, Arr.mk_select Term.is_equal a i) in
                          let (s, w2) = name Th.arr (s, Arr.mk_select Term.is_equal b i) in
                          let e' = Fact.mk_equal w1 w2 None in
                            merge_v e' s)
                     (d s j) s
               with
                   Not_found -> s)
          v s
    with
        Not_found -> s
   
          
and bvarith_equal e s =
  if is_empty Th.bvarith s then s else 
    let (x, bv, prf) = Fact.d_equal e in
      Set.fold
        (fun u s ->
           try
             (match apply bvarith s u with
                | App(Bvarith(Unsigned), [x'])
                    when Term.eq x x' ->
                    let ui = Bvarith.mk_unsigned bv in
                    let (s', a') = Abstract.term la (s, ui) in
                    let e' = Fact.mk_equal (v s' u) a' None in
                      equality e' s'
                | _ ->
                    s )
           with
               Not_found -> s)
        (use bvarith s x)
        s