Theory Code_Binary_Nat

(*  Title:      HOL/Library/Code_Binary_Nat.thy
    Author:     Florian Haftmann, TU Muenchen
*)

section Implementation of natural numbers as binary numerals

theory Code_Binary_Nat
imports Code_Abstract_Nat
begin

text 
  When generating code for functions on natural numbers, the
  canonical representation using term0::nat and
  termSuc is unsuitable for computations involving large
  numbers.  This theory refines the representation of
  natural numbers for code generation to use binary
  numerals, which do not grow linear in size but logarithmic.


subsection Representation

code_datatype "0::nat" nat_of_num

lemma [code]:
  "num_of_nat 0 = Num.One"
  "num_of_nat (nat_of_num k) = k"
  by (simp_all add: nat_of_num_inverse)

lemma [code]:
  "(1::nat) = Numeral1"
  by simp

lemma [code_abbrev]: "Numeral1 = (1::nat)"
  by simp

lemma [code]:
  "Suc n = n + 1"
  by simp


subsection Basic arithmetic

context
begin

declare [[code drop: "plus :: nat  _"]]  

lemma plus_nat_code [code]:
  "nat_of_num k + nat_of_num l = nat_of_num (k + l)"
  "m + 0 = (m::nat)"
  "0 + n = (n::nat)"
  by (simp_all add: nat_of_num_numeral)

text Bounded subtraction needs some auxiliary

qualified definition dup :: "nat  nat" where
  "dup n = n + n"

lemma dup_code [code]:
  "dup 0 = 0"
  "dup (nat_of_num k) = nat_of_num (Num.Bit0 k)"
  by (simp_all add: dup_def numeral_Bit0)

qualified definition sub :: "num  num  nat option" where
  "sub k l = (if k  l then Some (numeral k - numeral l) else None)"

lemma sub_code [code]:
  "sub Num.One Num.One = Some 0"
  "sub (Num.Bit0 m) Num.One = Some (nat_of_num (Num.BitM m))"
  "sub (Num.Bit1 m) Num.One = Some (nat_of_num (Num.Bit0 m))"
  "sub Num.One (Num.Bit0 n) = None"
  "sub Num.One (Num.Bit1 n) = None"
  "sub (Num.Bit0 m) (Num.Bit0 n) = map_option dup (sub m n)"
  "sub (Num.Bit1 m) (Num.Bit1 n) = map_option dup (sub m n)"
  "sub (Num.Bit1 m) (Num.Bit0 n) = map_option (λq. dup q + 1) (sub m n)"
  "sub (Num.Bit0 m) (Num.Bit1 n) = (case sub m n of None  None
     | Some q  if q = 0 then None else Some (dup q - 1))"
  apply (auto simp add: nat_of_num_numeral
    Num.dbl_def Num.dbl_inc_def Num.dbl_dec_def
    Let_def le_imp_diff_is_add BitM_plus_one sub_def dup_def)
  apply (simp_all add: sub_non_positive)
  apply (simp_all add: sub_non_negative [symmetric, where ?'a = int])
  done

declare [[code drop: "minus :: nat  _"]]

lemma minus_nat_code [code]:
  "nat_of_num k - nat_of_num l = (case sub k l of None  0 | Some j  j)"
  "m - 0 = (m::nat)"
  "0 - n = (0::nat)"
  by (simp_all add: nat_of_num_numeral sub_non_positive sub_def)

declare [[code drop: "times :: nat  _"]]

lemma times_nat_code [code]:
  "nat_of_num k * nat_of_num l = nat_of_num (k * l)"
  "m * 0 = (0::nat)"
  "0 * n = (0::nat)"
  by (simp_all add: nat_of_num_numeral)

declare [[code drop: "HOL.equal :: nat  _"]]

lemma equal_nat_code [code]:
  "HOL.equal 0 (0::nat)  True"
  "HOL.equal 0 (nat_of_num l)  False"
  "HOL.equal (nat_of_num k) 0  False"
  "HOL.equal (nat_of_num k) (nat_of_num l)  HOL.equal k l"
  by (simp_all add: nat_of_num_numeral equal)

lemma equal_nat_refl [code nbe]:
  "HOL.equal (n::nat) n  True"
  by (rule equal_refl)

declare [[code drop: "less_eq :: nat  _"]]

lemma less_eq_nat_code [code]:
  "0  (n::nat)  True"
  "nat_of_num k  0  False"
  "nat_of_num k  nat_of_num l  k  l"
  by (simp_all add: nat_of_num_numeral)

declare [[code drop: "less :: nat  _"]]

lemma less_nat_code [code]:
  "(m::nat) < 0  False"
  "0 < nat_of_num l  True"
  "nat_of_num k < nat_of_num l  k < l"
  by (simp_all add: nat_of_num_numeral)

declare [[code drop: Divides.divmod_nat]]
  
lemma divmod_nat_code [code]:
  "Divides.divmod_nat (nat_of_num k) (nat_of_num l) = divmod k l"
  "Divides.divmod_nat m 0 = (0, m)"
  "Divides.divmod_nat 0 n = (0, 0)"
  by (simp_all add: prod_eq_iff nat_of_num_numeral)

end


subsection Conversions

declare [[code drop: of_nat]]

lemma of_nat_code [code]:
  "of_nat 0 = 0"
  "of_nat (nat_of_num k) = numeral k"
  by (simp_all add: nat_of_num_numeral)


code_identifier
  code_module Code_Binary_Nat 
    (SML) Arith and (OCaml) Arith and (Haskell) Arith

end