Require Export header_extensible.

From MetaCoq.Template Require Export All.
From MetaCoq.Checker Require Export Checker uGraph.

Local Set Implicit Arguments.
Require Export String List Omega.
Export ListNotations.
Local Unset Strict Implicit.

MetaCoq Commands


Class subtermC A := subterm_rel : A -> A -> Prop.
Class InRelC B := {in_subtype : Type ; in_rel : in_subtype -> B -> Prop }.

Fixpoint destArity Γ (t : term) :=
  match t with
  | tProd na t b => destArity (Γ ,, vass na t) b
  | tLetIn na b b_ty b' => destArity (Γ ,, vdef na b b_ty) b'
  | s => (Γ, s)
  end.

Derive Subterm for nat.

Fixpoint replace_term (c : term) d (t : term) :=
  if @eq_term config.type_in_type init_graph c t then d else
    match t with
  | tRel i => tRel i
  | tEvar ev args => tEvar ev (List.map (replace_term c d) args)
  | tLambda na T M => tLambda na (replace_term c d T) (replace_term c d M)
  | tApp u v => tApp (replace_term c d u) (List.map (replace_term c d) v)
  | tProd na A B => tProd na (replace_term c d A) (replace_term c d B)
  | tCast C kind t => tCast (replace_term c d C) kind (replace_term c d t)
  | tLetIn na b t b' => tLetIn na (replace_term c d b) (replace_term c d t) (replace_term c d b')
  | tCase ind p C brs =>
    let brs' := List.map (on_snd (replace_term c d)) brs in
    tCase ind (replace_term c d p) (replace_term c d C) brs'
  | tProj p C => tProj p (replace_term c d C)
  | tFix mfix idx =>
    let mfix' := List.map (map_def (replace_term c d) (replace_term c d)) mfix in
    tFix mfix' idx
  | tCoFix mfix idx =>
    let mfix' := List.map (map_def (replace_term c d) (replace_term c d)) mfix in
    tCoFix mfix' idx
  | tConst name u => tConst name u
  | x => x
  end.
Require Import Ascii.

Fixpoint name_after_dot' (s : string) (r : string) :=
  match s with
  | EmptyString => r
  | String "#" xs => name_after_dot' xs xs (* see Coq_name in a section *)
  | String ("."%char) xs => name_after_dot' xs xs
  | String _ xs => name_after_dot' xs r
  end.

Definition name_after_dot s := name_after_dot' s s.

Fixpoint fixNames (t : term) :=
  match t with
  | tRel i => tRel i
  | tEvar ev args => tEvar ev (List.map (fixNames) args)
  | tLambda na T M => tLambda na (fixNames T) (fixNames M)
  | tApp u v => tApp (fixNames u) (List.map (fixNames) v)
  | tProd na A B => tProd na (fixNames A) (fixNames B)
  | tCast C kind t => tCast (fixNames C) kind (fixNames t)
  | tLetIn na b t b' => tLetIn na (fixNames b) (fixNames t) (fixNames b')
  | tCase ind p C brs =>
    let brs' := List.map (on_snd (fixNames)) brs in
    tCase ind (fixNames p) (fixNames C) brs'
  | tProj p C => tProj p (fixNames C)
  | tFix mfix idx =>
    let mfix' := List.map (map_def (fixNames) (fixNames)) mfix in
    tFix mfix' idx
  | tCoFix mfix idx =>
    let mfix' := List.map (map_def (fixNames) (fixNames)) mfix in
    tCoFix mfix' idx
  | tConst name u => tConst (name_after_dot name) u
  | tInd (mkInd name i) u => tInd (mkInd (name_after_dot name) i) u
  | x => x
  end.

Fixpoint replace_const (c : kername) d (t : term) :=
  match t with
  | tRel i => tRel i
  | tEvar ev args => tEvar ev (List.map (replace_const c d) args)
  | tLambda na T M => tLambda na (replace_const c d T) (replace_const c d M)
  | tApp u v => tApp (replace_const c d u) (List.map (replace_const c d) v)
  | tProd na A B => tProd na (replace_const c d A) (replace_const c d B)
  | tCast C kind t => tCast (replace_const c d C) kind (replace_const c d t)
  | tLetIn na b t b' => tLetIn na (replace_const c d b) (replace_const c d t) (replace_const c d b')
  | tCase ind p C brs =>
    let brs' := List.map (on_snd (replace_const c d)) brs in
    tCase ind (replace_const c d p) (replace_const c d C) brs'
  | tProj p C => tProj p (replace_const c d C)
  | tFix mfix idx =>
    let mfix' := List.map (map_def (replace_const c d) (replace_const c d)) mfix in
    tFix mfix' idx
  | tCoFix mfix idx =>
    let mfix' := List.map (map_def (replace_const c d) (replace_const c d)) mfix in
    tCoFix mfix' idx
  | tConst name u => if eq_string name c then d else tConst name u
  | x => x
  end.

Fixpoint remove_injs (k : nat) (u : term) {struct u} : term :=
  match u with
  | tRel n => tRel n
  | tEvar ev args => tEvar ev (map (remove_injs k) args)
  | tCast c kind ty => tCast (remove_injs k c) kind (remove_injs k ty)
  | tProd na A B => tProd na (remove_injs k A) (remove_injs (S k) B)
  | tLambda na T M => tLambda na (remove_injs k T) (remove_injs (S k) M)
  | tLetIn na b ty b' => tLetIn na (remove_injs k b) (remove_injs k ty) (remove_injs (S k) b')
  | tApp u0 v => let L := remove_injs k u0 in
                let R := map (remove_injs k) v in
                match L, R with tConst "header_extensible.retract_I" _ , [_;_;_;tRel k'] => (* if k =? k' then *) tRel k' (* else mkApps L R *)
                           | _,_ => mkApps L R end
  | tCase ind p c brs => let brs' := map (on_snd (remove_injs k)) brs in tCase ind (remove_injs k p) (remove_injs k c) brs'
  | tProj p c => tProj p (remove_injs k c)
  | tFix mfix idx => let k' := #|mfix| + k in let mfix' := map (map_def (remove_injs k) (remove_injs k')) mfix in tFix mfix' idx
  | tCoFix mfix idx =>
      let k' := #|mfix| + k in let mfix' := map (map_def (remove_injs k) (remove_injs k')) mfix in tCoFix mfix' idx
  | _ => u
  end.

Fixpoint replace_terms Repl t :=
  match Repl with
  | ((c, s) :: Repl) => replace_terms Repl (replace_term c s t)
  | _ => t
  end.

Fixpoint replace_consts Repl t :=
  match Repl with
  | ((tConst na u, s) :: Repl) => replace_consts Repl (replace_const na s t)
  | _ => t
  end.

Fixpoint replace_ext T T_ext t :=
  match t with
  | tProd na argT retT => if @eq_term config.default_checker_flags init_graph argT T_ext
                         then tProd na T (remove_injs 0 (replace_ext T T_ext retT))
                         else tProd na argT (replace_ext T T_ext retT)
  | s => s
  end.

(* todo here  *)
Definition genIH retT decl_list LT (T T_ext : term) Repl :=
  match LT with
    Some LT =>
    let IH := replace_terms Repl retT in
    let IH := (it_mkProd_or_LetIn decl_list ((tProd nAnon (tApp LT [tRel 0; tRel 1]) (lift 1 0 IH)))) in
    let IH := replace_ext T T_ext IH in
    IH
  | None =>
    let IH := replace_terms Repl retT in
    let IH := (it_mkProd_or_LetIn decl_list (IH)) in
    let IH := replace_ext T T_ext IH in
    IH
  end.

Fixpoint genStatement LT decl_list (T T_ext : term) Repl n B :=
  match n, B with
  | 0, retT => let IH := genIH retT (decl_list) (Some LT) T T_ext Repl in
                            it_mkProd_or_LetIn (decl_list) ((tProd nAnon IH (lift 1 0 retT)))
  | S n, tProd na argT retT => (genStatement LT (vass na argT :: decl_list) T T_ext Repl n retT)
  | _, _ => tVar "no"
  end.

Fixpoint genStatement_no_lt decl_list (T T_ext : term) Repl n B :=
  match n, B with
  | 0, retT => let IH := genIH retT (decl_list) None T T_ext Repl in
                            (IH, it_mkProd_or_LetIn (decl_list) ((tProd nAnon IH (lift 1 0 retT))))
  | S n, tProd na argT retT => (genStatement_no_lt (vass na argT :: decl_list) T T_ext Repl n retT)
  | _, _ => (tVar "no", tVar "no")
  end.

Fixpoint remove_suffix (a1 a2 s : string) : string :=
  match s with
  | EmptyString => a1
  | String "_"%char s =>
    remove_suffix (append a1 a2) (String "_"%char EmptyString) s
  | String c s => remove_suffix a1 (append a2 (String c EmptyString)) s
  end.

Definition mkLemma name (t : term) : TemplateMonad unit :=
  tmBind (tmUnquoteTyped Prop (fixNames t)) (fun t => tmBind (tmLemma name t) (fun _ => tmReturn tt)).

Definition mkVariable name (t : term) : TemplateMonad unit :=
  A <- tmAbout name ;;
  match A with None =>
               tmBind (tmUnquoteTyped Type (fixNames t)) (fun t => tmBind (tmVariable name t) (fun _ => tmReturn tt))
          | Some s => tmPrint "variable already exists, not defining"
  end.

Definition Forall' (name : ident) (n : nat) T T_ext Repl (A_P : term) (P : term) : TemplateMonad term :=
  match destArity nil P with (Gamma, _)=>
  match nth_error (rev Gamma) (Nat.pred n) with Some T' => let T' := fixNames (decl_type T') in
         T' <- tmUnquoteTyped Type T' ;;
         O <- tmInferInstance None (InRelC T') ;;
         match O with
         | Some Cl => let lt := @in_rel _ Cl in
                     LT <- tmQuote lt ;;
                     let St := genStatement LT [] T T_ext Repl n P in ret St
         | None => O <- tmInferInstance None (subtermC T') ;;
                    match O with
                    | Some Cl => let lt := @subterm_rel _ Cl in
                                LT <- tmQuote lt ;;
                                let St := genStatement LT [] T T_ext Repl n P in ret St
                    | None => let (IH, St) := genStatement_no_lt [] T T_ext Repl n P in
                             newname <- tmEval cbv (remove_suffix EmptyString EmptyString name) ;;
                             newname <- tmFreshName newname ;;
                             mkVariable newname (fixNames IH);;
                             ret St
                    end
         end
  | _ => tmFail "not enough arguments"
  end end.

Inductive genList : Type := genNil | genCons (A B : Type) (a : A) (b : B) : genList -> genList.

Fixpoint quote_list L1 : TemplateMonad (list (term * term)) :=
  match L1 with
  | genNil => ret []
  | genCons _ _ a b L => t <- tmQuote a ;; t' <- tmQuote b ;; L <- quote_list L ;; ret ((t,t') :: L)
  end.

Definition Forall (name : ident) (T : Type) (T_ext : Type) (n : nat) Repl {A_P : Type} (P : A_P) :=
  T <- tmQuote T ;;
  T_ext <- tmQuote T_ext ;;
  A_P <- tmQuote A_P ;;
  P <- tmQuote P ;;
  Repl <- quote_list Repl ;;
  St <- Forall' name n T T_ext Repl A_P P ;;
  (* (tmEval cbn St >>= tmPrint) ;; *)
  mkLemma name St.

Definition getName X (x : X) :=
  x <- tmEval cbv x;;
  t <- tmQuote x ;; match t with tLambda (nNamed na) _ _ => ret na | _ => ret "" end.

Definition ForallN (name : nat -> nat) (T : Type) (T_ext : Type) (n : nat) Repl {A_P : Type} (P : A_P) :=
  name <- getName name ;;
  T <- tmQuote T ;;
  T_ext <- tmQuote T_ext ;;
  A_P <- tmQuote A_P ;;
  P <- tmQuote P ;;
  Repl <- quote_list Repl ;;
  St <- Forall' name n T T_ext Repl A_P P ;;
  mkLemma name St.

Hint Extern 0 (subterm_rel _ _) => hnf.

Global Obligation Tactic := idtac.

Definition cns {X Y} '(x,y) := @genCons X Y x y.

Notation "x ~> y" := (x,y) (only parsing, at level 60).

Notation "[~ x ~]" := (cns x genNil) : list_scope.
Notation "[~ x ; y ; .. ; z ~]" := (cns x (cns y .. (cns z genNil) ..)) : list_scope.

Notation "'Modular' 'Lemma' na 'where' T_ext 'extends' T 'at' n 'with' C ':' P" := (ForallN (fun na : nat => na) T T_ext n C P) (at level 1, n at next level, C at next level, P at next level).
Notation "'Modular' 'Lemma' na 'where' T_ext 'extends' T 'with' C ':' P" := (ForallN (fun na : nat => na) T T_ext 0 C P) (at level 1, C at next level, P at next level).

Instance nat_subterm' : subtermC nat := lt.

Ltac inv H2 := inversion H2; subst; clear H2.

Definition tmMkDefinition name (t : term) : TemplateMonad unit :=
  tmBind (tmUnquote (fixNames t)) (fun t => tmBind (tmDefinitionRed name (Some hnf) t) (fun _ => tmReturn tt)).

Fixpoint split_forall_impl decls T :=
  match T with
  | tProd nAnon H1 H2 => (decls, H1, H2)
  | tProd na A B => split_forall_impl (vass na A :: decls) B
  | _ => ([], tVar "no", tVar "no")
  end.

Definition buildImp args H1 H2 :=
  it_mkProd_or_LetIn args (tApp (tConst "Imp" Instance.empty) [ H1; H2 ]).

Inductive GenList : Type := GenNil | GenCons (A : Type) (a : A) : GenList -> GenList.

Ltac apply_one L :=
  match constr:(L) with
  | GenNil => fail 1
  | GenCons ?a ?L => (now eapply a; eauto) || apply_one L
  end.

Class has_features (X : string) := features : list string.
Definition get_features X {H : has_features X} := features.

Definition tmTryInferInstance (red : option reductionStrategy) A :=
  I <- tmInferInstance red A ;; match I with Some x => ret x | _ => tmFail "no instance found" end.

Definition get_name_of (t : term) :=
  match t with
  | tConst c u => c
  | tInd c u => inductive_mind c
  | tApp (tConst c u) _ => c
  | tApp (tInd c u) _ => inductive_mind c
  | _ => ""
  end.

Fixpoint get_name (n : nat) (P : term) : string :=
  match n, P with
  | 0, tProd na A B => get_name_of A
  | S n, tProd na A B => get_name n B
  | _ , _ => "no name found"
  end.

Definition get_lemmas (X : string) (name : string) : TemplateMonad GenList :=
  I <- tmTryInferInstance None (has_features X) ;;
  feats <- tmEval hnf (@get_features X I) ;;
  @monad_fold_left _ TemplateMonad_Monad GenList string (fun L feature => A <- tmUnquote (tConst (append (append name "_") feature) Instance.empty) ;;
                                                                       A' <- tmEval cbn (my_projT2 A);;
                                 ret (GenCons A' L)) feats GenNil.

Definition get_lemmas_and_name (na : nat -> nat) (n : nat) (P : Type) :=
  na <- getName na;;
  P_syn <- tmQuote P ;;
  N <- tmEval cbv (name_after_dot (get_name n P_syn)) ;;
  get_lemmas N na.

Definition compose_fixpoint (na : nat -> nat) (P : Type) (body : P) :=
  na <- getName na;;
  tmDefinition na body.

Ltac int_dest n :=
  match constr:(n) with
  | 0 => let s := fresh "s" in intros s; destruct s
  | S ?n => let s := fresh "s" in intros s; int_dest n
  end.

Ltac int_inv n :=
  match constr:(n) with
  | 0 => let s := fresh "s" in intros s; inversion s
  | S ?n => let s := fresh "s" in intros s; int_inv n
  end.

Ltac int_ind ind n :=
  match constr:(n) with
  | 0 => let s := fresh "s" in intros s; induction s using ind; cbn
  | S ?n => let s := fresh "s" in intros s; int_ind ind n
  end.

Ltac fix_nat f n := let f := fresh "f" in
  match constr:(n) with
  | 0 => fix f 0
  | 1 => fix f 1
  | 2 => fix f 2
  | 3 => fix f 3
  | 4 => fix f 4
  | 5 => fix f 5
  | 6 => fix f 6
  | 7 => fix f 7
  | _ => fail
  end.

Notation "'Compose' 'Fixpoint' nm 'on' n ':' P" :=
  (let na := fun nm => nm in @compose_fixpoint na (P%type) ((ltac:(let k x := (fix_nat f (S n); int_dest n; apply_one x) in run_template_program (get_lemmas_and_name na n P) k)))) (at level 1, n at next level, P at next level).
Notation "'Compose' 'Lemma' nm 'on' n ':' P" :=
  (let na := fun nm => nm in @compose_fixpoint na (P%type) ((ltac:(let k x := (fix_nat f (S n); int_dest n; apply_one x) in run_template_program (get_lemmas_and_name na n P) k)))) (at level 1, n at next level, P at next level).
Notation "'Compose' 'Lemma' nm 'on' n 'by' 'inversion' ':' P" :=
  (let na := fun nm => nm in @compose_fixpoint na (P%type) ((ltac:(let k x := (fix_nat f (S n); int_inv n; apply_one x) in run_template_program (get_lemmas_and_name na n P) k)))) (at level 1, n at next level, P at next level).
Notation "'Compose' 'Lemma' nm 'on' n 'using' ind ':' P" :=
  (let na := fun nm => nm in @compose_fixpoint na (P%type) ((ltac:(let k x := (fix_nat f (S n); int_ind ind n; apply_one x) in run_template_program (get_lemmas_and_name na n P) k)))) (at level 1, n at next level, ind at next level, P at next level).

(* Modular Fixpoints *)

Definition genStatement_Fix (T T_ext : term) retT :=
  let IH := genIH retT [] None T T_ext [] in
    (IH, retT).

Definition mkDefinitionType name (T : Type) (t : term) : TemplateMonad unit :=
  tmBind (tmUnquoteTyped T (fixNames t)) (fun t => tmBind (tmDefinition name t) (fun _ => tmReturn tt)).

Definition ModularFixpointN (name : nat -> nat) (T : Type) (T_ext : Type) (P: Type) (A : Type) (body : A -> P) :=
  name <- getName name ;;
  T <- tmQuote T ;;
  T_ext <- tmQuote T_ext ;;
  P_syn <- tmQuote P ;;
  body_syn <- tmQuote body ;;
  let IH := genIH P_syn [] None T T_ext [] in
  newname <- tmEval cbv (remove_suffix EmptyString EmptyString name) ;;
  newname <- tmFreshName newname ;;
  mkVariable newname (fixNames IH);;
  mkDefinitionType name P (tApp body_syn [tVar newname]).

Definition ModularFixpoint2 (name : nat -> nat) (T : Type) (T_ext : Type) (P: Type) (body : P) :=
  name <- getName name ;;
  T <- tmQuote T ;;
  T_ext <- tmQuote T_ext ;;
  P_syn <- tmQuote P ;;
  body_syn <- tmQuote body ;;
  let IH := genIH P_syn [] None T T_ext [] in
  newname <- tmEval cbv (remove_suffix EmptyString EmptyString name) ;;
  newname <- tmFreshName newname ;;
  mkVariable newname (fixNames IH);;
  mkDefinitionType name P body_syn.

Notation "'Modular' 'Fixpoint' na 'where' T_ext 'extends' T 'with' c ':=' p" := (@ModularFixpointN (fun na : nat => na) T T_ext _ _ (fun c => p)) (at level 1, T_ext at next level, T at next level, c at next level, p at next level).

Notation "'Modular' 'Fixpoint' na 'where' T_ext 'extends' T ':=' p" := (@ModularFixpoint2 (fun na : nat => na) T T_ext _ p) (at level 1, T_ext at next level, T at next level, p at next level).

Hint Extern 0 => reflexivity.

(* Definition tmkAddInjection (name : ident) T := *)
(*   T_syn <- tmQuote T ;; *)
(*   let '((args, H1), H2) := split_forall_impl  T_syn in *)
(*   mkVariable name T;;  *)
(*   let I_syn := buildImp args H1 H2 in *)
(*   I <- tmUnquoteTyped Type I_syn;; *)
(*   i <- tmUnquoteTyped I (tConst name Instance.empty)  ;; *)
(*   (tmBind (tmDefinition (append name "_inst") i) (fun _ => tmReturn tt)). *)