diff --git a/ml/lib/algebra.ml b/ml/lib/algebra.ml index f4d18a8a1e9c3a50d2944456f2ef1cb835856bd5..a90c29bb66427b7bda340e4707fe0e291afbb11c 100644 --- a/ml/lib/algebra.ml +++ b/ml/lib/algebra.ml @@ -1,64 +1,50 @@ open Utils -module type Mappable = sig type 'a t val map: ('a -> 'b) -> 'a t -> 'b t end +module type Mappable = sig + type 'a t -module A(F: Mappable) = struct + val map : ('a -> 'b) -> 'a t -> 'b t +end +module A (F : Mappable) = struct type 'a fix = Fix of 'a fix F.t let out = function Fix fix -> fix - let in_ t = Fix t (* Generalization of fold_right *) - let rec cata fn = - out - %> F.map (fun x -> cata fn x) - %> fn - + let rec cata fn = out %> F.map (fun x -> cata fn x) %> fn let map fn x = cata (in_ %> fn) x (* Generalization of unfold_right *) - let rec ana fn = - fn - %> F.map (fun x -> ana fn x) - %> in_ - - let hylo f g = - (cata g) % (ana f) + let rec ana fn = fn %> F.map (fun x -> ana fn x) %> in_ + let hylo f g = cata g % ana f let rec para ralg = let fanout t = (t, para ralg t) in - out - %> F.map (fun x -> fanout x) - %> ralg - - let rec para' ralg t = - out t - |> F.map (fun x -> para' ralg x) - |> ralg t + out %> F.map (fun x -> fanout x) %> ralg + let rec para' ralg t = out t |> F.map (fun x -> para' ralg x) |> ralg t let rec apo rcoalg = - let fanin t = let open Either in - match t with Left x -> x | Right y -> apo rcoalg y in - rcoalg - %> F.map (fun x -> fanin x) - %> in_ - - type 'a attr = {attribute: 'a; hole: 'a attr F.t} + let fanin t = + let open Either in + match t with Left x -> x | Right y -> apo rcoalg y + in + rcoalg %> F.map (fun x -> fanin x) %> in_ - let mk_attr attribute hole = {attribute; hole} - let attribute {attribute;_} = attribute - let hole {hole;_} = hole + type 'a attr = { attribute : 'a; hole : 'a attr F.t } - let (&&&) f g x = f x, g x + let mk_attr attribute hole = { attribute; hole } + let attribute { attribute; _ } = attribute + let hole { hole; _ } = hole + let ( &&& ) f g x = (f x, g x) let histo h = let rec worker t = - out t |> F.map (fun x -> worker x) |> (h &&& Fun.id) |> uncurry mk_attr in - worker - %> attribute + out t |> F.map (fun x -> worker x) |> (h &&& Fun.id) |> uncurry mk_attr + in + worker %> attribute end (** A few examples of application @@ -83,24 +69,24 @@ end * let cata' f = para' (Fun.const f) *) - module type Algebra_t = sig type 'a base type 'a fix = Fix of 'a fix base - val in_: 'a fix base -> 'a fix - val out: 'a fix -> 'a fix base - val map: ('a fix -> 'a fix) -> 'b fix -> 'a fix - val cata: ('a base -> 'a) -> 'b fix -> 'a - val ana: ('a -> 'a base) -> 'a -> 'b fix - val hylo: ('a -> 'a base) -> ('b base -> 'b) ->'a -> 'b - val para: (('a fix * 'b) base -> 'b) -> 'a fix -> 'b - val para': ('a fix -> 'b base -> 'b) -> 'a fix -> 'b - val apo: ('a -> ('b fix, 'a) either base) -> 'a -> 'b fix + val in_ : 'a fix base -> 'a fix + val out : 'a fix -> 'a fix base + val map : ('a fix -> 'a fix) -> 'b fix -> 'a fix + val cata : ('a base -> 'a) -> 'b fix -> 'a + val ana : ('a -> 'a base) -> 'a -> 'b fix + val hylo : ('a -> 'a base) -> ('b base -> 'b) -> 'a -> 'b + val para : (('a fix * 'b) base -> 'b) -> 'a fix -> 'b + val para' : ('a fix -> 'b base -> 'b) -> 'a fix -> 'b + val apo : ('a -> ('b fix, 'a) either base) -> 'a -> 'b fix type 'a attr - val attribute: 'a attr -> 'a - val hole: 'a attr -> 'a attr base - val mk_attr: 'a -> 'a attr base -> 'a attr - val histo: ('a attr base -> 'a) -> 'b fix -> 'a + + val attribute : 'a attr -> 'a + val hole : 'a attr -> 'a attr base + val mk_attr : 'a -> 'a attr base -> 'a attr + val histo : ('a attr base -> 'a) -> 'b fix -> 'a end diff --git a/ml/lib/arch.ml b/ml/lib/arch.ml index 3ce405445a655667511c05d4b148e3185c53150f..13564de397d35a80f81c9decc9e8ad9831449930 100644 --- a/ml/lib/arch.ml +++ b/ml/lib/arch.ml @@ -1,27 +1,46 @@ module type Vec_arch_t = sig - val base_type_name: string + val base_type_name : string val vec_size : int - val vec_type_name: string - val transpose_func: string - val gen_load: (string -> string, unit, string, string, string, string) format6 - val gen_store: (string -> string -> string, unit, string, string, string, string) format6 - val gen_add: (string -> string -> string, unit, string, string, string, string) format6 - val gen_mul: (string -> string -> string, unit, string, string, string, string) format6 - val gen_sub: (string -> string -> string, unit, string, string, string, string) format6 - val gen_fma: (string -> string -> string -> string, unit, string, string, string, string) format6 - val gen_broadcast: (string -> string, unit, string, string, string, string) format6 + val vec_type_name : string + val transpose_func : string + + val gen_load : + (string -> string, unit, string, string, string, string) format6 + + val gen_store : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_add : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_mul : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_sub : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_fma : + ( string -> string -> string -> string, + unit, + string, + string, + string, + string ) + format6 + + val gen_broadcast : + (string -> string, unit, string, string, string, string) format6 end -module SSE: Vec_arch_t = struct +module SSE : Vec_arch_t = struct let base_type_name = "float" let vec_size = 4 - let vec_type_name = "__m128" - let transpose_func = "transpose_base_stride_sse" let gen_add = Stdlib.format_of_string "_mm_add_ps(%s, %s)" let gen_sub = Stdlib.format_of_string "_mm_sub_ps(%s, %s)" let gen_mul = Stdlib.format_of_string "_mm_mul_ps(%s, %s)" + (* gen_fma a b c = a * b + c *) let gen_fma = Stdlib.format_of_string "_mm_add_ps(_mm_mul_ps(%s, %s), %s)" let gen_broadcast = Stdlib.format_of_string "_mm_set1_ps(%s)" @@ -29,16 +48,15 @@ module SSE: Vec_arch_t = struct let gen_store = Stdlib.format_of_string "_mm_store_ps(&%s, %s)" end -module AVX2: Vec_arch_t = struct +module AVX2 : Vec_arch_t = struct let base_type_name = "float" let vec_size = 8 - let vec_type_name = "__m256" - let transpose_func = "transpose_base_stride_avx2" let gen_add = Stdlib.format_of_string "_mm256_add_ps(%s, %s)" let gen_sub = Stdlib.format_of_string "_mm256_sub_ps(%s, %s)" let gen_mul = Stdlib.format_of_string "_mm256_mul_ps(%s, %s)" + (* gen_fma a b c = a * b + c *) let gen_fma = Stdlib.format_of_string "_mm256_fmadd_ps(%s, %s, %s)" let gen_broadcast = Stdlib.format_of_string "_mm256_set1_ps(%s)" @@ -46,16 +64,15 @@ module AVX2: Vec_arch_t = struct let gen_store = Stdlib.format_of_string "_mm256_store_ps(&%s, %s)" end -module AVX2_unaligned: Vec_arch_t = struct +module AVX2_unaligned : Vec_arch_t = struct let base_type_name = "float" let vec_size = 8 - let vec_type_name = "__m256" - let transpose_func = "transpose_base_stride_avx2" let gen_add = Stdlib.format_of_string "_mm256_add_ps(%s, %s)" let gen_sub = Stdlib.format_of_string "_mm256_sub_ps(%s, %s)" let gen_mul = Stdlib.format_of_string "_mm256_mul_ps(%s, %s)" + (* gen_fma a b c = a * b + c *) let gen_fma = Stdlib.format_of_string "_mm256_fmadd_ps(%s, %s, %s)" let gen_broadcast = Stdlib.format_of_string "_mm256_set1_ps(%s)" @@ -63,30 +80,27 @@ module AVX2_unaligned: Vec_arch_t = struct let gen_store = Stdlib.format_of_string "_mm256_storeu_ps(&%s, %s)" end -module AVX2_INT32: Vec_arch_t = struct +module AVX2_INT32 : Vec_arch_t = struct let base_type_name = "int32_t" let vec_size = 8 - let vec_type_name = "__m256i" - let transpose_func = "unimplemented" let gen_add = Stdlib.format_of_string "_mm256_add_epi32(%s, %s)" let gen_sub = Stdlib.format_of_string "_mm256_sub_epi32(%s, %s)" let gen_mul = Stdlib.format_of_string "_mm256_mul_epi32(%s, %s)" - let gen_fma = Stdlib.format_of_string - "_mm256_add_epi32(_mm256_mul_epi32(%s, %s), %s)" + + let gen_fma = + Stdlib.format_of_string "_mm256_add_epi32(_mm256_mul_epi32(%s, %s), %s)" + let gen_broadcast = Stdlib.format_of_string "_mm256_set1_epi32(%s)" let gen_load = Stdlib.format_of_string "_mm256_load_epi32((void*)&%s)" let gen_store = Stdlib.format_of_string "_mm256_store_epi32((void*)&%s, %s)" end - -module AVX512: Vec_arch_t = struct +module AVX512 : Vec_arch_t = struct let base_type_name = "float" let vec_size = 16 - let vec_type_name = "__m512" - let transpose_func = "transpose_base_stride_avx512" let gen_add = Stdlib.format_of_string "_mm512_add_ps(%s, %s)" let gen_sub = Stdlib.format_of_string "_mm512_sub_ps(%s, %s)" @@ -97,12 +111,10 @@ module AVX512: Vec_arch_t = struct let gen_store = Stdlib.format_of_string "_mm512_store_ps(&%s, %s)" end -module AVX512_unaligned: Vec_arch_t = struct +module AVX512_unaligned : Vec_arch_t = struct let base_type_name = "float" let vec_size = 16 - let vec_type_name = "__m512" - let transpose_func = "transpose_base_stride_avx512" let gen_add = Stdlib.format_of_string "_mm512_add_ps(%s, %s)" let gen_sub = Stdlib.format_of_string "_mm512_sub_ps(%s, %s)" diff --git a/ml/lib/arch.mli b/ml/lib/arch.mli index ce5d30149241574b93907ed9e07306a059f3dc70..c695829b93a0135966a1cc7b3e44f8bcd2b762db 100644 --- a/ml/lib/arch.mli +++ b/ml/lib/arch.mli @@ -1,20 +1,40 @@ module type Vec_arch_t = sig - val base_type_name: string + val base_type_name : string val vec_size : int - val vec_type_name: string - val transpose_func: string - val gen_load: (string -> string, unit, string, string, string, string) format6 - val gen_store: (string -> string -> string, unit, string, string, string, string) format6 - val gen_add: (string -> string -> string, unit, string, string, string, string) format6 - val gen_mul: (string -> string -> string, unit, string, string, string, string) format6 - val gen_sub: (string -> string -> string, unit, string, string, string, string) format6 - val gen_fma: (string -> string -> string -> string, unit, string, string, string, string) format6 - val gen_broadcast: (string -> string, unit, string, string, string, string) format6 + val vec_type_name : string + val transpose_func : string + + val gen_load : + (string -> string, unit, string, string, string, string) format6 + + val gen_store : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_add : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_mul : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_sub : + (string -> string -> string, unit, string, string, string, string) format6 + + val gen_fma : + ( string -> string -> string -> string, + unit, + string, + string, + string, + string ) + format6 + + val gen_broadcast : + (string -> string, unit, string, string, string, string) format6 end -module SSE: Vec_arch_t -module AVX2: Vec_arch_t -module AVX2_unaligned: Vec_arch_t -module AVX2_INT32: Vec_arch_t -module AVX512: Vec_arch_t -module AVX512_unaligned: Vec_arch_t +module SSE : Vec_arch_t +module AVX2 : Vec_arch_t +module AVX2_unaligned : Vec_arch_t +module AVX2_INT32 : Vec_arch_t +module AVX512 : Vec_arch_t +module AVX512_unaligned : Vec_arch_t diff --git a/ml/lib/dim_info.ml b/ml/lib/dim_info.ml index e0c53cadb5aa8be6a4360b4a6f8670a8a0ecba43..11bf17705ad305cf263bbe774a426fe4c8fd39d9 100644 --- a/ml/lib/dim_info.ml +++ b/ml/lib/dim_info.ml @@ -6,43 +6,68 @@ open Loop_nest type div_constraint = Flexible of int | Div [@@deriving show] -module DI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct +module DI (Inst : Inst_t) (LN : Loopnest_t with module Inst := Inst) = struct + module DimMap = Make_map (struct + type t = Dim.t - module DimMap = Make_map(struct type t = Dim.t let compare = Dim.compare end) + let compare = Dim.compare + end) type loop = U of int | V | S of int | A - type bumped = (Index.t * Index.t) option - type t = {dim: Dim.t; incr: Expr.t; last_tile_size: Expr.t option; par: bool; - loop_scheme: loop list; is_vectorized: bool;is_closed:bool; - div_constraint: div_constraint; - index_clone: Index.gen; bump_info: bumped; next_index: Index.t; - current_index: Index.t; indexes_list: (Index.t * Index.gen * (Index.t * Expr.t) list) list } + type t = { + dim : Dim.t; + incr : Expr.t; + last_tile_size : Expr.t option; + par : bool; + loop_scheme : loop list; + is_vectorized : bool; + is_closed : bool; + div_constraint : div_constraint; + index_clone : Index.gen; + bump_info : bumped; + next_index : Index.t; + current_index : Index.t; + indexes_list : (Index.t * Index.gen * (Index.t * Expr.t) list) list; + } [@@deriving fields] (* A frozen snapshot of a dim state at some instant *) module Frozen = struct - module T_arg = struct - type t = {dim: Dim.t; incr: Expr.t; is_vectorized: bool; - current_pack: Index.t; - current_tile: Index.t; next_index: Index.t; - current_index: Index.t; indexes_list: (Index.t * (Index.t * Expr.t) list) list } + type t = { + dim : Dim.t; + incr : Expr.t; + is_vectorized : bool; + current_pack : Index.t; + current_tile : Index.t; + next_index : Index.t; + current_index : Index.t; + indexes_list : (Index.t * (Index.t * Expr.t) list) list; + } [@@deriving eq, show] end module C_arg = struct - type t = {dim: Dim.t; is_vectorized: bool; incr: Expr.t; last_index: Index.t; - last_tile: Index.t option; indexes_list: (Index.t * (Index.t * Expr.t) list) list } + type t = { + dim : Dim.t; + is_vectorized : bool; + incr : Expr.t; + last_index : Index.t; + last_tile : Index.t option; + indexes_list : (Index.t * (Index.t * Expr.t) list) list; + } [@@deriving eq, show] end - type t = Nothing of Dim.t (* This dimension did not appear in the scheme yet + type t = + | Nothing of Dim.t + (* This dimension did not appear in the scheme yet *) - | VecUnr of Dim.t * Expr.t * bool - | Tile of T_arg.t - | Close of C_arg.t (* This dimension is currently closed *) + | VecUnr of Dim.t * Expr.t * bool + | Tile of T_arg.t + | Close of C_arg.t (* This dimension is currently closed *) let dim_incr = function | Nothing _ -> Expr.const 0 @@ -53,85 +78,77 @@ module DI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct let filter_indexes big_tile_index small_tile_index = let rec aux = function | [], l -> Some l - | (ind1, _)::t1, (ind2, _)::t2 -> - if Index.equal ind1 ind2 then aux (t1, t2) - else None - | _ -> None in + | (ind1, _) :: t1, (ind2, _) :: t2 -> + if Index.equal ind1 ind2 then aux (t1, t2) else None + | _ -> None + in (* indexes are sorted from bigger to smaller, reverse that so we can filter *) - aux ((List.rev big_tile_index), (List.rev small_tile_index)) + aux (List.rev big_tile_index, List.rev small_tile_index) |> Option.map List.rev (* This type allows us to have exhaustiveness check when we match on a pair that we assume to be well-formed * VU means vector/unroll *) - type ordered_pair = Close_close of C_arg.t - | Close_tile of Index.t * C_arg.t * T_arg.t - | Close_vu of C_arg.t * Expr.t * bool - | Close_nothing of C_arg.t - | Tile_tile of Index.t option * T_arg.t * T_arg.t - | Same_tile of T_arg.t - | Tile_vu of T_arg.t * Expr.t * bool - | Tile_nothing of T_arg.t - | VU_VU of Dim.t * Expr.t * Expr.t * bool - | Same_VU of Dim.t * Expr.t * bool - | VU_nothing of Dim.t * Expr.t * bool - | Nothing_nothing of Dim.t [@@deriving show] + type ordered_pair = + | Close_close of C_arg.t + | Close_tile of Index.t * C_arg.t * T_arg.t + | Close_vu of C_arg.t * Expr.t * bool + | Close_nothing of C_arg.t + | Tile_tile of Index.t option * T_arg.t * T_arg.t + | Same_tile of T_arg.t + | Tile_vu of T_arg.t * Expr.t * bool + | Tile_nothing of T_arg.t + | VU_VU of Dim.t * Expr.t * Expr.t * bool + | Same_VU of Dim.t * Expr.t * bool + | VU_nothing of Dim.t * Expr.t * bool + | Nothing_nothing of Dim.t + [@@deriving show] (* returns Some ordered_pair if pairs are indeed well ordered, None elsewise *) - let are_well_ordered t1 t2 = match t1, t2 with - | Nothing d1, Nothing d2 -> - Option.from_cond_val (Dim.equal d1 d2) - (Nothing_nothing d1) - | VecUnr (d1, incr, is_vectorized), - Nothing d2 -> - Option.from_cond_val (Dim.equal d1 d2) - (VU_nothing (d1, incr, is_vectorized)) - | VecUnr (d1, incr1, is_v1), - VecUnr (d2, incr2, is_v2) -> - if (Expr.equal incr1 incr2 && is_v1 = is_v2) - then Some (Same_VU (d1, incr1, is_v1)) - else Option.from_cond_val (Dim.equal d1 d2 && (is_v1 = is_v2)) - (VU_VU (d1, incr1, incr2, is_v1)) - | Tile (T_arg.({dim =d2; _}) as t_arg), - Nothing d1 -> - Option.from_cond_val (Dim.equal d1 d2) - (Tile_nothing t_arg) - | Tile (T_arg.({dim =d2; _}) as t_arg), - VecUnr (d1, incr, is_vec) -> - Option.from_cond_val (Dim.equal d1 d2) - (Tile_vu (t_arg, incr, is_vec)) - | T_arg.(Tile ({dim =d1; indexes_list = big_indxs; _} as t_big), - Tile ({dim =d2;indexes_list = small_indxs;_} as t_small)) -> - Option.bind (filter_indexes big_indxs small_indxs) - (function - ((ind, _)::_) -> - Option.from_cond_val (Dim.equal d1 d2) - (Tile_tile (Some ind, t_big, t_small)) + let are_well_ordered t1 t2 = + match (t1, t2) with + | Nothing d1, Nothing d2 -> + Option.from_cond_val (Dim.equal d1 d2) (Nothing_nothing d1) + | VecUnr (d1, incr, is_vectorized), Nothing d2 -> + Option.from_cond_val (Dim.equal d1 d2) + (VU_nothing (d1, incr, is_vectorized)) + | VecUnr (d1, incr1, is_v1), VecUnr (d2, incr2, is_v2) -> + if Expr.equal incr1 incr2 && is_v1 = is_v2 then + Some (Same_VU (d1, incr1, is_v1)) + else + Option.from_cond_val + (Dim.equal d1 d2 && is_v1 = is_v2) + (VU_VU (d1, incr1, incr2, is_v1)) + | Tile (T_arg.{ dim = d2; _ } as t_arg), Nothing d1 -> + Option.from_cond_val (Dim.equal d1 d2) (Tile_nothing t_arg) + | Tile (T_arg.{ dim = d2; _ } as t_arg), VecUnr (d1, incr, is_vec) -> + Option.from_cond_val (Dim.equal d1 d2) (Tile_vu (t_arg, incr, is_vec)) + | T_arg.( + ( Tile ({ dim = d1; indexes_list = big_indxs; _ } as t_big), + Tile ({ dim = d2; indexes_list = small_indxs; _ } as t_small) )) -> + Option.bind (filter_indexes big_indxs small_indxs) (function + | (ind, _) :: _ -> + Option.from_cond_val (Dim.equal d1 d2) + (Tile_tile (Some ind, t_big, t_small)) | [] -> - if (T_arg.equal t_big t_small) - then (Some (Same_tile t_big)) - else (Option.from_cond_val (Dim.equal d1 d2) - (Tile_tile (None, t_big, t_small))) - ) - | Close (C_arg.({dim =d2; _}) as c_arg), - Nothing d1 -> - Option.from_cond_val (Dim.equal d1 d2) - (Close_nothing c_arg) - | Close (C_arg.({dim =d2; _}) as c_arg), - VecUnr (d1, incr, is_vec) -> - Option.from_cond_val (Dim.equal d1 d2) - (Close_vu (c_arg, incr, is_vec)) - | C_arg.(Close ({dim=d1; indexes_list= small_indxs;_} as c_big), - Tile ({dim=d2; indexes_list= big_indxs;_} as t_small)) -> - Option.bind (filter_indexes big_indxs small_indxs) - (function - ((ind, _)::_) -> - Option.from_cond_val (Dim.equal d1 d2) - (Close_tile (ind, c_big, t_small)) - | [] -> None - ) - | Close c1, Close c2 -> - Option.from_cond_val (C_arg.equal c1 c2) - (Close_close c1) + if T_arg.equal t_big t_small then Some (Same_tile t_big) + else + Option.from_cond_val (Dim.equal d1 d2) + (Tile_tile (None, t_big, t_small))) + | Close (C_arg.{ dim = d2; _ } as c_arg), Nothing d1 -> + Option.from_cond_val (Dim.equal d1 d2) (Close_nothing c_arg) + | Close (C_arg.{ dim = d2; _ } as c_arg), VecUnr (d1, incr, is_vec) -> + Option.from_cond_val (Dim.equal d1 d2) + (Close_vu (c_arg, incr, is_vec)) + | C_arg.( + ( Close ({ dim = d1; indexes_list = small_indxs; _ } as c_big), + Tile ({ dim = d2; indexes_list = big_indxs; _ } as t_small) )) -> + Option.bind (filter_indexes big_indxs small_indxs) (function + | (ind, _) :: _ -> + Option.from_cond_val (Dim.equal d1 d2) + (Close_tile (ind, c_big, t_small)) + | [] -> None) + | Close c1, Close c2 -> + Option.from_cond_val (C_arg.equal c1 c2) (Close_close c1) | _ -> None exception Ill_ordered_tiles @@ -139,415 +156,541 @@ module DI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct module Zip = LN.Zipper let gen_tile_loop to_vectorize = - let get_incr incr = match (incr, to_vectorize) with - | (Expr.Const i, true) + let get_incr incr = + match (incr, to_vectorize) with + | Expr.Const i, true when i mod Inst.A.vec_size = 0 && i >= Inst.A.vec_size -> - Expr.const Inst.A.vec_size - | (Expr.SizeVar _, true) -> Expr.const Inst.A.vec_size - | _ -> Expr.one in + Expr.const Inst.A.vec_size + | Expr.SizeVar _, true -> Expr.const Inst.A.vec_size + | _ -> Expr.one + in function - | Nothing_nothing dim - | VU_nothing (dim, _, _) -> - (* We pass here for weird reason, even if there is a "T w" over us *) - let index = Index.from_dim dim in - [], Expr.index index, Expr.zero, Fun.id + | Nothing_nothing dim | VU_nothing (dim, _, _) -> + (* We pass here for weird reason, even if there is a "T w" over us *) + let index = Index.from_dim dim in + ([], Expr.index index, Expr.zero, Fun.id) | Same_VU (dim, incr, _) -> - let increment = get_incr incr in - let index = Index.from_dim dim in - let global_index = Index.map_prefix index (fun s -> s ^ "g") in - let local_index = Index.map_prefix index (fun s -> s ^ "l") in - (*let tile_index = Index.map_prefix index (fun s -> s ^ "t") in*) - let start = Expr.index index in - let halt = Expr.(I.(start + incr )) in - let aux = [local_index, Expr.zero] in - [global_index; local_index], Expr.index global_index, Expr.index local_index, - Zip.new_seq ~aux global_index start halt increment + let increment = get_incr incr in + let index = Index.from_dim dim in + let global_index = Index.map_prefix index (fun s -> s ^ "g") in + let local_index = Index.map_prefix index (fun s -> s ^ "l") in + (*let tile_index = Index.map_prefix index (fun s -> s ^ "t") in*) + let start = Expr.index index in + let halt = Expr.(I.(start + incr)) in + let aux = [ (local_index, Expr.zero) ] in + ( [ global_index; local_index ], + Expr.index global_index, + Expr.index local_index, + Zip.new_seq ~aux global_index start halt increment ) | VU_VU _ -> - (* TODO handle this case *) - assert false - | Tile_nothing T_arg.({current_pack;_}) -> - [], Expr.index current_pack, Expr.zero, Fun.id - | Tile_vu (T_arg.({dim; current_pack; _}), incr, _) -> - let increment = get_incr incr in - let index = Index.from_dim dim in - let global_index = Index.map_prefix index (fun s -> s ^ "g") in - let local_index = Index.map_prefix index (fun s -> s ^ "l") in - let aux = [local_index, Expr.zero] in - let start = Expr.index current_pack in - let halt = Expr.(I.(start + incr )) in - [global_index; local_index], Expr.index global_index, Expr.index local_index, - Zip.new_seq ~aux global_index start halt increment - | Close_nothing C_arg.({dim;_}) -> - [], Expr.index @@ Index.from_dim dim, Expr.zero, - Fun.id - | Close_vu (C_arg.({dim; _}), incr, _) -> - let increment = get_incr incr in - let index = Index.from_dim dim in - let all_index = Index.map_prefix index (fun s -> s ^ "all") in - let local_index = Index.map_prefix index (fun s -> s ^ "l") in - let aux = [local_index, Expr.zero] in - let start = Expr.index index in - let halt = Expr.(I.(start + incr )) in - [all_index; local_index], Expr.index all_index, Expr.index local_index, - Zip.new_seq ~aux all_index start halt increment - | Tile_tile (Some tile_index, _, T_arg.({ incr; current_index; _})) -> - let increment = get_incr incr in - let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in - let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let start = Expr.index tile_index in - let halt = Expr.(I.(start + incr )) in - [glob_index; local_index], Expr.index glob_index, Expr.index local_index, - Zip.new_seq ~aux glob_index start halt increment - | Tile_tile (None, _, T_arg.({ incr; next_index; current_index; _})) -> - let increment = get_incr incr in - let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in - let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let start = Expr.index next_index in - let halt = Expr.(I.(start + incr )) in - [glob_index; local_index], Expr.index glob_index, Expr.index local_index, - Zip.new_seq ~aux glob_index start halt increment - | Same_tile T_arg.{ current_index; incr; _} -> - let increment = get_incr incr in - let _tile_index = Index.map_prefix current_index (fun s -> s ^ "tile") in - let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in - let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in - let start = Expr.index current_index in - let aux = [local_index, Expr.zero] in - let halt = Expr.(I.(start + incr )) in - [glob_index;local_index], Expr.index glob_index, Expr.index local_index, - Zip.new_seq ~aux glob_index start halt increment - | Close_tile (_, _, T_arg.({ incr; current_index; _})) -> - let increment = get_incr incr in - let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in - let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let start = Expr.index current_index in - let halt = Expr.(I.(start + incr )) in - [local_index; glob_index], Expr.index glob_index, Expr.index local_index, - Zip.new_seq ~aux glob_index start halt increment - | Close_close C_arg.{last_index; incr; dim;_} -> - let increment = get_incr incr in - let index = Index.map_prefix last_index (fun s -> s ^ "all") in - let start = Expr.zero in - let halt = Expr.(size @@ Dim.size_id dim) in - [index], Expr.index index, Expr.index index, Zip.new_seq index start halt increment + (* TODO handle this case *) + assert false + | Tile_nothing T_arg.{ current_pack; _ } -> + ([], Expr.index current_pack, Expr.zero, Fun.id) + | Tile_vu (T_arg.{ dim; current_pack; _ }, incr, _) -> + let increment = get_incr incr in + let index = Index.from_dim dim in + let global_index = Index.map_prefix index (fun s -> s ^ "g") in + let local_index = Index.map_prefix index (fun s -> s ^ "l") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index current_pack in + let halt = Expr.(I.(start + incr)) in + ( [ global_index; local_index ], + Expr.index global_index, + Expr.index local_index, + Zip.new_seq ~aux global_index start halt increment ) + | Close_nothing C_arg.{ dim; _ } -> + ([], Expr.index @@ Index.from_dim dim, Expr.zero, Fun.id) + | Close_vu (C_arg.{ dim; _ }, incr, _) -> + let increment = get_incr incr in + let index = Index.from_dim dim in + let all_index = Index.map_prefix index (fun s -> s ^ "all") in + let local_index = Index.map_prefix index (fun s -> s ^ "l") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index index in + let halt = Expr.(I.(start + incr)) in + ( [ all_index; local_index ], + Expr.index all_index, + Expr.index local_index, + Zip.new_seq ~aux all_index start halt increment ) + | Tile_tile (Some tile_index, _, T_arg.{ incr; current_index; _ }) -> + let increment = get_incr incr in + let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in + let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index tile_index in + let halt = Expr.(I.(start + incr)) in + ( [ glob_index; local_index ], + Expr.index glob_index, + Expr.index local_index, + Zip.new_seq ~aux glob_index start halt increment ) + | Tile_tile (None, _, T_arg.{ incr; next_index; current_index; _ }) -> + let increment = get_incr incr in + let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in + let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index next_index in + let halt = Expr.(I.(start + incr)) in + ( [ glob_index; local_index ], + Expr.index glob_index, + Expr.index local_index, + Zip.new_seq ~aux glob_index start halt increment ) + | Same_tile T_arg.{ current_index; incr; _ } -> + let increment = get_incr incr in + let _tile_index = + Index.map_prefix current_index (fun s -> s ^ "tile") + in + let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in + let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in + let start = Expr.index current_index in + let aux = [ (local_index, Expr.zero) ] in + let halt = Expr.(I.(start + incr)) in + ( [ glob_index; local_index ], + Expr.index glob_index, + Expr.index local_index, + Zip.new_seq ~aux glob_index start halt increment ) + | Close_tile (_, _, T_arg.{ incr; current_index; _ }) -> + let increment = get_incr incr in + let local_index = Index.map_prefix current_index (fun s -> s ^ "l") in + let glob_index = Index.map_prefix current_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index current_index in + let halt = Expr.(I.(start + incr)) in + ( [ local_index; glob_index ], + Expr.index glob_index, + Expr.index local_index, + Zip.new_seq ~aux glob_index start halt increment ) + | Close_close C_arg.{ last_index; incr; dim; _ } -> + let increment = get_incr incr in + let index = Index.map_prefix last_index (fun s -> s ^ "all") in + let start = Expr.zero in + let halt = Expr.(size @@ Dim.size_id dim) in + ( [ index ], + Expr.index index, + Expr.index index, + Zip.new_seq index start halt increment ) (* Special case for join dimension, we have to match on both dim state *) - let gen_fused_join_loop to_vectorize iter_dim main_dim aux_dim main_diff aux_diff = + let gen_fused_join_loop to_vectorize iter_dim main_dim aux_dim main_diff + aux_diff = let size_of dim = Expr.size @@ Dim.size_id dim in - let get_incr incr = match (incr, to_vectorize) with - | (Expr.Const i, true) + let get_incr incr = + match (incr, to_vectorize) with + | Expr.Const i, true when i mod Inst.A.vec_size = 0 && i >= Inst.A.vec_size -> - Expr.const Inst.A.vec_size - | (Expr.SizeVar _, true) -> Expr.const Inst.A.vec_size - | _ -> Expr.one in - match aux_diff, main_diff with - Nothing_nothing _, _ -> - let indexes, local_index, global_index, loop = - gen_tile_loop to_vectorize main_diff in - let aux_index = Index.from_dim aux_dim in - aux_index::indexes, [Dim.id aux_dim, Expr.index aux_index; Dim.id main_dim, global_index], - [Dim.id aux_dim, Expr.zero; Dim.id main_dim, local_index], loop - | Close_nothing _, (Nothing_nothing _ | VU_nothing _) -> - (* Never vectorize on small dimension *) - let indexes, local_index, _, loop = - gen_tile_loop false aux_diff in - let main_index = Index.from_dim main_dim in - main_index::indexes, [Dim.id main_dim, Expr.index main_index; - Dim.id aux_dim, Expr.index @@ Index.from_dim aux_dim], - [Dim.id main_dim, Expr.zero; Dim.id aux_dim, local_index], loop - | Close_nothing _, (Close_nothing _) -> - let indexes, _, _global_index, loop = - gen_tile_loop false aux_diff in - let main_index = Index.from_dim main_dim in - main_index::indexes, [Dim.id main_dim, Expr.index main_index; Dim.id - aux_dim, Expr.index @@ Index.from_dim aux_dim], - [Dim.id main_dim, Expr.zero; Dim.id aux_dim, Expr.zero], loop - | Close_nothing _, Close_vu (_, incr, _) -> - let increment = get_incr incr in - let main_index = Index.from_dim main_dim - and aux_index = Index.from_dim aux_dim in - let global_index = Index.map_prefix main_index (fun s -> s ^ "g") in - let local_index = Index.map_prefix main_index (fun s -> s ^ "l") in - let aux = [local_index, Expr.zero] in - let start = Expr.index main_index in - let halt = Expr.(I.(start + incr)) in - [global_index; local_index], - [Dim.id main_dim, Expr.index global_index; - Dim.id aux_dim, Expr.index aux_index], - [Dim.id main_dim, Expr.index local_index; - Dim.id aux_dim, Expr.zero], - Zip.new_seq ~aux global_index start halt increment - | Close_nothing _, Close_tile (main_ind, _, T_arg.({incr;_})) -> - let increment = get_incr incr in - (* Find better names. main_ind is the next index, main_index is ranging - * from main_ind to main_ind + incr *) - let main_index = Index.from_dim main_dim - and aux_index = Index.from_dim aux_dim in - let global_index = Index.map_prefix main_index (fun s -> s ^ "g") in - let local_index = Index.map_prefix main_index (fun s -> s ^ "l") in - let aux = [local_index, Expr.zero] in - let start = Expr.index main_ind in - let halt = Expr.(I.(start + incr)) in - [global_index; local_index], - [Dim.id main_dim, Expr.index global_index; - Dim.id aux_dim, Expr.index aux_index], - [Dim.id main_dim, Expr.index local_index; - Dim.id aux_dim, Expr.zero], - Zip.new_seq ~aux global_index start halt increment - | Close_nothing _, Close_close C_arg.({incr;_}) -> - let increment = get_incr incr in - let main_index = Index.from_dim main_dim in - let aux_index = Index.from_dim aux_dim in - let global_index = Index.map_prefix main_index (fun s -> s ^ "g") in - let local_index = Index.map_prefix main_index (fun s -> s ^ "l") in - let aux = [local_index, Expr.zero] in - let start = Expr.index aux_index in - let halt = Expr.(I.(start + size_of main_dim)) in - [global_index; local_index], - [Dim.id main_dim, Expr.index global_index; - Dim.id aux_dim, Expr.index @@ Index.from_dim aux_dim], - [Dim.id main_dim, Expr.index local_index; - Dim.id aux_dim, Expr.zero], - Zip.new_seq ~aux global_index start halt increment + Expr.const Inst.A.vec_size + | Expr.SizeVar _, true -> Expr.const Inst.A.vec_size + | _ -> Expr.one + in + match (aux_diff, main_diff) with + | Nothing_nothing _, _ -> + let indexes, local_index, global_index, loop = + gen_tile_loop to_vectorize main_diff + in + let aux_index = Index.from_dim aux_dim in + ( aux_index :: indexes, + [ + (Dim.id aux_dim, Expr.index aux_index); + (Dim.id main_dim, global_index); + ], + [ (Dim.id aux_dim, Expr.zero); (Dim.id main_dim, local_index) ], + loop ) + | Close_nothing _, (Nothing_nothing _ | VU_nothing _) -> + (* Never vectorize on small dimension *) + let indexes, local_index, _, loop = gen_tile_loop false aux_diff in + let main_index = Index.from_dim main_dim in + ( main_index :: indexes, + [ + (Dim.id main_dim, Expr.index main_index); + (Dim.id aux_dim, Expr.index @@ Index.from_dim aux_dim); + ], + [ (Dim.id main_dim, Expr.zero); (Dim.id aux_dim, local_index) ], + loop ) + | Close_nothing _, Close_nothing _ -> + let indexes, _, _global_index, loop = gen_tile_loop false aux_diff in + let main_index = Index.from_dim main_dim in + ( main_index :: indexes, + [ + (Dim.id main_dim, Expr.index main_index); + (Dim.id aux_dim, Expr.index @@ Index.from_dim aux_dim); + ], + [ (Dim.id main_dim, Expr.zero); (Dim.id aux_dim, Expr.zero) ], + loop ) + | Close_nothing _, Close_vu (_, incr, _) -> + let increment = get_incr incr in + let main_index = Index.from_dim main_dim + and aux_index = Index.from_dim aux_dim in + let global_index = Index.map_prefix main_index (fun s -> s ^ "g") in + let local_index = Index.map_prefix main_index (fun s -> s ^ "l") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index main_index in + let halt = Expr.(I.(start + incr)) in + ( [ global_index; local_index ], + [ + (Dim.id main_dim, Expr.index global_index); + (Dim.id aux_dim, Expr.index aux_index); + ], + [ + (Dim.id main_dim, Expr.index local_index); + (Dim.id aux_dim, Expr.zero); + ], + Zip.new_seq ~aux global_index start halt increment ) + | Close_nothing _, Close_tile (main_ind, _, T_arg.{ incr; _ }) -> + let increment = get_incr incr in + (* Find better names. main_ind is the next index, main_index is ranging + * from main_ind to main_ind + incr *) + let main_index = Index.from_dim main_dim + and aux_index = Index.from_dim aux_dim in + let global_index = Index.map_prefix main_index (fun s -> s ^ "g") in + let local_index = Index.map_prefix main_index (fun s -> s ^ "l") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index main_ind in + let halt = Expr.(I.(start + incr)) in + ( [ global_index; local_index ], + [ + (Dim.id main_dim, Expr.index global_index); + (Dim.id aux_dim, Expr.index aux_index); + ], + [ + (Dim.id main_dim, Expr.index local_index); + (Dim.id aux_dim, Expr.zero); + ], + Zip.new_seq ~aux global_index start halt increment ) + | Close_nothing _, Close_close C_arg.{ incr; _ } -> + let increment = get_incr incr in + let main_index = Index.from_dim main_dim in + let aux_index = Index.from_dim aux_dim in + let global_index = Index.map_prefix main_index (fun s -> s ^ "g") in + let local_index = Index.map_prefix main_index (fun s -> s ^ "l") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index aux_index in + let halt = Expr.(I.(start + size_of main_dim)) in + ( [ global_index; local_index ], + [ + (Dim.id main_dim, Expr.index global_index); + (Dim.id aux_dim, Expr.index @@ Index.from_dim aux_dim); + ], + [ + (Dim.id main_dim, Expr.index local_index); + (Dim.id aux_dim, Expr.zero); + ], + Zip.new_seq ~aux global_index start halt increment ) | Close_nothing _, _ -> - print_endline @@ show_ordered_pair main_diff; - print_endline @@ show_ordered_pair aux_diff; - assert false + print_endline @@ show_ordered_pair main_diff; + print_endline @@ show_ordered_pair aux_diff; + assert false | Close_vu _, (Nothing_nothing _ | VU_nothing _) -> - (* Never vectorize on small dimension *) - let indexes, local_index, global_index, loop = - gen_tile_loop false aux_diff in - let main_index = Index.from_dim main_dim in - main_index::indexes, [Dim.id main_dim, Expr.index main_index; Dim.id aux_dim, global_index], - [Dim.id main_dim, Expr.zero; Dim.id aux_dim, local_index], loop - | Close_vu (_, aux_incr, _), Close_tile (main_index, _, T_arg.({incr;_})) -> - let increment = get_incr incr in - let aux_index = Index.from_dim aux_dim in - let global_index = Index.from_dim iter_dim in - let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let start = Expr.(I.(index main_index + index aux_index)) in - let halt = Expr.(I.( - start - + aux_incr - + incr - - const 1 - )) in - [global_index; local_index], - [Dim.id iter_dim, Expr.index global_index], - [Dim.id iter_dim, Expr.index local_index], - Zip.new_seq ~aux global_index start halt increment - | Close_vu (_,aux_incr,_), Close_close C_arg.({incr;_}) -> - let increment = get_incr incr in - let aux_index = Index.from_dim aux_dim in - let global_index = Index.from_dim iter_dim in - let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let start = Expr.index aux_index in - let halt = Expr.(I.( start - + size_of main_dim - + aux_incr - - const 1)) in - [global_index; local_index], - [Dim.id iter_dim, Expr.index global_index], - [Dim.id iter_dim, Expr.index local_index], - Zip.new_seq ~aux global_index start halt increment - | Close_vu (_, aux_incr, _), Close_nothing C_arg.({incr;_}) -> - let increment = get_incr incr in - let global_index = Index.from_dim iter_dim in - let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let main_index = Index.from_dim main_dim - and aux_index = Index.from_dim aux_dim in - let start = Expr.(I.(index main_index + index aux_index)) in - let halt = Expr.(I.(start + aux_incr - const 1)) in - [global_index; local_index], - [Dim.id iter_dim, Expr.index global_index], - [Dim.id iter_dim, Expr.index local_index], - Zip.new_seq ~aux global_index start halt increment + (* Never vectorize on small dimension *) + let indexes, local_index, global_index, loop = + gen_tile_loop false aux_diff + in + let main_index = Index.from_dim main_dim in + ( main_index :: indexes, + [ + (Dim.id main_dim, Expr.index main_index); + (Dim.id aux_dim, global_index); + ], + [ (Dim.id main_dim, Expr.zero); (Dim.id aux_dim, local_index) ], + loop ) + | Close_vu (_, aux_incr, _), Close_tile (main_index, _, T_arg.{ incr; _ }) + -> + let increment = get_incr incr in + let aux_index = Index.from_dim aux_dim in + let global_index = Index.from_dim iter_dim in + let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.(I.(index main_index + index aux_index)) in + let halt = Expr.(I.(start + aux_incr + incr - const 1)) in + ( [ global_index; local_index ], + [ (Dim.id iter_dim, Expr.index global_index) ], + [ (Dim.id iter_dim, Expr.index local_index) ], + Zip.new_seq ~aux global_index start halt increment ) + | Close_vu (_, aux_incr, _), Close_close C_arg.{ incr; _ } -> + let increment = get_incr incr in + let aux_index = Index.from_dim aux_dim in + let global_index = Index.from_dim iter_dim in + let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index aux_index in + let halt = Expr.(I.(start + size_of main_dim + aux_incr - const 1)) in + ( [ global_index; local_index ], + [ (Dim.id iter_dim, Expr.index global_index) ], + [ (Dim.id iter_dim, Expr.index local_index) ], + Zip.new_seq ~aux global_index start halt increment ) + | Close_vu (_, aux_incr, _), Close_nothing C_arg.{ incr; _ } -> + let increment = get_incr incr in + let global_index = Index.from_dim iter_dim in + let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let main_index = Index.from_dim main_dim + and aux_index = Index.from_dim aux_dim in + let start = Expr.(I.(index main_index + index aux_index)) in + let halt = Expr.(I.(start + aux_incr - const 1)) in + ( [ global_index; local_index ], + [ (Dim.id iter_dim, Expr.index global_index) ], + [ (Dim.id iter_dim, Expr.index local_index) ], + Zip.new_seq ~aux global_index start halt increment ) | Close_vu _, _ -> - print_endline @@ show_ordered_pair main_diff; - print_endline @@ show_ordered_pair aux_diff; - assert false + print_endline @@ show_ordered_pair main_diff; + print_endline @@ show_ordered_pair aux_diff; + assert false | Close_close _aux_cargs, (Nothing_nothing _ | VU_nothing _) -> (* Never vectorize on small dimension *) let indexes, local_index, global_index, loop = - gen_tile_loop false aux_diff in + gen_tile_loop false aux_diff + in let main_index = Index.from_dim main_dim in - main_index::indexes, - [Dim.id main_dim, Expr.index main_index; Dim.id aux_dim, global_index], - [Dim.id main_dim, Expr.zero; Dim.id aux_dim, local_index], - loop + ( main_index :: indexes, + [ + (Dim.id main_dim, Expr.index main_index); + (Dim.id aux_dim, global_index); + ], + [ (Dim.id main_dim, Expr.zero); (Dim.id aux_dim, local_index) ], + loop ) | Close_close _, Close_vu (_, main_incr, _) -> - let increment = get_incr main_incr in - let iter_index = Index.from_dim iter_dim in - let global_index = Index.map_prefix iter_index (fun s -> s ^ "_g") in - let local_index = Index.map_prefix iter_index (fun s -> s ^ "_l") in - let main_index = Index.from_dim main_dim in - let aux = [local_index, Expr.zero] in - let start = Expr.index main_index in - let halt = Expr.(I.( start - + main_incr - + size_of aux_dim - - const 1)) in - [global_index; local_index], - [Dim.id iter_dim, Expr.index global_index], - [Dim.id iter_dim, Expr.index local_index], - Zip.new_seq ~aux global_index start halt increment - | Close_close _, Close_tile (ind, _, T_arg.({incr;_})) -> - let increment = get_incr incr in - let global_index = Index.from_dim iter_dim in - let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let start = Expr.index ind in - let halt = Expr.(I.( start - + incr - + size_of aux_dim - - const 1 - ) - ) in - [global_index; local_index], - [Dim.id iter_dim, Expr.index global_index], - [Dim.id iter_dim, Expr.index local_index], - Zip.new_seq ~aux global_index start halt increment - | Close_close _, Close_close C_arg.({incr;_}) -> - let increment = get_incr incr in - let global_index = Index.from_dim iter_dim in - let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let start = Expr.zero in - let halt = Expr.(I.(size_of main_dim - + size_of aux_dim - - const 1 - )) in - [global_index; local_index], - [Dim.id iter_dim, Expr.index global_index], - [Dim.id iter_dim, Expr.index local_index], - Zip.new_seq ~aux global_index start halt increment - | Close_close _, Close_nothing C_arg.({incr;_}) -> - let increment = get_incr incr in - let global_index = Index.from_dim iter_dim in - let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in - let aux = [local_index, Expr.zero] in - let main_index = Index.from_dim main_dim in - let start = Expr.index main_index in - let halt = Expr.(I.(start + size_of aux_dim - const 1)) in - [global_index; local_index], - [Dim.id iter_dim, Expr.index global_index], - [Dim.id iter_dim, Expr.index local_index], - Zip.new_seq ~aux global_index start halt increment + let increment = get_incr main_incr in + let iter_index = Index.from_dim iter_dim in + let global_index = Index.map_prefix iter_index (fun s -> s ^ "_g") in + let local_index = Index.map_prefix iter_index (fun s -> s ^ "_l") in + let main_index = Index.from_dim main_dim in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index main_index in + let halt = Expr.(I.(start + main_incr + size_of aux_dim - const 1)) in + ( [ global_index; local_index ], + [ (Dim.id iter_dim, Expr.index global_index) ], + [ (Dim.id iter_dim, Expr.index local_index) ], + Zip.new_seq ~aux global_index start halt increment ) + | Close_close _, Close_tile (ind, _, T_arg.{ incr; _ }) -> + let increment = get_incr incr in + let global_index = Index.from_dim iter_dim in + let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.index ind in + let halt = Expr.(I.(start + incr + size_of aux_dim - const 1)) in + ( [ global_index; local_index ], + [ (Dim.id iter_dim, Expr.index global_index) ], + [ (Dim.id iter_dim, Expr.index local_index) ], + Zip.new_seq ~aux global_index start halt increment ) + | Close_close _, Close_close C_arg.{ incr; _ } -> + let increment = get_incr incr in + let global_index = Index.from_dim iter_dim in + let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let start = Expr.zero in + let halt = Expr.(I.(size_of main_dim + size_of aux_dim - const 1)) in + ( [ global_index; local_index ], + [ (Dim.id iter_dim, Expr.index global_index) ], + [ (Dim.id iter_dim, Expr.index local_index) ], + Zip.new_seq ~aux global_index start halt increment ) + | Close_close _, Close_nothing C_arg.{ incr; _ } -> + let increment = get_incr incr in + let global_index = Index.from_dim iter_dim in + let local_index = Index.map_prefix global_index (fun s -> s ^ "g") in + let aux = [ (local_index, Expr.zero) ] in + let main_index = Index.from_dim main_dim in + let start = Expr.index main_index in + let halt = Expr.(I.(start + size_of aux_dim - const 1)) in + ( [ global_index; local_index ], + [ (Dim.id iter_dim, Expr.index global_index) ], + [ (Dim.id iter_dim, Expr.index local_index) ], + Zip.new_seq ~aux global_index start halt increment ) | _ -> - print_endline @@ show_ordered_pair main_diff; - print_endline @@ show_ordered_pair aux_diff; - assert false - + print_endline @@ show_ordered_pair main_diff; + print_endline @@ show_ordered_pair aux_diff; + assert false let interval_loop_tile big_tile small_tile to_vectorize = - let ordered_pair = Option.get_exn (are_well_ordered big_tile small_tile) Ill_ordered_tiles in + let ordered_pair = + Option.get_exn (are_well_ordered big_tile small_tile) Ill_ordered_tiles + in gen_tile_loop to_vectorize ordered_pair - let interval_loop_tile_join iter_dim main_dim aux_dim - main_big_tile main_small_tile - aux_big_tile aux_small_tile - to_vectorize = - let main_ordered_pair = Option.get_exn (are_well_ordered main_big_tile - main_small_tile) Ill_ordered_tiles in - let aux_ordered_pair = Option.get_exn (are_well_ordered aux_big_tile aux_small_tile) Ill_ordered_tiles in + let interval_loop_tile_join iter_dim main_dim aux_dim main_big_tile + main_small_tile aux_big_tile aux_small_tile to_vectorize = + let main_ordered_pair = + Option.get_exn + (are_well_ordered main_big_tile main_small_tile) + Ill_ordered_tiles + in + let aux_ordered_pair = + Option.get_exn + (are_well_ordered aux_big_tile aux_small_tile) + Ill_ordered_tiles + in gen_fused_join_loop to_vectorize iter_dim main_dim aux_dim main_ordered_pair aux_ordered_pair end (* Get a frozen snapshot of the state of this dimension at a given moment *) let freeze dim = function - Some {dim; incr; is_vectorized; is_closed; bump_info; - current_index; next_index; indexes_list;_} -> - let open Frozen in - let indexes_list = List.map (fun (ind, _, expr_list) -> ind, expr_list) indexes_list in - begin match bump_info with + | Some + { + dim; + incr; + is_vectorized; + is_closed; + bump_info; + current_index; + next_index; + indexes_list; + _; + } -> ( + let open Frozen in + let indexes_list = + List.map (fun (ind, _, expr_list) -> (ind, expr_list)) indexes_list + in + match bump_info with | Some (current_tile, current_pack) -> - if is_closed - then - Close {dim; incr; is_vectorized; last_tile = Some current_tile; - last_index = current_index; indexes_list} - else - Tile {dim; incr; is_vectorized; current_index; next_index; - current_pack; current_tile; indexes_list} + if is_closed then + Close + { + dim; + incr; + is_vectorized; + last_tile = Some current_tile; + last_index = current_index; + indexes_list; + } + else + Tile + { + dim; + incr; + is_vectorized; + current_index; + next_index; + current_pack; + current_tile; + indexes_list; + } | None -> - if is_closed - then - Close {dim; incr; is_vectorized; last_tile = None; - last_index = current_index; indexes_list} - else VecUnr (dim, incr, is_vectorized) - end - | None -> Nothing dim - - let current_tile {bump_info;_} = Option.map fst bump_info - let current_pack {bump_info;_} = Option.map snd bump_info - - let bump_index ({next_index; index_clone; bump_info; current_index; indexes_list; _} as elem) = + if is_closed then + Close + { + dim; + incr; + is_vectorized; + last_tile = None; + last_index = current_index; + indexes_list; + } + else VecUnr (dim, incr, is_vectorized)) + | None -> Nothing dim + + let current_tile { bump_info; _ } = Option.map fst bump_info + let current_pack { bump_info; _ } = Option.map snd bump_info + + let bump_index + ({ next_index; index_clone; bump_info; current_index; indexes_list; _ } as + elem) = let new_next = index_clone () in match bump_info with | None -> - let current_pack, new_gen = Index.fresh_clone_mark current_index "p" () in - let indexes_list = [current_index, new_gen, [current_pack, Expr.zero]] in - {elem with next_index = new_next; bump_info = Some (current_pack, current_pack); - indexes_list; current_index = next_index} + let current_pack, new_gen = + Index.fresh_clone_mark current_index "p" () + in + let indexes_list = + [ (current_index, new_gen, [ (current_pack, Expr.zero) ]) ] + in + { + elem with + next_index = new_next; + bump_info = Some (current_pack, current_pack); + indexes_list; + current_index = next_index; + } | Some _ -> - let fold_aux (new_cloned_index, acc) (ind, ind_clone, ind_list) = - let new_clone = ind_clone () in - let ind_list = (new_clone, Expr.index new_cloned_index)::ind_list in - (new_clone, (ind, ind_clone, ind_list)::acc) in - let base_new_index, new_gen = Index.fresh_clone_mark current_index "p" () in - let last_cloned, folded_list = List.fold_left fold_aux (base_new_index, []) indexes_list in - let indexes_list = (current_index, new_gen, [base_new_index, Expr.zero])::(List.rev folded_list) in - {elem with next_index = new_next; bump_info= Some (base_new_index, last_cloned); - indexes_list; current_index = next_index} - - let is_closed {is_closed;_} = is_closed - - let vectorize elem = assert (not (elem.is_vectorized)); - {elem with incr = Expr.const Inst.A.vec_size; - loop_scheme = V::elem.loop_scheme; - div_constraint = Div; - indexes_list = []; is_vectorized = true} - - + let fold_aux (new_cloned_index, acc) (ind, ind_clone, ind_list) = + let new_clone = ind_clone () in + let ind_list = (new_clone, Expr.index new_cloned_index) :: ind_list in + (new_clone, (ind, ind_clone, ind_list) :: acc) + in + let base_new_index, new_gen = + Index.fresh_clone_mark current_index "p" () + in + let last_cloned, folded_list = + List.fold_left fold_aux (base_new_index, []) indexes_list + in + let indexes_list = + (current_index, new_gen, [ (base_new_index, Expr.zero) ]) + :: List.rev folded_list + in + { + elem with + next_index = new_next; + bump_info = Some (base_new_index, last_cloned); + indexes_list; + current_index = next_index; + } + + let is_closed { is_closed; _ } = is_closed + + let vectorize elem = + assert (not elem.is_vectorized); + { + elem with + incr = Expr.const Inst.A.vec_size; + loop_scheme = V :: elem.loop_scheme; + div_constraint = Div; + indexes_list = []; + is_vectorized = true; + } - let close ({is_closed; dim; incr; bump_info; indexes_list; current_index;_}as elem) = if is_closed - then (Printf.printf "Dim %s is already closed\n" (Dim.show dim); assert false) + let close + ({ is_closed; dim; incr; bump_info; indexes_list; current_index; _ } as + elem) = + if is_closed then ( + Printf.printf "Dim %s is already closed\n" (Dim.show dim); + assert false) else let clone = Index.clone current_index in - let indexes_list = (current_index, clone, [])::indexes_list in + let indexes_list = (current_index, clone, []) :: indexes_list in let last_tile_size = Some incr in let incr = Expr.size @@ Dim.size_id dim in (* TODO what is the right thing to do here ? *) - let bump_info = match bump_info with + let bump_info = + match bump_info with | None -> Some (Index.from_dim dim, Index.from_dim dim) - | _ -> bump_info in - {elem with indexes_list; bump_info; last_tile_size; incr; is_closed = true} + | _ -> bump_info + in + { + elem with + indexes_list; + bump_info; + last_tile_size; + incr; + is_closed = true; + } (* Apply a function to incr *) - let map_incr f elem = {elem with incr = f elem.incr} - (* Apply a function to incr *) - let set_par elem = {elem with par = true} + let map_incr f elem = { elem with incr = f elem.incr } - let set_div_constraint div_constraint elem = {elem with div_constraint} + (* Apply a function to incr *) + let set_par elem = { elem with par = true } + let set_div_constraint div_constraint elem = { elem with div_constraint } - let cur_bound_var {current_index;_} = "N_" ^ Index.show_id_of_t current_index + let cur_bound_var { current_index; _ } = + "N_" ^ Index.show_id_of_t current_index - let next_bound_var {next_index;_} = "N_" ^ Index.show_id_of_t next_index + let next_bound_var { next_index; _ } = "N_" ^ Index.show_id_of_t next_index let default dim = let current_index = Index.from_dim dim in let index_clone = Index.clone current_index in let next_index = index_clone () in - {dim; par = false; incr = Expr.one; - div_constraint = Flexible 1; - last_tile_size = None; current_index; - next_index ; loop_scheme = []; index_clone; - bump_info = None; is_closed = false; is_vectorized = false; indexes_list = [] ; + { + dim; + par = false; + incr = Expr.one; + div_constraint = Flexible 1; + last_tile_size = None; + current_index; + next_index; + loop_scheme = []; + index_clone; + bump_info = None; + is_closed = false; + is_vectorized = false; + indexes_list = []; } end diff --git a/ml/lib/exprs.ml b/ml/lib/exprs.ml index 5eec9025550906261deb8d9ec674c7b33ac9c854..cc560aacb6d24a1c0a3dea08ddfbd4db3e7a549a 100644 --- a/ml/lib/exprs.ml +++ b/ml/lib/exprs.ml @@ -2,7 +2,6 @@ open Utils open Ids module Expr = struct - type t = | Const of int | One @@ -17,33 +16,29 @@ module Expr = struct | IndVar of Index.t let is_atomic = function - | Const _ | One | Zero | IndVar _ | PlaceHolder _ | Var _ | SizeVar _ -> true + | Const _ | One | Zero | IndVar _ | PlaceHolder _ | Var _ | SizeVar _ -> + true | _ -> false let rec is_constant = function | IndVar _ | Var _ -> false - | Zero | One | Const _ | SizeVar _ | PlaceHolder _ -> true + | Zero | One | Const _ | SizeVar _ | PlaceHolder _ -> true | Add (e1, e2) | Sub (e1, e2) | Mul (e1, e2) | Min (e1, e2) -> - is_constant e1 && is_constant e2 - + is_constant e1 && is_constant e2 let rec show = - let show_paren e = if is_atomic e - then show e - else Format.sprintf "(%s)" (show e) in - function + let show_paren e = + if is_atomic e then show e else Format.sprintf "(%s)" (show e) + in + function | Const i -> Format.sprintf "%d" i | Zero -> "0" | One -> "1" | Var name -> name - | Add (e1, e2) -> Format.sprintf "%s + %s" (show_paren e1) - (show_paren e2) - | Sub (e1, e2) -> Format.sprintf "%s - %s" (show_paren e1) (show_paren e2) - | Mul (e1, e2) -> - Format.sprintf "%s * %s" (show_paren e1) (show_paren e2) - | Min (e1, e2) -> - Format.sprintf "MIN(%s, %s)" (show e1) - (show e2) + | Add (e1, e2) -> Format.sprintf "%s + %s" (show_paren e1) (show_paren e2) + | Sub (e1, e2) -> Format.sprintf "%s - %s" (show_paren e1) (show_paren e2) + | Mul (e1, e2) -> Format.sprintf "%s * %s" (show_paren e1) (show_paren e2) + | Min (e1, e2) -> Format.sprintf "MIN(%s, %s)" (show e1) (show e2) | PlaceHolder s -> s | SizeVar ind_id -> Format.sprintf "%s" (Size.show_id ind_id) | IndVar index -> Format.sprintf "%s" (Index.show_id (Index.id index)) @@ -51,46 +46,42 @@ module Expr = struct let pp fmt expr = Format.fprintf fmt "%s" (show expr) let rec recursor f = function - | (Const _) | One | Zero as expr -> f expr - | Add (e1, e2) -> begin - match (recursor f e1, recursor f e2) with - e1, Zero -> e1 + | (Const _ | One | Zero) as expr -> f expr + | Add (e1, e2) -> ( + match (recursor f e1, recursor f e2) with + | e1, Zero -> e1 | Zero, e2 -> e2 | Const c1, Const c2 -> Const (c1 + c2) - | e1, e2 -> Add (e1, e2) - end - | Sub (e1, e2) -> begin - match (recursor f e1, recursor f e2) with - e1, Zero -> e1 + | e1, e2 -> Add (e1, e2)) + | Sub (e1, e2) -> ( + match (recursor f e1, recursor f e2) with + | e1, Zero -> e1 | Const c1, Const c2 -> Const (c1 - c2) - | e1, e2 -> Sub (e1, e2) - end - | Mul (e1, e2) -> - begin - match (recursor f e1, recursor f e2) with - _, Zero -> Zero + | e1, e2 -> Sub (e1, e2)) + | Mul (e1, e2) -> ( + match (recursor f e1, recursor f e2) with + | _, Zero -> Zero | Zero, _ -> Zero | e1, One -> e1 | One, e2 -> e2 | Const c1, Const c2 -> Const (c1 * c2) - | e1, e2 -> Mul (e1, e2) - end - | Min (e1, e2) -> begin + | e1, e2 -> Mul (e1, e2)) + | Min (e1, e2) -> ( match (recursor f e1, recursor f e2) with | _, Zero | Zero, _ -> Zero | One, Const c | Const c, One -> if c < 1 then Zero else One | Const c1, Const c2 -> Const (min c1 c2) - | e1, e2 -> Min (e1, e2) - end + | e1, e2 -> Min (e1, e2)) | SizeVar _ as sv -> f sv | PlaceHolder _ as ph -> f ph | IndVar _ as e -> f e | Var _ as e -> f e - let rec equal e1 e2 = match e1, e2 with + let rec equal e1 e2 = + match (e1, e2) with | Const c1, Const c2 -> Int.equal c1 c2 - | One, One -> true - | Zero, Zero -> true + | One, One -> true + | Zero, Zero -> true | Add (e11, e12), Add (e21, e22) -> equal e11 e21 && equal e12 e22 | Sub (e11, e12), Sub (e21, e22) -> equal e11 e21 && equal e12 e22 | Mul (e11, e12), Mul (e21, e22) -> equal e11 e21 && equal e12 e22 @@ -98,190 +89,162 @@ module Expr = struct | SizeVar sid1, SizeVar sid2 -> Size.equal_id sid1 sid2 | PlaceHolder s1, PlaceHolder s2 -> String.equal s1 s2 | IndVar vid1, IndVar vid2 -> Index.equal vid1 vid2 - | Var n1, Var n2 -> String.equal n1 n2 + | Var n1, Var n2 -> String.equal n1 n2 | _ -> false - - let one = One - let zero = Zero + let const i = match i with 0 -> Zero | 1 -> One | _ -> Const i - let const i = match i with - | 0 -> Zero - | 1 -> One - | _ -> Const i - - let add e1 e2 = match e1, e2 with + let add e1 e2 = + match (e1, e2) with | Zero, e | e, Zero -> e | Const c1, Const c2 -> Const (c1 + c2) | _ -> Add (e1, e2) - let sub e1 e2 = match e1, e2 with + let sub e1 e2 = + match (e1, e2) with | e, Zero -> e | Const c1, Const c2 -> Const (c1 - c2) | _ -> Sub (e1, e2) - let mul e1 e2 = match e1, e2 with + let mul e1 e2 = + match (e1, e2) with | One, e | e, One -> e | Zero, _ | _, Zero -> Zero | Const c1, Const c2 -> Const (c1 * c2) | _ -> Mul (e1, e2) let min e1 e2 = - match e1, e2 with + match (e1, e2) with | _, Zero | Zero, _ -> Zero | One, Const c | Const c, One -> if c < 1 then Zero else One | Const c1, Const c2 -> Const (min c1 c2) | e1, e2 -> Min (e1, e2) let rec to_int_opt sizes_indexes = - let open Opt_syntax in function + let open Opt_syntax in + function | SizeVar id -> List.assoc_opt id sizes_indexes | PlaceHolder _ -> None | Const i -> Some i | One -> Some 1 | Zero -> Some 0 | Add (x, y) -> - let+ x' = to_int_opt sizes_indexes x - and+ y' = to_int_opt sizes_indexes y in - x' + y' + let+ x' = to_int_opt sizes_indexes x + and+ y' = to_int_opt sizes_indexes y in + x' + y' | Sub (x, y) -> - let+ x' = to_int_opt sizes_indexes x - and+ y' = to_int_opt sizes_indexes y in - x' - y' + let+ x' = to_int_opt sizes_indexes x + and+ y' = to_int_opt sizes_indexes y in + x' - y' | Mul (x, y) -> - let+ x' = to_int_opt sizes_indexes x - and+ y' = to_int_opt sizes_indexes y in - x' * y' + let+ x' = to_int_opt sizes_indexes x + and+ y' = to_int_opt sizes_indexes y in + x' * y' | Min (x, y) -> - let+ x' = to_int_opt sizes_indexes x - and+ y' = to_int_opt sizes_indexes y in - Int.min x' y' + let+ x' = to_int_opt sizes_indexes x + and+ y' = to_int_opt sizes_indexes y in + Int.min x' y' | _ -> None let simplify' = let f = function - | Add (Zero, e) | Add (e, Zero) -> e - | Mul (Zero, _) | Mul (_, Zero) -> Zero - | Mul (One, e) | Mul (e, One) -> e - | Mul (Const c1, Const c2) -> Const (c1 * c2) - | Add (Const c1, Const c2) -> Const (c1 + c2) - | Sub (Const c1, Const c2) -> Const (c1 - c2) - | e -> e in + | Add (Zero, e) | Add (e, Zero) -> e + | Mul (Zero, _) | Mul (_, Zero) -> Zero + | Mul (One, e) | Mul (e, One) -> e + | Mul (Const c1, Const c2) -> Const (c1 * c2) + | Add (Const c1, Const c2) -> Const (c1 + c2) + | Sub (Const c1, Const c2) -> Const (c1 - c2) + | e -> e + in recursor f - let rec fix eq f e = - if eq (f e) e then e else fix eq f (f e) - + let rec fix eq f e = if eq (f e) e then e else fix eq f (f e) let simplify = fix equal simplify' + (* This could be nicer with Algebra *) - let rec expand e = match e with - | Const _ - | One - | Zero - | PlaceHolder _ - | SizeVar _ - | Var _ - | IndVar _ as e -> e - | Add (e1, e2) -> Add( expand e1, expand e2) - | Sub (e1, e2) -> Sub( expand e1, expand e2) - | Min (e1, e2) -> Min( expand e1, expand e2) - | Mul(e1, e2) -> - begin match expand e1, expand e2 with - | Add(e1, e2), Add(e3, e4) -> - Add(Add( Mul(e1, e3), Mul(e1, e4)), - Add(Mul(e2, e3), Mul(e2, e4))) - | Add(e1, e2), Sub(e3, e4) -> - Add(Sub( Mul(e1, e3), Mul(e1, e4)), - Sub(Mul(e2, e3), Mul(e2, e4))) - | Sub(e1, e2), Add(e3, e4) -> - Sub(Add( Mul(e1, e3), Mul(e1, e4)), - Add(Mul(e2, e3), Mul(e2, e4))) - | Sub(e1, e2), Sub(e3, e4) -> - Add(Sub( Mul(e1, e3), Mul(e1, e4)), - Sub(Mul(e2, e4), Mul(e2, e3))) - | e1, Add(e2, e3) -> - Add( Mul(e1, e2), Mul(e1, e3)) - | Add(e1, e2), e3 -> - Add( Mul(e1, e3), Mul(e2, e3)) - | e1, Sub(e2, e3) -> - Sub( Mul(e1, e2), Mul(e1, e3)) - | Sub(e1, e2), e3 -> - Sub( Mul(e1, e3), Mul(e2, e3)) - | e1, e2 -> - Mul(e1, e2) - end + let rec expand e = + match e with + | (Const _ | One | Zero | PlaceHolder _ | SizeVar _ | Var _ | IndVar _) as e + -> + e + | Add (e1, e2) -> Add (expand e1, expand e2) + | Sub (e1, e2) -> Sub (expand e1, expand e2) + | Min (e1, e2) -> Min (expand e1, expand e2) + | Mul (e1, e2) -> ( + match (expand e1, expand e2) with + | Add (e1, e2), Add (e3, e4) -> + Add + ( Add (Mul (e1, e3), Mul (e1, e4)), + Add (Mul (e2, e3), Mul (e2, e4)) ) + | Add (e1, e2), Sub (e3, e4) -> + Add + ( Sub (Mul (e1, e3), Mul (e1, e4)), + Sub (Mul (e2, e3), Mul (e2, e4)) ) + | Sub (e1, e2), Add (e3, e4) -> + Sub + ( Add (Mul (e1, e3), Mul (e1, e4)), + Add (Mul (e2, e3), Mul (e2, e4)) ) + | Sub (e1, e2), Sub (e3, e4) -> + Add + ( Sub (Mul (e1, e3), Mul (e1, e4)), + Sub (Mul (e2, e4), Mul (e2, e3)) ) + | e1, Add (e2, e3) -> Add (Mul (e1, e2), Mul (e1, e3)) + | Add (e1, e2), e3 -> Add (Mul (e1, e3), Mul (e2, e3)) + | e1, Sub (e2, e3) -> Sub (Mul (e1, e2), Mul (e1, e3)) + | Sub (e1, e2), e3 -> Sub (Mul (e1, e3), Mul (e2, e3)) + | e1, e2 -> Mul (e1, e2)) type sign = Plus | Minus - let negate = function Plus -> Minus - | Minus -> Plus + let negate = function Plus -> Minus | Minus -> Plus - let sign_mul s1 s2 = match s1, s2 with - Plus, Minus | Minus, Plus -> Minus - | _ -> Plus + let sign_mul s1 s2 = + match (s1, s2) with Plus, Minus | Minus, Plus -> Minus | _ -> Plus - let rec normalize e = let open Utils in + let rec normalize e = + let open Utils in match e with - | Const _ - | One - | Zero - | PlaceHolder _ - | SizeVar _ - | Var _ - | IndVar _ -> Left (Plus, [e]) - | Min (e1, e2) -> - begin match normalize e1, normalize e2 with - Left p1, Left p2 -> - Right ([p1; p2]) - | Right e1, Left p2 -> - Right (e1 @ [p2]) - | Left p1, Right l2 -> - Right (p1 :: l2) - | Right p1, Right p2 -> - Right (p1 @ p2) - end - | Add (e1, e2) -> - begin match normalize e1, normalize e2 with - Left p1, Left p2 -> - Right ([p1; p2]) - | Right e1, Left p2 -> - Right (e1 @ [p2]) - | Left p1, Right l2 -> - Right (p1 :: l2) - | Right p1, Right p2 -> - Right (p1 @ p2) - end - | Sub (e1, e2) -> - begin match normalize e1, normalize e2 with - Left p1, Left (s2, e2) -> - Right ([p1; negate s2, e2]) - | Right e1, Left (s2, e2) -> - Right (e1 @ [negate s2, e2]) - | Left p1, Right l2 -> - Right (p1 :: List.map (fun (s, e) -> negate s, e) l2) - | Right l1, Right l2 -> - Right (l1 @ List.map (fun (s, e) -> negate s, e) l2) - end - | Mul(e1, e2) -> - let distribute concat p lp = List.map - (fun (s2, l2) -> - let (s1, l1) = p in - sign_mul s1 s2, concat l1 l2) lp in - let distribute_left = distribute List.append in - let distribute_right = distribute (Fun.flip List.append) in - begin match normalize e1, normalize e2 with - | Left (s1,l1), Left (s2, l2) -> - Left (sign_mul s1 s2, l1 @ l2) - | Left p1, Right l2 -> - Right (distribute_left p1 l2) - | Right l1, Left p2 -> - Right (distribute_right p2 l1) + | Const _ | One | Zero | PlaceHolder _ | SizeVar _ | Var _ | IndVar _ -> + Left (Plus, [ e ]) + | Min (e1, e2) -> ( + match (normalize e1, normalize e2) with + | Left p1, Left p2 -> Right [ p1; p2 ] + | Right e1, Left p2 -> Right (e1 @ [ p2 ]) + | Left p1, Right l2 -> Right (p1 :: l2) + | Right p1, Right p2 -> Right (p1 @ p2)) + | Add (e1, e2) -> ( + match (normalize e1, normalize e2) with + | Left p1, Left p2 -> Right [ p1; p2 ] + | Right e1, Left p2 -> Right (e1 @ [ p2 ]) + | Left p1, Right l2 -> Right (p1 :: l2) + | Right p1, Right p2 -> Right (p1 @ p2)) + | Sub (e1, e2) -> ( + match (normalize e1, normalize e2) with + | Left p1, Left (s2, e2) -> Right [ p1; (negate s2, e2) ] + | Right e1, Left (s2, e2) -> Right (e1 @ [ (negate s2, e2) ]) + | Left p1, Right l2 -> + Right (p1 :: List.map (fun (s, e) -> (negate s, e)) l2) | Right l1, Right l2 -> - let l_fused = List.concat_map (fun p -> distribute_left p l2) l1 in - Right l_fused - end + Right (l1 @ List.map (fun (s, e) -> (negate s, e)) l2)) + | Mul (e1, e2) -> ( + let distribute concat p lp = + List.map + (fun (s2, l2) -> + let s1, l1 = p in + (sign_mul s1 s2, concat l1 l2)) + lp + in + let distribute_left = distribute List.append in + let distribute_right = distribute (Fun.flip List.append) in + match (normalize e1, normalize e2) with + | Left (s1, l1), Left (s2, l2) -> Left (sign_mul s1 s2, l1 @ l2) + | Left p1, Right l2 -> Right (distribute_left p1 l2) + | Right l1, Left p2 -> Right (distribute_right p2 l1) + | Right l1, Right l2 -> + let l_fused = List.concat_map (fun p -> distribute_left p l2) l1 in + Right l_fused) let size sid = SizeVar sid @@ -289,60 +252,62 @@ module Expr = struct * (potentially living in another "space") *) let rec map_size_index f = function | SizeVar id -> f id - | Const _ | PlaceHolder _ | One | Zero as c -> c + | (Const _ | PlaceHolder _ | One | Zero) as c -> c | e -> recursor (map_size_index f) e let instanciate_placeholder var c = - let f = function PlaceHolder s when String.equal var s -> Const c - | e -> e in + let f = function + | PlaceHolder s when String.equal var s -> Const c + | e -> e + in recursor f (* contain_var [var_expr] [var_id] returns true if Variable var_id is used in var_expr *) - let rec contain_var var_expr var_id = match var_expr with + let rec contain_var var_expr var_id = + match var_expr with | IndVar index -> Index.equal index var_id - | Zero | One | Const _ - | Var _ | SizeVar _ | PlaceHolder _ -> false + | Zero | One | Const _ | Var _ | SizeVar _ | PlaceHolder _ -> false | Add (e1, e2) | Sub (e1, e2) | Mul (e1, e2) | Min (e1, e2) -> - contain_var e1 var_id || contain_var e2 var_id + contain_var e1 var_id || contain_var e2 var_id let index ind_id = IndVar ind_id (* contain_var [var_expr] [var_id] returns true if Variable var_id is used in var_expr *) - let rec access_dim_id var_expr dim = match var_expr with + let rec access_dim_id var_expr dim = + match var_expr with | IndVar index -> Dim.equal_id (Index.dim index) dim - | Zero | One | Const _ -> false + | Zero | One | Const _ -> false | Var _ | SizeVar _ | PlaceHolder _ -> false | Add (e1, e2) | Sub (e1, e2) | Mul (e1, e2) | Min (e1, e2) -> - access_dim_id e1 dim || access_dim_id e2 dim + access_dim_id e1 dim || access_dim_id e2 dim (* contain_var [var_expr] [var_id] returns true if Variable var_id is used in var_expr *) - let access_dim var_expr dim = - access_dim_id var_expr (Dim.id dim) + let access_dim var_expr dim = access_dim_id var_expr (Dim.id dim) + (* Apply a function on all indexes by recursing on an expression *) let rec map_index f = function | IndVar vid -> f vid - | Const _ | One | Zero | SizeVar _ | PlaceHolder _ as e -> e + | (Const _ | One | Zero | SizeVar _ | PlaceHolder _) as e -> e | e -> recursor (map_index f) e let alpha_replace_pred pred var_expr expr = - let replace = fun vid -> if pred vid then var_expr else IndVar vid in + let replace vid = if pred vid then var_expr else IndVar vid in map_index replace expr let alpha_replace var_id var_expr expr = - alpha_replace_pred (Index.equal var_id) var_expr expr + alpha_replace_pred (Index.equal var_id) var_expr expr let alpha_replace_dim dim var_expr expr = - alpha_replace_pred (fun vid -> Dim.equal_id (Index.dim vid) dim) var_expr expr - - let alpha_conv old_vid new_vid = - alpha_replace old_vid (IndVar new_vid) + alpha_replace_pred + (fun vid -> Dim.equal_id (Index.dim vid) dim) + var_expr expr - let prop_const old_vid const expr = - alpha_replace old_vid (Const const) expr + let alpha_conv old_vid new_vid = alpha_replace old_vid (IndVar new_vid) + let prop_const old_vid const expr = alpha_replace old_vid (Const const) expr module I = struct - let (+) = add - let (-) = sub + let ( + ) = add + let ( - ) = sub let ( * ) = mul end end diff --git a/ml/lib/inst_sign.ml b/ml/lib/inst_sign.ml index 92e13e75922f8874df82fe94be4c3658d67cee80..19dc8ec51430bc36770796ece4dfaf53fa31b922 100644 --- a/ml/lib/inst_sign.ml +++ b/ml/lib/inst_sign.ml @@ -3,7 +3,7 @@ open Ids open Exprs module Concrete_types = struct - type inst_type = Vec | Scal [@@deriving eq, show {with_path = false}] + type inst_type = Vec | Scal [@@deriving eq, show { with_path = false }] type t = | Nop @@ -11,11 +11,12 @@ module Concrete_types = struct * in some way "out of scope", we don't expect to generate a fresh name for * them. Still uneasy with that though *) | Allocated of inst_type * string - - | External_call of {name: string ; - tensors_info : (Tensor.t * - Tensor.accesses) list; - var_exprs: Expr.t list} (* size that may vary *) + | External_call of { + name : string; + tensors_info : (Tensor.t * Tensor.accesses) list; + var_exprs : Expr.t list; + } + (* size that may vary *) (* void microk(float* C, float * A, float * B, * int strideI, int strideJ, int strideK)*) | Comment of string @@ -30,28 +31,44 @@ module Concrete_types = struct | Vsub of t * t | Vmul of t * t | Vbcst of t - | Vtranspose of Tensor.t * Tensor.accesses * Expr.t * Tensor.t* - Tensor.accesses * Expr.t - | Vwrite of t * Tensor.t * Tensor.accesses [@@deriving eq, show] + | Vtranspose of + Tensor.t + * Tensor.accesses + * Expr.t + * Tensor.t + * Tensor.accesses + * Expr.t + | Vwrite of t * Tensor.t * Tensor.accesses + [@@deriving eq, show] end module type Inst_t = sig - module A: Vec_arch_t - + module A : Vec_arch_t include module type of Concrete_types - val vectorize_on_dim: Dim.id -> t -> t - val map_expr: Dim.id -> (Expr.t -> Expr.t ) -> t -> t - val access: Dim.id -> t -> bool - val is_tensor_readonly: t list -> Tensor.t -> bool + val vectorize_on_dim : Dim.id -> t -> t + val map_expr : Dim.id -> (Expr.t -> Expr.t) -> t -> t + val access : Dim.id -> t -> bool + val is_tensor_readonly : t list -> Tensor.t -> bool type mem_accesses_inf = (Tensor.t * Tensor.accesses) * (inst_type * string) - val fact_loads_on_dims: Dim.id list -> mem_accesses_inf list -> t list - -> mem_accesses_inf list * t list * t list - val fact_stores_on_dims: Dim.id list -> mem_accesses_inf list -> t list - -> mem_accesses_inf list * t list * t list - val swap_tensor: Tensor.t -> Tensor.t -> t list -> t list - val swap_accesses: (Tensor.t -> Tensor.accesses -> Tensor.accesses) -> t list -> t list - val gen_code: string -> t list -> string list * string list * string -end + val fact_loads_on_dims : + Dim.id list -> + mem_accesses_inf list -> + t list -> + mem_accesses_inf list * t list * t list + + val fact_stores_on_dims : + Dim.id list -> + mem_accesses_inf list -> + t list -> + mem_accesses_inf list * t list * t list + + val swap_tensor : Tensor.t -> Tensor.t -> t list -> t list + + val swap_accesses : + (Tensor.t -> Tensor.accesses -> Tensor.accesses) -> t list -> t list + + val gen_code : string -> t list -> string list * string list * string +end diff --git a/ml/lib/instruction.ml b/ml/lib/instruction.ml index 934b1f42f4ff0ec11fdff9962b1497d011f2289d..b7517317aa715d256bb84aaa2a2cdc9d0e7b0579 100644 --- a/ml/lib/instruction.ml +++ b/ml/lib/instruction.ml @@ -5,49 +5,63 @@ module T = Tensor open Arch module type Inst_t = Inst_sign.Inst_t -module I(V: Vec_arch_t): Inst_t = struct + +module I (V : Vec_arch_t) : Inst_t = struct module A = V (* Definition for type inst *) include Inst_sign.Concrete_types - let string_of_inst_type = function - | Vec -> "vec" - | Scal -> "scal" + let string_of_inst_type = function Vec -> "vec" | Scal -> "scal" (* Allocated and Assign are supposed not to access any dimension, which is ... doubtful *) let rec access dim = function - | Read (tensor, accesses) | Vread (tensor, accesses) -> T.does_access tensor accesses dim - | Write (expr, tensor, accesses) | Vwrite (expr, tensor, accesses) -> access dim expr - || T.does_access tensor accesses dim - | External_call {tensors_info;_} -> - List.exists (fun (t, acc) -> T.does_access t acc dim) tensors_info - | Add (e1, e2) | Mul (e1, e2) | Sub (e1, e2) - | Vadd (e1, e2) | Vmul (e1, e2) | Vsub (e1, e2) -> access dim e1 || access dim e2 + | Read (tensor, accesses) | Vread (tensor, accesses) -> + T.does_access tensor accesses dim + | Write (expr, tensor, accesses) | Vwrite (expr, tensor, accesses) -> + access dim expr || T.does_access tensor accesses dim + | External_call { tensors_info; _ } -> + List.exists (fun (t, acc) -> T.does_access t acc dim) tensors_info + | Add (e1, e2) + | Mul (e1, e2) + | Sub (e1, e2) + | Vadd (e1, e2) + | Vmul (e1, e2) + | Vsub (e1, e2) -> + access dim e1 || access dim e2 | _ -> false - let rec map_expr dim f inst = match inst with - | Nop - | Comment _ - | Allocated _ - | Assign _ - | Vtranspose _ -> inst - | External_call ({tensors_info;_} as arg) -> - let tensors_info = List.map (fun (t, acc) -> - t, List.modify_assoc Dim.equal_id (Option.map f) dim acc) tensors_info in - External_call {arg with tensors_info} - | Read(s, access_list) -> Read(s, List.modify_assoc Dim.equal_id (Option.map f) dim access_list) - | Write(value, tensor, access_list) -> - Write(map_expr dim f value, tensor, List.modify_assoc Dim.equal_id (Option.map f) dim access_list) - | Add(v1, v2) -> Add(map_expr dim f v1, map_expr dim f v2) - | Sub(v1, v2) -> Sub(map_expr dim f v1, map_expr dim f v2) - | Mul(v1, v2) -> Mul(map_expr dim f v1, map_expr dim f v2) - | Vread(s, access_list) -> Vread(s, List.modify_assoc Dim.equal_id (Option.map f) dim access_list) - | Vwrite(value, tensor, access_list) -> - Vwrite(map_expr dim f value, tensor, List.modify_assoc Dim.equal_id (Option.map f) dim access_list) - | Vadd(v1, v2) -> Vadd(map_expr dim f v1, map_expr dim f v2) - | Vsub(v1, v2) -> Vsub(map_expr dim f v1, map_expr dim f v2) - | Vmul(v1, v2) -> Vmul(map_expr dim f v1, map_expr dim f v2) + let rec map_expr dim f inst = + match inst with + | Nop | Comment _ | Allocated _ | Assign _ | Vtranspose _ -> inst + | External_call ({ tensors_info; _ } as arg) -> + let tensors_info = + List.map + (fun (t, acc) -> + (t, List.modify_assoc Dim.equal_id (Option.map f) dim acc)) + tensors_info + in + External_call { arg with tensors_info } + | Read (s, access_list) -> + Read (s, List.modify_assoc Dim.equal_id (Option.map f) dim access_list) + | Write (value, tensor, access_list) -> + Write + ( map_expr dim f value, + tensor, + List.modify_assoc Dim.equal_id (Option.map f) dim access_list ) + | Add (v1, v2) -> Add (map_expr dim f v1, map_expr dim f v2) + | Sub (v1, v2) -> Sub (map_expr dim f v1, map_expr dim f v2) + | Mul (v1, v2) -> Mul (map_expr dim f v1, map_expr dim f v2) + | Vread (s, access_list) -> + Vread (s, List.modify_assoc Dim.equal_id (Option.map f) dim access_list) + | Vwrite (value, tensor, access_list) -> + Vwrite + ( map_expr dim f value, + tensor, + List.modify_assoc Dim.equal_id (Option.map f) dim access_list ) + | Vadd (v1, v2) -> Vadd (map_expr dim f v1, map_expr dim f v2) + | Vsub (v1, v2) -> Vsub (map_expr dim f v1, map_expr dim f v2) + | Vmul (v1, v2) -> Vmul (map_expr dim f v1, map_expr dim f v2) | Vbcst value -> Vbcst (map_expr dim f value) let rec vectorize_on_dim dim = @@ -56,232 +70,286 @@ module I(V: Vec_arch_t): Inst_t = struct * vectorized dimension, we can just broadcast the result. Else we vectorize * both expression *) let vec_pair e1 e2 = - match access dim e1, access dim e2 with + match (access dim e1, access dim e2) with | false, false -> Left (e1, e2) | true, false -> - let ve2 = Vbcst e2 - and ve1 = vectorize_on_dim dim e1 in - Right (ve1, ve2) + let ve2 = Vbcst e2 and ve1 = vectorize_on_dim dim e1 in + Right (ve1, ve2) | false, true -> - let ve1 = Vbcst e1 - and ve2 = vectorize_on_dim dim e2 in - Right (ve1, ve2) + let ve1 = Vbcst e1 and ve2 = vectorize_on_dim dim e2 in + Right (ve1, ve2) | true, true -> - let ve1 = vectorize_on_dim dim e1 - and ve2 = vectorize_on_dim dim e2 in - Right (ve1, ve2) in function - | Read (tensor, accesses) -> if T.does_access tensor accesses dim - then Vread (tensor, accesses) + let ve1 = vectorize_on_dim dim e1 and ve2 = vectorize_on_dim dim e2 in + Right (ve1, ve2) + in + function + | Read (tensor, accesses) -> + if T.does_access tensor accesses dim then Vread (tensor, accesses) else Vbcst (Read (tensor, accesses)) - | Add (e1, e2) -> - (match vec_pair e1 e2 with - | Left (e1, e2) -> Vbcst (Add (e1, e2)) - | Right (e1, e2) -> Vadd (e1, e2) - ) - | Sub (e1, e2) -> - (match vec_pair e1 e2 with - | Left (e1, e2) -> Vbcst (Sub (e1, e2)) - | Right (e1, e2) -> Vsub (e1, e2) - ) - | Mul (e1, e2) -> - (match vec_pair e1 e2 with - | Left (e1, e2) -> Vbcst (Mul (e1, e2)) - | Right (e1, e2) -> Vmul (e1, e2) - ) - | Write (e , tensor, accesses) -> if T.does_access tensor accesses dim - then Vwrite (vectorize_on_dim dim e, tensor, accesses) + | Add (e1, e2) -> ( + match vec_pair e1 e2 with + | Left (e1, e2) -> Vbcst (Add (e1, e2)) + | Right (e1, e2) -> Vadd (e1, e2)) + | Sub (e1, e2) -> ( + match vec_pair e1 e2 with + | Left (e1, e2) -> Vbcst (Sub (e1, e2)) + | Right (e1, e2) -> Vsub (e1, e2)) + | Mul (e1, e2) -> ( + match vec_pair e1 e2 with + | Left (e1, e2) -> Vbcst (Mul (e1, e2)) + | Right (e1, e2) -> Vmul (e1, e2)) + | Write (e, tensor, accesses) -> + if T.does_access tensor accesses dim then + Vwrite (vectorize_on_dim dim e, tensor, accesses) else Vbcst (Write (e, tensor, accesses)) - | Allocated (Scal, _) as e -> Vbcst (e) - (* This catch-all pattern feels a bit uneasy... Supposed to catch only vectorized instructions *) - | _ as e -> e + | Allocated (Scal, _) as e -> Vbcst e + (* This catch-all pattern feels a bit uneasy... Supposed to catch only vectorized instructions *) + | _ as e -> e let num_allocated_loads = ref 0 let num_allocated_vloads = ref 0 let num_allocated_stores = ref 0 let num_allocated_vstores = ref 0 - let reiinit_all () = let reinit r = r := 0 in - List.iter reinit [num_allocated_loads; num_allocated_vloads; - num_allocated_stores; num_allocated_vstores] + let reiinit_all () = + let reinit r = r := 0 in + List.iter reinit + [ + num_allocated_loads; + num_allocated_vloads; + num_allocated_stores; + num_allocated_vstores; + ] type mem_accesses_inf = (T.t * T.accesses) * (inst_type * string) + (* Finds every access that is independent of all dimensions in dims, factorize them out * We have a problem : How to prevent naming interference. We have to keep some * track of how many variables have been allocated. No clear idea how to do that now *) let fact_loads_on_dims dims prev_accesses inst_list = let gen_name acc_type accesses loads value tensor expr = - match List.assoc_eq [%eq: T.t * (Dim.id * Expr.t) list] (tensor, expr) accesses with - | None -> - (match List.assoc_eq [%eq: T.t * (Dim.id * Expr.t) list] (tensor, expr) prev_accesses with - None -> (match acc_type with - | Scal -> - let num_load = Ref.post_incr num_allocated_loads in - let load_name ="mem_scal_" ^ string_of_int num_load in - (((tensor, expr), (Scal, load_name))::accesses), - (Assign(value, Scal, load_name)::loads), - Allocated(Scal, load_name) - | Vec -> - let num_load = Ref.post_incr num_allocated_vloads in - let load_name ="mem_vec_" ^ string_of_int num_load in - (((tensor, expr), (Vec, load_name))::accesses), - (Assign(value, Vec, load_name)::loads), - Allocated(Vec, load_name) - ) - | Some (_, var) -> - (((tensor, expr), (acc_type, var))::accesses), - (Assign(value, acc_type, var)::loads), - Allocated(acc_type, var) - ) - | Some (_, var) -> - accesses, loads, Allocated(acc_type, var) in + match + List.assoc_eq [%eq: T.t * (Dim.id * Expr.t) list] (tensor, expr) + accesses + with + | None -> ( + match + List.assoc_eq [%eq: T.t * (Dim.id * Expr.t) list] (tensor, expr) + prev_accesses + with + | None -> ( + match acc_type with + | Scal -> + let num_load = Ref.post_incr num_allocated_loads in + let load_name = "mem_scal_" ^ string_of_int num_load in + ( ((tensor, expr), (Scal, load_name)) :: accesses, + Assign (value, Scal, load_name) :: loads, + Allocated (Scal, load_name) ) + | Vec -> + let num_load = Ref.post_incr num_allocated_vloads in + let load_name = "mem_vec_" ^ string_of_int num_load in + ( ((tensor, expr), (Vec, load_name)) :: accesses, + Assign (value, Vec, load_name) :: loads, + Allocated (Vec, load_name) )) + | Some (_, var) -> + ( ((tensor, expr), (acc_type, var)) :: accesses, + Assign (value, acc_type, var) :: loads, + Allocated (acc_type, var) )) + | Some (_, var) -> (accesses, loads, Allocated (acc_type, var)) + in let rec recurse_inst (accesses_list, loads) = function - | (Read (tensor, accesses) as r) when List.for_all (fun dim -> not (T.does_access tensor accesses dim)) dims -> - gen_name Scal accesses_list loads r tensor accesses - | (Vread (tensor, accesses) as r) when List.for_all (fun dim -> not (T.does_access tensor accesses dim)) dims -> - gen_name Vec accesses_list loads r tensor accesses - | (Nop | Allocated _ | External_call _ | Vtranspose _ | Comment _) as inst -> - accesses_list, loads, inst + | Read (tensor, accesses) as r + when List.for_all + (fun dim -> not (T.does_access tensor accesses dim)) + dims -> + gen_name Scal accesses_list loads r tensor accesses + | Vread (tensor, accesses) as r + when List.for_all + (fun dim -> not (T.does_access tensor accesses dim)) + dims -> + gen_name Vec accesses_list loads r tensor accesses + | (Nop | Allocated _ | External_call _ | Vtranspose _ | Comment _) as inst + -> + (accesses_list, loads, inst) | Assign (e, itype, ptr) -> - let accesses_list, loads, e = recurse_inst (accesses_list, loads) e in - accesses_list, loads, Assign (e, itype, ptr) - | (Read _) as r -> accesses_list, loads, r - | (Vread _) as r -> accesses_list, loads, r + let accesses_list, loads, e = recurse_inst (accesses_list, loads) e in + (accesses_list, loads, Assign (e, itype, ptr)) + | Read _ as r -> (accesses_list, loads, r) + | Vread _ as r -> (accesses_list, loads, r) | Add (e1, e2) -> - let accesses_list, loads, e1 = recurse_inst (accesses_list, loads) e1 in - let accesses_list, loads, e2 = recurse_inst (accesses_list, loads) e2 in - accesses_list, loads, Add(e1, e2) + let accesses_list, loads, e1 = + recurse_inst (accesses_list, loads) e1 + in + let accesses_list, loads, e2 = + recurse_inst (accesses_list, loads) e2 + in + (accesses_list, loads, Add (e1, e2)) | Sub (e1, e2) -> - let accesses_list, loads, e1 = recurse_inst (accesses_list, loads) e1 in - let accesses_list, loads, e2 = recurse_inst (accesses_list, loads) e2 in - accesses_list, loads, Sub(e1, e2) + let accesses_list, loads, e1 = + recurse_inst (accesses_list, loads) e1 + in + let accesses_list, loads, e2 = + recurse_inst (accesses_list, loads) e2 + in + (accesses_list, loads, Sub (e1, e2)) | Mul (e1, e2) -> - let accesses_list, loads, e1 = recurse_inst (accesses_list, loads) e1 in - let accesses_list, loads, e2 = recurse_inst (accesses_list, loads) e2 in - accesses_list, loads, Mul(e1, e2) + let accesses_list, loads, e1 = + recurse_inst (accesses_list, loads) e1 + in + let accesses_list, loads, e2 = + recurse_inst (accesses_list, loads) e2 + in + (accesses_list, loads, Mul (e1, e2)) | Vadd (e1, e2) -> - let accesses_list, loads, e1 = recurse_inst (accesses_list, loads) e1 in - let accesses_list, loads, e2 = recurse_inst (accesses_list, loads) e2 in - accesses_list, loads, Vadd(e1, e2) + let accesses_list, loads, e1 = + recurse_inst (accesses_list, loads) e1 + in + let accesses_list, loads, e2 = + recurse_inst (accesses_list, loads) e2 + in + (accesses_list, loads, Vadd (e1, e2)) | Vsub (e1, e2) -> - let accesses_list, loads, e1 = recurse_inst (accesses_list, loads) e1 in - let accesses_list, loads, e2 = recurse_inst (accesses_list, loads) e2 in - accesses_list, loads, Vsub(e1, e2) + let accesses_list, loads, e1 = + recurse_inst (accesses_list, loads) e1 + in + let accesses_list, loads, e2 = + recurse_inst (accesses_list, loads) e2 + in + (accesses_list, loads, Vsub (e1, e2)) | Vmul (e1, e2) -> - let accesses_list, loads, e1 = recurse_inst (accesses_list, loads) e1 in - let accesses_list, loads, e2 = recurse_inst (accesses_list, loads) e2 in - accesses_list, loads, Vmul(e1, e2) + let accesses_list, loads, e1 = + recurse_inst (accesses_list, loads) e1 + in + let accesses_list, loads, e2 = + recurse_inst (accesses_list, loads) e2 + in + (accesses_list, loads, Vmul (e1, e2)) | Vbcst e -> - let accesses_list, loads, e = recurse_inst (accesses_list, loads) e in - accesses_list, loads, Vbcst e - | Vwrite (value, tensor, expr) -> - let accesses_list, loads, e = recurse_inst (accesses_list, loads) value in - accesses_list, loads, Vwrite (e, tensor, expr) - | Write (value, tensor, expr) -> - let accesses_list, loads, e = recurse_inst (accesses_list, loads) value in - accesses_list, loads, Write (e, tensor, expr) + let accesses_list, loads, e = recurse_inst (accesses_list, loads) e in + (accesses_list, loads, Vbcst e) + | Vwrite (value, tensor, expr) -> + let accesses_list, loads, e = + recurse_inst (accesses_list, loads) value + in + (accesses_list, loads, Vwrite (e, tensor, expr)) + | Write (value, tensor, expr) -> + let accesses_list, loads, e = + recurse_inst (accesses_list, loads) value + in + (accesses_list, loads, Write (e, tensor, expr)) in let factorize (accesses_list, loads, filtered) inst = - let accesses_list, loads, new_inst = recurse_inst (accesses_list, loads) inst in - accesses_list, loads, (new_inst::filtered) + let accesses_list, loads, new_inst = + recurse_inst (accesses_list, loads) inst + in + (accesses_list, loads, new_inst :: filtered) in - List.fold_left factorize - ([], [], []) inst_list - |> fun (acc, lds, insts) -> List.(rev acc, rev lds, rev insts) + List.fold_left factorize ([], [], []) inst_list |> fun (acc, lds, insts) -> + List.(rev acc, rev lds, rev insts) let is_tensor_readonly inst_list tensor = let does_modify_tens = function | External_call _ -> true - | Write(_, tens, _) | Vwrite(_, tens, _) when T.equal_id (T.id tens) (T.id tensor) -> - true - | _ -> false in + | (Write (_, tens, _) | Vwrite (_, tens, _)) + when T.equal_id (T.id tens) (T.id tensor) -> + true + | _ -> false + in List.for_all (Bool.not % does_modify_tens) inst_list (* Finds every write access that is independent of all dimensions in dim, factorize them out *) let fact_stores_on_dims dims prev_accesses inst_list = let gen_name acc_type accesses stores filtered value tensor expr = let find_tens_access tensor expr accesses = - List.assoc_eq [%eq: T.t * (Dim.id * Expr.t) list] (tensor, expr) accesses in + List.assoc_eq [%eq: T.t * (Dim.id * Expr.t) list] (tensor, expr) + accesses + in match find_tens_access tensor expr accesses with - | None -> - (match find_tens_access tensor expr prev_accesses with - None -> (match acc_type with - | Scal -> - let num_stores = Ref.post_incr num_allocated_stores in - let store_name ="mem_scal_" ^ string_of_int num_stores in - let allocated = Allocated(Scal, store_name) in - (((tensor, expr), (Scal, store_name))::accesses), - (Write(allocated, tensor, expr)::stores), - (Assign(value, Scal, store_name)::filtered) - | Vec -> - let num_stores = Ref.post_incr num_allocated_vstores in - let store_name = "mem_vec_" ^ string_of_int num_stores in - let allocated = Allocated(Vec, store_name) in - (((tensor, expr), (Vec, store_name))::accesses), - (Vwrite(allocated, tensor, expr)::stores), - (Assign(value, Vec, store_name)::filtered) - ) - | Some (_, var) -> - (((tensor, expr), (acc_type, var))::accesses), - (let allocated = Allocated(acc_type, var) in - match acc_type with - Scal -> Write(allocated, tensor, expr)::stores - | Vec -> Vwrite(allocated, tensor, expr)::stores), - (Assign(value, acc_type, var)::filtered) - ) + | None -> ( + match find_tens_access tensor expr prev_accesses with + | None -> ( + match acc_type with + | Scal -> + let num_stores = Ref.post_incr num_allocated_stores in + let store_name = "mem_scal_" ^ string_of_int num_stores in + let allocated = Allocated (Scal, store_name) in + ( ((tensor, expr), (Scal, store_name)) :: accesses, + Write (allocated, tensor, expr) :: stores, + Assign (value, Scal, store_name) :: filtered ) + | Vec -> + let num_stores = Ref.post_incr num_allocated_vstores in + let store_name = "mem_vec_" ^ string_of_int num_stores in + let allocated = Allocated (Vec, store_name) in + ( ((tensor, expr), (Vec, store_name)) :: accesses, + Vwrite (allocated, tensor, expr) :: stores, + Assign (value, Vec, store_name) :: filtered )) + | Some (_, var) -> + ( ((tensor, expr), (acc_type, var)) :: accesses, + (let allocated = Allocated (acc_type, var) in + match acc_type with + | Scal -> Write (allocated, tensor, expr) :: stores + | Vec -> Vwrite (allocated, tensor, expr) :: stores), + Assign (value, acc_type, var) :: filtered )) | Some (_, var) -> - accesses, stores, (Assign (value, acc_type, var)::filtered) + (accesses, stores, Assign (value, acc_type, var) :: filtered) in let factorize (accesses_list, stores, filtered) = function (* There is no need to call recursively factor_stores on * instructions because no write instructions must appear in this * place *) | Write (value, tensor, accesses) - when List.for_all (fun dim -> not (T.does_access tensor accesses dim)) dims -> - let accesses, stores, filtered = - gen_name Scal accesses_list stores filtered value tensor accesses in - accesses, stores, filtered + when List.for_all + (fun dim -> not (T.does_access tensor accesses dim)) + dims -> + let accesses, stores, filtered = + gen_name Scal accesses_list stores filtered value tensor accesses + in + (accesses, stores, filtered) | Vwrite (value, tensor, accesses) - when List.for_all (fun dim -> not (T.does_access tensor accesses dim)) dims -> - let accesses, stores, filtered = - gen_name Vec accesses_list stores filtered value tensor accesses in - accesses, stores, filtered - | inst -> - accesses_list, stores, (inst::filtered) + when List.for_all + (fun dim -> not (T.does_access tensor accesses dim)) + dims -> + let accesses, stores, filtered = + gen_name Vec accesses_list stores filtered value tensor accesses + in + (accesses, stores, filtered) + | inst -> (accesses_list, stores, inst :: filtered) in - List.fold_left factorize ([], [], []) inst_list - |> fun (acc, lds, insts) -> List.(rev acc, rev lds, rev insts) + List.fold_left factorize ([], [], []) inst_list |> fun (acc, lds, insts) -> + List.(rev acc, rev lds, rev insts) let swap_tensor src_tensor dst_tensor inst_list = let rec swap_inst = function | Read (tensor, acc_expr) when T.equal tensor src_tensor -> - Read (dst_tensor, acc_expr) + Read (dst_tensor, acc_expr) | Vread (tensor, acc_expr) when T.equal tensor src_tensor -> - Vread (dst_tensor, acc_expr) - | Write (value, tensor, acc_expr) - when T.equal tensor src_tensor -> - let value = swap_inst value in - Write (value, dst_tensor, acc_expr) - | Vwrite (value, tensor, acc_expr) - when T.equal tensor src_tensor -> - let value = swap_inst value in - Vwrite(value, dst_tensor, acc_expr) + Vread (dst_tensor, acc_expr) + | Write (value, tensor, acc_expr) when T.equal tensor src_tensor -> + let value = swap_inst value in + Write (value, dst_tensor, acc_expr) + | Vwrite (value, tensor, acc_expr) when T.equal tensor src_tensor -> + let value = swap_inst value in + Vwrite (value, dst_tensor, acc_expr) | Write (value, tensor, acc_expr) -> - Write (swap_inst value, tensor, acc_expr) + Write (swap_inst value, tensor, acc_expr) | Vwrite (value, tensor, acc_expr) -> - Vwrite (swap_inst value, tensor, acc_expr) - | Assign (value, vtype, tensor) -> - Assign (swap_inst value, vtype, tensor) - | Add (v1, v2) -> Add(swap_inst v1, swap_inst v2) - | Sub (v1, v2) -> Sub(swap_inst v1, swap_inst v2) - | Mul (v1, v2) -> Mul(swap_inst v1, swap_inst v2) - | Vadd (v1, v2) -> Vadd(swap_inst v1, swap_inst v2) - | Vsub (v1, v2) ->Vsub(swap_inst v1, swap_inst v2) - | Vmul (v1, v2) -> Vmul(swap_inst v1, swap_inst v2) + Vwrite (swap_inst value, tensor, acc_expr) + | Assign (value, vtype, tensor) -> Assign (swap_inst value, vtype, tensor) + | Add (v1, v2) -> Add (swap_inst v1, swap_inst v2) + | Sub (v1, v2) -> Sub (swap_inst v1, swap_inst v2) + | Mul (v1, v2) -> Mul (swap_inst v1, swap_inst v2) + | Vadd (v1, v2) -> Vadd (swap_inst v1, swap_inst v2) + | Vsub (v1, v2) -> Vsub (swap_inst v1, swap_inst v2) + | Vmul (v1, v2) -> Vmul (swap_inst v1, swap_inst v2) | Vbcst v -> Vbcst (swap_inst v) - | External_call ({ tensors_info; _} as args) -> - let tensors_info = List.map (fun (t, acc) -> if T.equal t src_tensor then - (dst_tensor, acc) else (t, acc)) tensors_info in - External_call {args with tensors_info} + | External_call ({ tensors_info; _ } as args) -> + let tensors_info = + List.map + (fun (t, acc) -> + if T.equal t src_tensor then (dst_tensor, acc) else (t, acc)) + tensors_info + in + External_call { args with tensors_info } | inst -> inst in List.map swap_inst inst_list @@ -289,29 +357,29 @@ module I(V: Vec_arch_t): Inst_t = struct let swap_accesses f inst_list = let rec swap_inst = function | Read (tensor, acc_expr) -> - let accesses = f tensor acc_expr in - Read (tensor, accesses) + let accesses = f tensor acc_expr in + Read (tensor, accesses) | Vread (tensor, acc_expr) -> - let accesses = f tensor acc_expr in - Vread (tensor, accesses) + let accesses = f tensor acc_expr in + Vread (tensor, accesses) | Write (value, tensor, acc_expr) -> - let value = swap_inst value in - let accesses = f tensor acc_expr in - Write (value, tensor, accesses) + let value = swap_inst value in + let accesses = f tensor acc_expr in + Write (value, tensor, accesses) | Vwrite (value, tensor, acc_expr) -> - let value = swap_inst value in - let accesses = f tensor acc_expr in - Vwrite(value, tensor, accesses) - | Assign (value, vtype, tensor) -> - Assign (swap_inst value, vtype, tensor) - | Add (v1, v2) -> Add(swap_inst v1, swap_inst v2) - | Sub (v1, v2) -> Sub(swap_inst v1, swap_inst v2) - | Mul (v1, v2) -> Mul(swap_inst v1, swap_inst v2) - | Vadd (v1, v2) -> Vadd(swap_inst v1, swap_inst v2) - | Vsub (v1, v2) ->Vsub(swap_inst v1, swap_inst v2) - | Vmul (v1, v2) -> Vmul(swap_inst v1, swap_inst v2) + let value = swap_inst value in + let accesses = f tensor acc_expr in + Vwrite (value, tensor, accesses) + | Assign (value, vtype, tensor) -> Assign (swap_inst value, vtype, tensor) + | Add (v1, v2) -> Add (swap_inst v1, swap_inst v2) + | Sub (v1, v2) -> Sub (swap_inst v1, swap_inst v2) + | Mul (v1, v2) -> Mul (swap_inst v1, swap_inst v2) + | Vadd (v1, v2) -> Vadd (swap_inst v1, swap_inst v2) + | Vsub (v1, v2) -> Vsub (swap_inst v1, swap_inst v2) + | Vmul (v1, v2) -> Vmul (swap_inst v1, swap_inst v2) | Vbcst v -> Vbcst (swap_inst v) - | inst -> inst in + | inst -> inst + in List.map swap_inst inst_list type fresh = Fresh of string | Old of string @@ -331,221 +399,264 @@ module I(V: Vec_arch_t): Inst_t = struct let () = reiinit_all () in let forbidden_access = ref [] in let rec can_use_old_name = function - | Vadd (a, b) | Vmul (a, b) | Vsub (a, b) - | Add (a, b) | Mul (a, b) | Sub (a, b) -> - can_use_old_name a && can_use_old_name b - | Vwrite _ | Vtranspose _ | Write _ -> false + | Vadd (a, b) + | Vmul (a, b) + | Vsub (a, b) + | Add (a, b) + | Mul (a, b) + | Sub (a, b) -> + can_use_old_name a && can_use_old_name b + | Vwrite _ | Vtranspose _ | Write _ -> false | Comment _ | External_call _ | Allocated _ | Nop -> true - | Vbcst e | Assign (e, _, _) -> can_use_old_name e + | Vbcst e | Assign (e, _, _) -> can_use_old_name e | Vread (t, addr) | Read (t, addr) -> - not @@ List.exists ([%eq: Tensor.t * Tensor.accesses] (t, addr)) !forbidden_access + not + @@ List.exists + ([%eq: Tensor.t * Tensor.accesses] (t, addr)) + !forbidden_access in let gen_name inst_type state value = match value with | Allocated (_, name) -> Old name - | _ -> - let id, name_list = !state in - match List.assoc_eq equal value name_list with - | Some ident -> Old ident - | None -> - let new_ident = Format.sprintf "%s_%d" (string_of_inst_type inst_type) id in - state := id + 1, (value, new_ident)::name_list; Fresh new_ident - and scal_state =ref (0, []) - and vec_state =ref (0, []) + | _ -> ( + let id, name_list = !state in + match List.assoc_eq equal value name_list with + | Some ident -> Old ident + | None -> + let new_ident = + Format.sprintf "%s_%d" (string_of_inst_type inst_type) id + in + state := (id + 1, (value, new_ident) :: name_list); + Fresh new_ident) + and scal_state = ref (0, []) + and vec_state = ref (0, []) in + let add_write_accesses t_addr = + forbidden_access := t_addr :: !forbidden_access in - let add_write_accesses t_addr = - forbidden_access := t_addr::!forbidden_access in let gen_vec_name = gen_name Vec scal_state and gen_scal_name = gen_name Scal vec_state in - let rec get_code_var gen_name inst_type op = match (gen_name op, can_use_old_name op) with - | Fresh var, _ - | Old var, false -> - begin match inst_type with - |Scal -> - let var_list, vec_list, code = gen_inst op in - var::var_list, vec_list, var, code + let rec get_code_var gen_name inst_type op = + match (gen_name op, can_use_old_name op) with + | Fresh var, _ | Old var, false -> ( + match inst_type with + | Scal -> + let var_list, vec_list, code = gen_inst op in + (var :: var_list, vec_list, var, code) | Vec -> - let var_list, vec_list, code = gen_inst op in - var_list, var::vec_list, var, code - end - | Old var, true -> - [], [], var, "" - + let var_list, vec_list, code = gen_inst op in + (var_list, var :: vec_list, var, code)) + | Old var, true -> ([], [], var, "") and format_binop inst_type gen_name format_inst_string tab res op1 op2 = let f_a var_a = - let var_list1, vec_list1, var1, code1 = get_code_var gen_name inst_type op1 in - let var_list2, vec_list2, var2, code2 = get_code_var gen_name inst_type op2 in - let var_list = List.concat [var_list1 ; var_list2] - and vec_list = List.concat [vec_list1 ; vec_list2] - and code = String.concat "\n" [code1; code2] in - let op = Printf.sprintf format_inst_string var1 var2 in - var_list, vec_list, - Format.sprintf "%s\n%s%s = %s;" code tab var_a op in - match gen_name res, can_use_old_name res with - | Old var_a, true -> f_a var_a - | Old var_a, false - | Fresh var_a, _ -> - let var_list, vec_list, code = f_a var_a in - begin + let var_list1, vec_list1, var1, code1 = + get_code_var gen_name inst_type op1 + in + let var_list2, vec_list2, var2, code2 = + get_code_var gen_name inst_type op2 + in + let var_list = List.concat [ var_list1; var_list2 ] + and vec_list = List.concat [ vec_list1; vec_list2 ] + and code = String.concat "\n" [ code1; code2 ] in + let op = Printf.sprintf format_inst_string var1 var2 in + (var_list, vec_list, Format.sprintf "%s\n%s%s = %s;" code tab var_a op) + in + match (gen_name res, can_use_old_name res) with + | Old var_a, true -> f_a var_a + | Old var_a, false | Fresh var_a, _ -> ( + let var_list, vec_list, code = f_a var_a in match inst_type with - |Scal -> var_a::var_list, vec_list, code - | Vec -> var_list, var_a::vec_list, code - end - + | Scal -> (var_a :: var_list, vec_list, code) + | Vec -> (var_list, var_a :: vec_list, code)) and format_3op inst_type gen_name format_inst_string tab res op1 op2 op3 = let f_a var_a = - let var_list1, vec_list1, var1, code1 = get_code_var gen_name inst_type op1 in - let var_list2, vec_list2, var2, code2 = get_code_var gen_name inst_type op2 in - let var_list3, vec_list3, var3, code3 = get_code_var gen_name inst_type op3 in - let var_list = List.concat [var_list1 ; var_list2; var_list3] - and vec_list = List.concat [vec_list1 ; vec_list2; vec_list3] - and code = String.concat "\n" [code1; code2; code3] in - let op = Printf.sprintf format_inst_string var1 var2 var3 in - var_list, vec_list, - Format.sprintf "%s\n%s%s = %s;" code tab var_a op in + let var_list1, vec_list1, var1, code1 = + get_code_var gen_name inst_type op1 + in + let var_list2, vec_list2, var2, code2 = + get_code_var gen_name inst_type op2 + in + let var_list3, vec_list3, var3, code3 = + get_code_var gen_name inst_type op3 + in + let var_list = List.concat [ var_list1; var_list2; var_list3 ] + and vec_list = List.concat [ vec_list1; vec_list2; vec_list3 ] + and code = String.concat "\n" [ code1; code2; code3 ] in + let op = Printf.sprintf format_inst_string var1 var2 var3 in + (var_list, vec_list, Format.sprintf "%s\n%s%s = %s;" code tab var_a op) + in - match gen_name res, can_use_old_name res with - | Old var_a, true -> f_a var_a - | Old var_a, false - | Fresh var_a, _ -> - let var_list, vec_list, code = f_a var_a in - begin + match (gen_name res, can_use_old_name res) with + | Old var_a, true -> f_a var_a + | Old var_a, false | Fresh var_a, _ -> ( + let var_list, vec_list, code = f_a var_a in match inst_type with - |Scal -> var_a::var_list, vec_list, code - | Vec -> var_list, var_a::vec_list, code - end - + | Scal -> (var_a :: var_list, vec_list, code) + | Vec -> (var_list, var_a :: vec_list, code)) and gen_inst = function - | Nop -> [], [], Format.sprintf "%snop;" tab - | Comment c -> [], [], Format.sprintf "%s// %s" tab c - | External_call {name; tensors_info; var_exprs} -> - let strides = List.concat_map (fun (t, _) -> - List.map [%show:Expr.t] @@ T.strides_in_order t) tensors_info - |> String.concat ", " in - let var_exprs = String.concat ", " @@ List.map [%show: Expr.t] var_exprs in - let formatted_accesses = String.concat ", " - @@ List.map (fun (t, acc) -> "&" ^ T.gen_access t acc) tensors_info in - [], [], Format.sprintf "%s%s(%s, %s, %s);" tab name - formatted_accesses - var_exprs strides - | Allocated(Scal, name) -> [name], [], Format.sprintf "%s%s;" tab name - | Allocated(Vec, name) -> [], [name], Format.sprintf "%s%s;" tab name - | Vtranspose (tensor_out, accesses_out, stride_out, - tensor_in, accesses_in, stride_in) -> - let map2 f x y = f x, f y in - let acc_out, acc_in = T.gen_access tensor_out accesses_out, - T.gen_access tensor_in accesses_in - and str_out, str_in = map2 Expr.show stride_out stride_in in - [], [], Format.sprintf "%s%s(&%s, %s, &%s, %s);" - tab - A.transpose_func - acc_out str_out - acc_in str_in - | Assign(Read(tensor, accesses), Scal, name) -> - [name], [], Format.sprintf "%s%s = %s;" tab name (T.gen_access tensor accesses) - | Assign(value, Scal, name) -> - (match gen_scal_name value with - | Fresh var -> - let scal_list, vec_list, code = gen_inst value in - var::name::scal_list, vec_list, Format.sprintf "%s\n%s%s = %s;" code tab name var - | Old var -> - [name], [], Format.sprintf "%s%s = %s;" tab name var - ) - | Read (tensor, accesses) as r -> - (match gen_scal_name r with - | Fresh name -> - let read_str = Format.sprintf "%s = %s;" name (T.gen_access tensor accesses) in - [name], [], Format.sprintf "%s%s" tab read_str - | Old name -> - let read_str = Format.sprintf "%s = %s;" name (T.gen_access tensor accesses) in - [], [], Format.sprintf "%s%s" tab read_str) - | Write (value, tensor, accesses) -> - add_write_accesses (tensor, accesses); - (match gen_scal_name value, can_use_old_name value with - | Fresh var, _ - | Old var, false -> - let scal_list, vec_list, code = gen_inst value in - let wr_str = Format.sprintf "%s = %s;" (T.gen_access tensor accesses) var in - var::scal_list, vec_list, Format.sprintf "%s\n%s%s" code tab wr_str - | Old var, true -> - let wr_str = Format.sprintf "%s = %s;" (T.gen_access tensor accesses) var in - [], [], Format.sprintf "%s%s" tab wr_str - ) + | Nop -> ([], [], Format.sprintf "%snop;" tab) + | Comment c -> ([], [], Format.sprintf "%s// %s" tab c) + | External_call { name; tensors_info; var_exprs } -> + let strides = + List.concat_map + (fun (t, _) -> List.map [%show: Expr.t] @@ T.strides_in_order t) + tensors_info + |> String.concat ", " + in + let var_exprs = + String.concat ", " @@ List.map [%show: Expr.t] var_exprs + in + let formatted_accesses = + String.concat ", " + @@ List.map (fun (t, acc) -> "&" ^ T.gen_access t acc) tensors_info + in + ( [], + [], + Format.sprintf "%s%s(%s, %s, %s);" tab name formatted_accesses + var_exprs strides ) + | Allocated (Scal, name) -> ([ name ], [], Format.sprintf "%s%s;" tab name) + | Allocated (Vec, name) -> ([], [ name ], Format.sprintf "%s%s;" tab name) + | Vtranspose + ( tensor_out, + accesses_out, + stride_out, + tensor_in, + accesses_in, + stride_in ) -> + let map2 f x y = (f x, f y) in + let acc_out, acc_in = + ( T.gen_access tensor_out accesses_out, + T.gen_access tensor_in accesses_in ) + and str_out, str_in = map2 Expr.show stride_out stride_in in + ( [], + [], + Format.sprintf "%s%s(&%s, %s, &%s, %s);" tab A.transpose_func + acc_out str_out acc_in str_in ) + | Assign (Read (tensor, accesses), Scal, name) -> + ( [ name ], + [], + Format.sprintf "%s%s = %s;" tab name (T.gen_access tensor accesses) + ) + | Assign (value, Scal, name) -> ( + match gen_scal_name value with + | Fresh var -> + let scal_list, vec_list, code = gen_inst value in + ( var :: name :: scal_list, + vec_list, + Format.sprintf "%s\n%s%s = %s;" code tab name var ) + | Old var -> ([ name ], [], Format.sprintf "%s%s = %s;" tab name var)) + | Read (tensor, accesses) as r -> ( + match gen_scal_name r with + | Fresh name -> + let read_str = + Format.sprintf "%s = %s;" name (T.gen_access tensor accesses) + in + ([ name ], [], Format.sprintf "%s%s" tab read_str) + | Old name -> + let read_str = + Format.sprintf "%s = %s;" name (T.gen_access tensor accesses) + in + ([], [], Format.sprintf "%s%s" tab read_str)) + | Write (value, tensor, accesses) -> ( + add_write_accesses (tensor, accesses); + match (gen_scal_name value, can_use_old_name value) with + | Fresh var, _ | Old var, false -> + let scal_list, vec_list, code = gen_inst value in + let wr_str = + Format.sprintf "%s = %s;" (T.gen_access tensor accesses) var + in + ( var :: scal_list, + vec_list, + Format.sprintf "%s\n%s%s" code tab wr_str ) + | Old var, true -> + let wr_str = + Format.sprintf "%s = %s;" (T.gen_access tensor accesses) var + in + ([], [], Format.sprintf "%s%s" tab wr_str)) | Add (v1, v2) as a -> - format_binop Scal gen_scal_name "%s + %s" tab a v1 v2 + format_binop Scal gen_scal_name "%s + %s" tab a v1 v2 | Sub (v1, v2) as a -> - format_binop Scal gen_scal_name "%s - %s" tab a v1 v2 + format_binop Scal gen_scal_name "%s - %s" tab a v1 v2 | Mul (v1, v2) as a -> - format_binop Scal gen_scal_name "%s * %s" tab a v1 v2 - | Vbcst value as vb -> - (match gen_vec_name vb with - | Old var_a -> - (match gen_scal_name value with - | Fresh var1 -> + format_binop Scal gen_scal_name "%s * %s" tab a v1 v2 + | Vbcst value as vb -> ( + match gen_vec_name vb with + | Old var_a -> ( + match gen_scal_name value with + | Fresh var1 -> + let scal_list, vec_list, code = gen_inst value in + let op = Format.sprintf V.gen_broadcast var1 in + ( var1 :: scal_list, + vec_list, + Format.sprintf "%s\n%s%s = %s;" code tab var_a op ) + | Old var1 -> + let op = Format.sprintf V.gen_broadcast var1 in + ([], [], Format.sprintf "%s%s = %s;" tab var_a op)) + | Fresh var_a -> ( + match gen_scal_name value with + | Fresh var1 -> + let var_list, vec_list, code = gen_inst value in + let op = Format.sprintf V.gen_broadcast var1 in + ( var1 :: var_list, + var_a :: vec_list, + Format.sprintf "%s\n%s%s = %s;" code tab var_a op ) + | Old var1 -> + let op = Format.sprintf V.gen_broadcast var1 in + ([], [ var_a ], Format.sprintf "%s%s = %s;" tab var_a op))) + | Vread (tensor, accesses) as r -> ( + match gen_vec_name r with + | Fresh name -> + let op = + Format.sprintf V.gen_load (T.gen_access tensor accesses) + in + ([], [ name ], Format.sprintf "%s%s = %s;" tab name op) + | Old name -> + let op = + Format.sprintf V.gen_load (T.gen_access tensor accesses) + in + ([], [], Format.sprintf "%s%s = %s;" tab name op)) + | Assign (Vread (tensor, accesses), Vec, name) -> + let op = Format.sprintf V.gen_load (T.gen_access tensor accesses) in + ([], [ name ], Format.sprintf "%s%s = %s;" tab name op) + | Assign (value, Vec, name) -> ( + match gen_vec_name value with + | Fresh var -> let scal_list, vec_list, code = gen_inst value in - let op = Format.sprintf V.gen_broadcast var1 in - var1::scal_list, vec_list, Format.sprintf "%s\n%s%s = %s;" code tab var_a op - | Old var1 -> - let op = Format.sprintf V.gen_broadcast var1 in - [], [], Format.sprintf "%s%s = %s;" tab var_a op - ) - | Fresh var_a -> - (match gen_scal_name value with - | Fresh var1 -> - let var_list, vec_list, code = gen_inst value in - let op = Format.sprintf V.gen_broadcast var1 in - var1::var_list, var_a::vec_list, - Format.sprintf "%s\n%s%s = %s;" code tab var_a op - | Old var1 -> - let op = Format.sprintf V.gen_broadcast var1 in - [], [var_a], Format.sprintf "%s%s = %s;" tab var_a op - )) - | Vread (tensor, accesses) as r -> - (match gen_vec_name r with - | Fresh name -> - let op = Format.sprintf V.gen_load (T.gen_access tensor accesses) in - [], [name], Format.sprintf "%s%s = %s;" tab name op - | Old name -> - let op = Format.sprintf V.gen_load (T.gen_access tensor accesses) in - [], [], Format.sprintf "%s%s = %s;" tab name op) - | Assign(Vread(tensor, accesses), Vec, name) -> - let op = Format.sprintf V.gen_load (T.gen_access tensor accesses) in - [], [name], Format.sprintf "%s%s = %s;" tab name op - | Assign(value, Vec, name) -> - (match gen_vec_name value with - | Fresh var - -> - let scal_list, vec_list, code = gen_inst value in - scal_list, var::name::vec_list, Format.sprintf "%s\n%s%s = %s;" code tab name var - | Old var -> - [], [name], Format.sprintf "%s%s = %s;" tab name var - ) - | Vwrite (value, tensor, accesses) -> - add_write_accesses (tensor, accesses); - (match gen_vec_name value, can_use_old_name value with - | Fresh var, _ - | Old var, false -> - let scal_list, vec_list, code = gen_inst value in - let op = Format.sprintf V.gen_store (T.gen_access tensor accesses) var in - scal_list, var::vec_list, Format.sprintf "%s\n%s%s;" code tab op - | Old var, true -> - let op = Format.sprintf V.gen_store (T.gen_access tensor accesses) var in - [], [], Format.sprintf "%s%s;"tab op - ) - | Vadd (Vmul(v1, v2), v3) | Vadd (v3, Vmul(v1, v2)) as a -> - format_3op Vec gen_vec_name V.gen_fma tab a v1 v2 v3 + ( scal_list, + var :: name :: vec_list, + Format.sprintf "%s\n%s%s = %s;" code tab name var ) + | Old var -> ([], [ name ], Format.sprintf "%s%s = %s;" tab name var)) + | Vwrite (value, tensor, accesses) -> ( + add_write_accesses (tensor, accesses); + match (gen_vec_name value, can_use_old_name value) with + | Fresh var, _ | Old var, false -> + let scal_list, vec_list, code = gen_inst value in + let op = + Format.sprintf V.gen_store (T.gen_access tensor accesses) var + in + ( scal_list, + var :: vec_list, + Format.sprintf "%s\n%s%s;" code tab op ) + | Old var, true -> + let op = + Format.sprintf V.gen_store (T.gen_access tensor accesses) var + in + ([], [], Format.sprintf "%s%s;" tab op)) + | (Vadd (Vmul (v1, v2), v3) | Vadd (v3, Vmul (v1, v2))) as a -> + format_3op Vec gen_vec_name V.gen_fma tab a v1 v2 v3 | Vadd (v1, v2) as a -> - format_binop Vec gen_vec_name V.gen_add tab a v1 v2 + format_binop Vec gen_vec_name V.gen_add tab a v1 v2 | Vsub (v1, v2) as s -> - format_binop Vec gen_vec_name V.gen_sub tab s v1 v2 + format_binop Vec gen_vec_name V.gen_sub tab s v1 v2 | Vmul (v1, v2) as m -> - format_binop Vec gen_vec_name V.gen_mul tab m v1 v2 + format_binop Vec gen_vec_name V.gen_mul tab m v1 v2 in - let scal_insts, vec_insts, code = List.fold_left + let scal_insts, vec_insts, code = + List.fold_left (fun (scal, vec, code) inst -> - let scal', vec', code' = gen_inst inst in - List.append scal' scal, List.append vec' vec, code'::code) - ([], [], []) value_list in - List.rev scal_insts, List.rev vec_insts, String.concat "\n" (List.rev code) + let scal', vec', code' = gen_inst inst in + (List.append scal' scal, List.append vec' vec, code' :: code)) + ([], [], []) value_list + in + (List.rev scal_insts, List.rev vec_insts, String.concat "\n" (List.rev code)) end diff --git a/ml/lib/instruction.mli b/ml/lib/instruction.mli index 31c78065fbfcc47c64dfae1e50d3d23ae563c0bc..f092dcb889cd332c3bd02ea9a34d0ce7762f4a78 100644 --- a/ml/lib/instruction.mli +++ b/ml/lib/instruction.mli @@ -2,4 +2,4 @@ open Arch module type Inst_t = Inst_sign.Inst_t -module I(A:Vec_arch_t) : Inst_t with module A = A +module I (A : Vec_arch_t) : Inst_t with module A = A diff --git a/ml/lib/loop_nest.ml b/ml/lib/loop_nest.ml index c0afcd2b0a830d53f0fd222def5baf31b08e0d5d..d7a2b3539a61224554dce52e85cee24094534dc7 100644 --- a/ml/lib/loop_nest.ml +++ b/ml/lib/loop_nest.ml @@ -4,395 +4,617 @@ open Utils open Instruction include Loop_nest_types - -module L(Inst: Inst_t) = struct +module L (Inst : Inst_t) = struct type loop = | Once of loop list | Statement of Inst.t list - | Unroll of {comment: string list; dim_index_id: Index.t; start: Expr.t ; - size: int; increment: Expr.t; body: loop list} - | Loop of {comment: string list; pragma: string option; dim_index_id: Index.t; - aux: (Index.t * Expr.t) list; start: Expr.t ; - halt: Expr.t; increment: Expr.t; body: loop list; - vars_decl: (string * Expr.t) list; - } [@@deriving show] + | Unroll of { + comment : string list; + dim_index_id : Index.t; + start : Expr.t; + size : int; + increment : Expr.t; + body : loop list; + } + | Loop of { + comment : string list; + pragma : string option; + dim_index_id : Index.t; + aux : (Index.t * Expr.t) list; + start : Expr.t; + halt : Expr.t; + increment : Expr.t; + body : loop list; + vars_decl : (string * Expr.t) list; + } + [@@deriving show] let rec show_alt = function - | Once l -> "Once {\n" ^ (String.concat "\n" (List.map show_alt l)) ^ "} // Once\n" + | Once l -> + "Once {\n" ^ String.concat "\n" (List.map show_alt l) ^ "} // Once\n" | Statement _ -> "Insts\n" - | Unroll {dim_index_id; size; body;_ } -> Printf.sprintf "Unroll %s %d {\n%s\n} // Unroll\n" (Index.show_id_of_t dim_index_id) size - (String.concat "\n" (List.map show_alt body)) - | Loop {dim_index_id; body;_ } -> Printf.sprintf "Loop %s {\n%s\n}\n" (Index.show_id_of_t dim_index_id) - (String.concat "\n" (List.map show_alt body)) + | Unroll { dim_index_id; size; body; _ } -> + Printf.sprintf "Unroll %s %d {\n%s\n} // Unroll\n" + (Index.show_id_of_t dim_index_id) + size + (String.concat "\n" (List.map show_alt body)) + | Loop { dim_index_id; body; _ } -> + Printf.sprintf "Loop %s {\n%s\n}\n" + (Index.show_id_of_t dim_index_id) + (String.concat "\n" (List.map show_alt body)) let rec map_all_insts f = function - Once l -> Once (List.map (map_all_insts f) l) - | Statement insts -> Statement (f insts) (* Do we need to map here ? *) - | Unroll ({start; increment; body; _} as arg) -> - let body = List.map (map_all_insts f) body in - Unroll {arg with body; start; increment} - | Loop ({ body; _} as arg) -> - let body = List.map (map_all_insts f) body in - Loop {arg with body; } + | Once l -> Once (List.map (map_all_insts f) l) + | Statement insts -> Statement (f insts) (* Do we need to map here ? *) + | Unroll ({ start; increment; body; _ } as arg) -> + let body = List.map (map_all_insts f) body in + Unroll { arg with body; start; increment } + | Loop ({ body; _ } as arg) -> + let body = List.map (map_all_insts f) body in + Loop { arg with body } let rec map_const_expr f = function - Once l -> Once (List.map (map_const_expr f) l) - | Statement insts -> Statement insts (* Do we need to map here ? *) - | Unroll ({start; increment;body; _} as arg) -> - let body = List.map (map_const_expr f) body - and start = f start - and increment = f increment in - Unroll {arg with body; start; increment} - | Loop ({start; halt;increment;body; _} as arg) -> - let body = List.map (map_const_expr f) body - and start = f start - and halt = f halt - and increment = f increment in - Loop {arg with body; halt; start; increment} + | Once l -> Once (List.map (map_const_expr f) l) + | Statement insts -> Statement insts (* Do we need to map here ? *) + | Unroll ({ start; increment; body; _ } as arg) -> + let body = List.map (map_const_expr f) body + and start = f start + and increment = f increment in + Unroll { arg with body; start; increment } + | Loop ({ start; halt; increment; body; _ } as arg) -> + let body = List.map (map_const_expr f) body + and start = f start + and halt = f halt + and increment = f increment in + Loop { arg with body; halt; start; increment } (* Returns Some list of instructions if all subloops in loop are instructions * Returns None otherwise *) let to_instructions_list loop_list = let rec to_inst_lst acc = function - | [] -> Some (acc) - | Statement inst::tail -> to_inst_lst (inst @ acc) tail + | [] -> Some acc + | Statement inst :: tail -> to_inst_lst (inst @ acc) tail | _ -> None in to_inst_lst [] loop_list let set_comment comment = function - | Loop args -> Loop {args with comment} - | Unroll args -> Unroll {args with comment} + | Loop args -> Loop { args with comment } + | Unroll args -> Unroll { args with comment } (* Other sorts of loops don't support comment yet *) | l -> l - let unroll_insts_dim _comments dim start size increment insts = + let unroll_insts_dim _comments dim start size increment insts = let rewrite_insts cnst inst = Inst.map_expr dim - (Expr.alpha_replace_dim dim (Expr.(I.(start + const cnst * increment )))) inst in - List.init size Fun.id |> List.concat_map (fun cnst -> List.map (rewrite_insts cnst) insts) - + (Expr.alpha_replace_dim dim Expr.(I.(start + (const cnst * increment)))) + inst + in + List.init size Fun.id + |> List.concat_map (fun cnst -> List.map (rewrite_insts cnst) insts) module Zipper = struct type zipper = - | Top + | Top (* | OnceChild of loop list * int * loop list *) - | LoopChild of {parent : zipper ; pragma : string option; aux : (Index.t * Expr.t) list; - comment: string list; - halt : Expr.t; dim_index_id: Index.t; start: Expr.t ; - increment: Expr.t; - left : loop list; right : loop list; - vars_decl: (string * Expr.t) list; - } - | UnrChild of { parent : zipper ; comment: string list; dim_index_id: Index.t; start: Expr.t ; - size: int; increment: Expr.t; left : loop list ; - right: loop list} - | OnceChild of {parent : zipper; left : loop list; right : loop list} [@@deriving show] + | LoopChild of { + parent : zipper; + pragma : string option; + aux : (Index.t * Expr.t) list; + comment : string list; + halt : Expr.t; + dim_index_id : Index.t; + start : Expr.t; + increment : Expr.t; + left : loop list; + right : loop list; + vars_decl : (string * Expr.t) list; + } + | UnrChild of { + parent : zipper; + comment : string list; + dim_index_id : Index.t; + start : Expr.t; + size : int; + increment : Expr.t; + left : loop list; + right : loop list; + } + | OnceChild of { parent : zipper; left : loop list; right : loop list } + [@@deriving show] type t = zipper * loop [@@deriving show] let go_down i = function | _, Statement _ -> None - | z, Loop {pragma ; aux ; comment; halt ; dim_index_id; start ; - increment; body; vars_decl} -> - begin match List.takedrop i body with - | left, elem::right -> - Some (LoopChild {parent = z; left; right; - pragma; aux; comment; halt; - dim_index_id; start; increment; vars_decl}, - elem) - | _ -> None - end - | z, Unroll { size ; comment; dim_index_id; start ; increment; body} -> - begin match List.takedrop i body with - | left, elem::right -> - Some (UnrChild {parent = z; left; right; - comment; size; dim_index_id; start; increment}, - elem) - | _ -> None - end - | z, Once l -> - begin match List.takedrop i l with - | left, elem::right -> - Some (OnceChild {parent = z; left; right;}, elem) - | _ -> None - end + | ( z, + Loop + { + pragma; + aux; + comment; + halt; + dim_index_id; + start; + increment; + body; + vars_decl; + } ) -> ( + match List.takedrop i body with + | left, elem :: right -> + Some + ( LoopChild + { + parent = z; + left; + right; + pragma; + aux; + comment; + halt; + dim_index_id; + start; + increment; + vars_decl; + }, + elem ) + | _ -> None) + | z, Unroll { size; comment; dim_index_id; start; increment; body } -> ( + match List.takedrop i body with + | left, elem :: right -> + Some + ( UnrChild + { + parent = z; + left; + right; + comment; + size; + dim_index_id; + start; + increment; + }, + elem ) + | _ -> None) + | z, Once l -> ( + match List.takedrop i l with + | left, elem :: right -> + Some (OnceChild { parent = z; left; right }, elem) + | _ -> None) let go_up_with_path = function | Top, _ -> None - | OnceChild {parent; left; right}, t -> - Some (parent, Once (left @ [t] @ right), - Option.get % go_down (List.length left)) - | LoopChild {parent; left; aux; halt; right; comment; - pragma; dim_index_id; start; increment; vars_decl }, t -> - Some (parent, - Loop {comment; aux; halt; pragma; dim_index_id; start; - increment; - body = left @ [t] @ right; - vars_decl - }, - Option.get % go_down (List.length left)) - | UnrChild {parent; left; size; right; comment; dim_index_id; start; increment;}, t -> - Some (parent, - Unroll {comment; size; dim_index_id; start; increment; body = left @ [t] @ right}, - Option.get % go_down (List.length left)) - - - let go_up = - Option.map (fun (z, l, _) -> z, l) % go_up_with_path - + | OnceChild { parent; left; right }, t -> + Some + ( parent, + Once (left @ [ t ] @ right), + Option.get % go_down (List.length left) ) + | ( LoopChild + { + parent; + left; + aux; + halt; + right; + comment; + pragma; + dim_index_id; + start; + increment; + vars_decl; + }, + t ) -> + Some + ( parent, + Loop + { + comment; + aux; + halt; + pragma; + dim_index_id; + start; + increment; + body = left @ [ t ] @ right; + vars_decl; + }, + Option.get % go_down (List.length left) ) + | ( UnrChild + { + parent; + left; + size; + right; + comment; + dim_index_id; + start; + increment; + }, + t ) -> + Some + ( parent, + Unroll + { + comment; + size; + dim_index_id; + start; + increment; + body = left @ [ t ] @ right; + }, + Option.get % go_down (List.length left) ) + + let go_up = Option.map (fun (z, l, _) -> (z, l)) % go_up_with_path let set_top_comment comment = (* back is a "back path". It brings original argument back where it was *) - let rec aux back (z,l) = match go_up_with_path (z,l) with + let rec aux back (z, l) = + match go_up_with_path (z, l) with | Some (z', l', b) -> aux (back %> b) (z', l') - | None -> back (z, set_comment comment l) in + | None -> back (z, set_comment comment l) + in aux Fun.id - + let map_all_insts f (z, l) = let rec loop = function | Top -> Top - | LoopChild ({parent; left; right;_} as arg) -> - let left = List.map (map_all_insts f) left - and right = List.map (map_all_insts f) right - and parent = loop parent in - LoopChild {arg with left; right; parent} - | UnrChild ({parent; left; right;_} as arg) -> - let left = List.map (map_all_insts f) left - and right = List.map (map_all_insts f) right - and parent = loop parent in - UnrChild {arg with left; right; parent} - | OnceChild {parent; left; right} -> - let left = List.map (map_all_insts f) left - and right = List.map (map_all_insts f) right - and parent = loop parent in - OnceChild {left; right; parent} in + | LoopChild ({ parent; left; right; _ } as arg) -> + let left = List.map (map_all_insts f) left + and right = List.map (map_all_insts f) right + and parent = loop parent in + LoopChild { arg with left; right; parent } + | UnrChild ({ parent; left; right; _ } as arg) -> + let left = List.map (map_all_insts f) left + and right = List.map (map_all_insts f) right + and parent = loop parent in + UnrChild { arg with left; right; parent } + | OnceChild { parent; left; right } -> + let left = List.map (map_all_insts f) left + and right = List.map (map_all_insts f) right + and parent = loop parent in + OnceChild { left; right; parent } + in (loop z, map_all_insts f l) let map_const_expr f (z, l) = let rec loop = function | Top -> Top - | LoopChild ({parent; start; halt; increment; left; right;vars_decl;_} as arg) -> - let left = List.map (map_const_expr f) left - and right = List.map (map_const_expr f) right - and vars_decl = List.map (Utils.map_snd f) vars_decl - and start = f start - and halt = f halt - and increment = f increment - and parent = loop parent in - LoopChild {arg with left; right; start; halt; increment; parent;vars_decl} - | UnrChild ({parent; start; increment; left; right;_} as arg) -> - let left = List.map (map_const_expr f) left - and right = List.map (map_const_expr f) right - and start = f start - and increment = f increment - and parent = loop parent in - UnrChild {arg with left; right; start; increment; parent} - | OnceChild {parent; left; right} -> - let left = List.map (map_const_expr f) left - and right = List.map (map_const_expr f) right - and parent = loop parent in - OnceChild {left; right; parent} in + | LoopChild + ({ parent; start; halt; increment; left; right; vars_decl; _ } as + arg) -> + let left = List.map (map_const_expr f) left + and right = List.map (map_const_expr f) right + and vars_decl = List.map (Utils.map_snd f) vars_decl + and start = f start + and halt = f halt + and increment = f increment + and parent = loop parent in + LoopChild + { + arg with + left; + right; + start; + halt; + increment; + parent; + vars_decl; + } + | UnrChild ({ parent; start; increment; left; right; _ } as arg) -> + let left = List.map (map_const_expr f) left + and right = List.map (map_const_expr f) right + and start = f start + and increment = f increment + and parent = loop parent in + UnrChild { arg with left; right; start; increment; parent } + | OnceChild { parent; left; right } -> + let left = List.map (map_const_expr f) left + and right = List.map (map_const_expr f) right + and parent = loop parent in + OnceChild { left; right; parent } + in (loop z, map_const_expr f l) let to_tree = - let rec loop (z, ln) = match go_up (z, ln) with - | Some (z', ln') -> loop (z', ln') - | None -> ln in + let rec loop (z, ln) = + match go_up (z, ln) with Some (z', ln') -> loop (z', ln') | None -> ln + in loop let map_inst f = function - (z, Statement insts) -> z, Statement (f insts) + | z, Statement insts -> (z, Statement (f insts)) | _ -> failwith "Not an instruction" let unroll_inst dim start size increment = let rewrite_insts inst cnst = Inst.map_expr dim - (Expr.alpha_replace_dim dim (Expr.(I.(start + const cnst * increment )))) inst in + (Expr.alpha_replace_dim dim + Expr.(I.(start + (const cnst * increment)))) + inst + in let unroll_inst inst = - List.init size Fun.id |> List.map (rewrite_insts inst) in - function (z, Statement stmts) -> z, Statement (List.concat_map unroll_inst stmts) - | _ -> failwith "Cannot unroll on anything but a statement" + List.init size Fun.id |> List.map (rewrite_insts inst) + in + function + | z, Statement stmts -> (z, Statement (List.concat_map unroll_inst stmts)) + | _ -> failwith "Cannot unroll on anything but a statement" let embed_before t = function - | Top, Once t2 -> Top, Once (t :: t2) - | Top, (Statement _ as s) -> Top, Once (t :: [s]) - | Top, l -> Top, Once ( t :: [l]) - | OnceChild {left; parent; right}, l -> OnceChild{left = t :: left; parent; right}, l - | LoopChild ({left; _} as arg), l -> LoopChild {arg with left = t :: left}, l - | UnrChild ({left; _} as arg), l -> UnrChild {arg with left = t :: left}, l + | Top, Once t2 -> (Top, Once (t :: t2)) + | Top, (Statement _ as s) -> (Top, Once (t :: [ s ])) + | Top, l -> (Top, Once (t :: [ l ])) + | OnceChild { left; parent; right }, l -> + (OnceChild { left = t :: left; parent; right }, l) + | LoopChild ({ left; _ } as arg), l -> + (LoopChild { arg with left = t :: left }, l) + | UnrChild ({ left; _ } as arg), l -> + (UnrChild { arg with left = t :: left }, l) let embed_before_at_top t (z, l) = let rec aux = function - | Top -> OnceChild {parent = Top; left = [t]; right = []} - | LoopChild ({parent; _} as arg) -> - LoopChild {arg with parent = aux parent} - | UnrChild ({parent; _} as arg) -> - UnrChild {arg with parent = aux parent} - | OnceChild ({parent; _} as arg) -> - OnceChild {arg with parent = aux parent} in + | Top -> OnceChild { parent = Top; left = [ t ]; right = [] } + | LoopChild ({ parent; _ } as arg) -> + LoopChild { arg with parent = aux parent } + | UnrChild ({ parent; _ } as arg) -> + UnrChild { arg with parent = aux parent } + | OnceChild ({ parent; _ } as arg) -> + OnceChild { arg with parent = aux parent } + in (aux z, l) let embed_after t = function - | Top, Once t2 -> Top, Once (t2 @ [t]) - | Top, l -> Top, Once ( [l] @ [t]) - | OnceChild {left; parent; right}, l -> OnceChild{left; parent; right = right @ [t]}, l - | LoopChild ({right; _} as arg), l -> LoopChild {arg with right = right @ [t]}, l - | UnrChild ({right; _} as arg), l -> UnrChild {arg with right = right @ [t]}, l + | Top, Once t2 -> (Top, Once (t2 @ [ t ])) + | Top, l -> (Top, Once ([ l ] @ [ t ])) + | OnceChild { left; parent; right }, l -> + (OnceChild { left; parent; right = right @ [ t ] }, l) + | LoopChild ({ right; _ } as arg), l -> + (LoopChild { arg with right = right @ [ t ] }, l) + | UnrChild ({ right; _ } as arg), l -> + (UnrChild { arg with right = right @ [ t ] }, l) let embed_after_at_top t (z, l) = let rec aux = function - | Top -> OnceChild {parent = Top; left = []; right = [t]} - | LoopChild ({parent; _} as arg) -> - LoopChild {arg with parent = aux parent} - | UnrChild ({parent; _} as arg) -> - UnrChild {arg with parent = aux parent} - | OnceChild ({parent; _} as arg) -> - OnceChild {arg with parent = aux parent} in + | Top -> OnceChild { parent = Top; left = []; right = [ t ] } + | LoopChild ({ parent; _ } as arg) -> + LoopChild { arg with parent = aux parent } + | UnrChild ({ parent; _ } as arg) -> + UnrChild { arg with parent = aux parent } + | OnceChild ({ parent; _ } as arg) -> + OnceChild { arg with parent = aux parent } + in (aux z, l) - - let new_seq ?(comment=[]) ?pragma ?aux ?(vars_decl=[]) dim_index_id start halt increment = + let new_seq ?(comment = []) ?pragma ?aux ?(vars_decl = []) dim_index_id + start halt increment = let aux = match aux with Some l -> l | None -> [] in let rec wrap_up = function | Top -> - LoopChild { parent = Top; dim_index_id; start; aux; pragma; vars_decl; halt; - increment; comment; - left = []; right = []} - | LoopChild ({parent;_} as arg) -> - LoopChild {arg with parent = wrap_up parent} - | UnrChild ({parent;_} as arg) -> - UnrChild{arg with parent = wrap_up parent} - | OnceChild ({parent;_} as arg) -> - OnceChild{arg with parent = wrap_up parent} in - fun (z, tree) -> wrap_up z, tree - - let new_unroll ?(comment=[]) dim_index_id start size increment = + LoopChild + { + parent = Top; + dim_index_id; + start; + aux; + pragma; + vars_decl; + halt; + increment; + comment; + left = []; + right = []; + } + | LoopChild ({ parent; _ } as arg) -> + LoopChild { arg with parent = wrap_up parent } + | UnrChild ({ parent; _ } as arg) -> + UnrChild { arg with parent = wrap_up parent } + | OnceChild ({ parent; _ } as arg) -> + OnceChild { arg with parent = wrap_up parent } + in + fun (z, tree) -> (wrap_up z, tree) + + let new_unroll ?(comment = []) dim_index_id start size increment = let rec wrap_up = function | Top -> - UnrChild { parent = Top; dim_index_id; start; size; - increment; comment; left = []; right = []} - | LoopChild ({parent;_} as arg) -> - LoopChild {arg with parent = wrap_up parent} - | UnrChild ({parent;_} as arg) -> - UnrChild{arg with parent = wrap_up parent} - | OnceChild ({parent;_} as arg) -> - OnceChild{arg with parent = wrap_up parent} in - fun (z, tree) -> wrap_up z, tree + UnrChild + { + parent = Top; + dim_index_id; + start; + size; + increment; + comment; + left = []; + right = []; + } + | LoopChild ({ parent; _ } as arg) -> + LoopChild { arg with parent = wrap_up parent } + | UnrChild ({ parent; _ } as arg) -> + UnrChild { arg with parent = wrap_up parent } + | OnceChild ({ parent; _ } as arg) -> + OnceChild { arg with parent = wrap_up parent } + in + fun (z, tree) -> (wrap_up z, tree) end - - let unroll ?st:(start= Expr.Zero) size dim increment body = - Unroll {comment = []; dim_index_id = Index.from_dim dim; start; size; increment; body} - - let loop ?st:(start= Expr.Zero) ?halt ?incr:(incr= Expr.One) ?pragma - ?(vars_decl=[]) dim body = + let unroll ?st:(start = Expr.Zero) size dim increment body = + Unroll + { + comment = []; + dim_index_id = Index.from_dim dim; + start; + size; + increment; + body; + } + + let loop ?st:(start = Expr.Zero) ?halt ?(incr = Expr.One) ?pragma + ?(vars_decl = []) dim body = let halt = Option.default (Expr.size (Dim.size_id dim)) halt in - Loop {dim_index_id = Index.from_dim dim; aux=[]; start; halt; - vars_decl; - increment = incr; body; pragma; comment = []} + Loop + { + dim_index_id = Index.from_dim dim; + aux = []; + start; + halt; + vars_decl; + increment = incr; + body; + pragma; + comment = []; + } let rec set_aux index aux = function - | Loop ({dim_index_id;_} as arg) - when Index.equal index dim_index_id -> - Loop {arg with aux} - | Loop ({body;_} as arg) -> - Loop {arg with body = List.map (set_aux index aux) body} - | Unroll ({body;_} as arg) -> - Unroll {arg with body = List.map (set_aux index aux) body} + | Loop ({ dim_index_id; _ } as arg) when Index.equal index dim_index_id -> + Loop { arg with aux } + | Loop ({ body; _ } as arg) -> + Loop { arg with body = List.map (set_aux index aux) body } + | Unroll ({ body; _ } as arg) -> + Unroll { arg with body = List.map (set_aux index aux) body } | Statement stmts -> Statement stmts - | Once body -> Once ( List.map (set_aux index aux) body) + | Once body -> Once (List.map (set_aux index aux) body) let rec propagate_expr old_id expr = function | Statement stmt -> - let f = Inst.map_expr (Index.dim old_id) (Expr.alpha_replace old_id expr) in - Statement (List.map f stmt) - | Unroll ({ start; body; _} as l_args) -> - let start = Expr.alpha_replace old_id expr start in - let body = List.map (propagate_expr old_id expr) body in - Unroll {l_args with start; body} - | Loop ({ start; halt; vars_decl; increment; body; _} as l_args) -> - let map_3tuple f x y z = f x, f y, f z in - let start, increment, halt = map_3tuple (Expr.alpha_replace old_id expr) start increment halt in - let vars_decl = List.map (Utils.map_snd @@ Expr.alpha_replace old_id expr) vars_decl in - let body = List.map (propagate_expr old_id expr) body in - Loop {l_args with start; increment; body; halt; vars_decl} + let f = + Inst.map_expr (Index.dim old_id) (Expr.alpha_replace old_id expr) + in + Statement (List.map f stmt) + | Unroll ({ start; body; _ } as l_args) -> + let start = Expr.alpha_replace old_id expr start in + let body = List.map (propagate_expr old_id expr) body in + Unroll { l_args with start; body } + | Loop ({ start; halt; vars_decl; increment; body; _ } as l_args) -> + let map_3tuple f x y z = (f x, f y, f z) in + let start, increment, halt = + map_3tuple (Expr.alpha_replace old_id expr) start increment halt + in + let vars_decl = + List.map (Utils.map_snd @@ Expr.alpha_replace old_id expr) vars_decl + in + let body = List.map (propagate_expr old_id expr) body in + Loop { l_args with start; increment; body; halt; vars_decl } | Once loop -> Once (List.map (propagate_expr old_id expr) loop) - let rec gen_loop_code tab size_indexes loop_nest: string list * string list * string = + let rec gen_loop_code tab loop_nest : string list * string list * string = (* This call is probably to be removed *) - let gen_body body = begin match to_instructions_list body with - Some inst_list -> - Inst.gen_code (tab ^ "\t") inst_list + let gen_body body = + match to_instructions_list body with + | Some inst_list -> Inst.gen_code (tab ^ "\t") inst_list | None -> - let scal_list, vec_list, code_list = - List.split3 @@ List.map (gen_loop_code (tab ^ "\t") size_indexes) body in - List.concat scal_list, List.concat vec_list, String.concat "\n" code_list end in + let scal_list, vec_list, code_list = + List.split3 @@ List.map (gen_loop_code (tab ^ "\t")) body + in + ( List.concat scal_list, + List.concat vec_list, + String.concat "\n" code_list ) + in match loop_nest with - | Statement inst -> - Inst.gen_code tab inst - | Once loops -> - gen_body loops - | Unroll { comment; dim_index_id; start; size; body; increment; _;} -> - let comment = comment - |> List.map (fun c -> tab ^ " // " ^ c ) - |> String.concat "\n" - |> fun s -> if String.is_empty s then s else s ^ "\n" - in - (* List of numbers from 0 to loop_size - 1 *) - List.init size Fun.id - |> List.map (fun cnst_index -> List.map - (propagate_expr dim_index_id (Expr.(I.(start + increment * const cnst_index)))) body) - |> List.flatten - |> gen_body - |> fun (sl, vl, code) -> sl, vl, comment ^ code - | Loop { comment; pragma; dim_index_id; aux; start; halt; increment; body; vars_decl;_;} -> - let scal_list, vec_list, body_str = gen_body body in - let index_str = (Index.show_id (Index.id dim_index_id)) in - let start_str = - (Expr.show start) in - let halt_str = (Expr.show @@ Expr.simplify halt) in - let increment_str = (Expr.show @@ Expr.simplify increment) in - let aux_list = List.map (fun (index, start) -> Printf.sprintf "%s = %s" - (Index.show_id (Index.id index)) (Expr.show @@ Expr.simplify start)) aux in - let comment = comment - |> List.map (fun c -> tab ^ "// " ^ c ) - |> String.concat "\n" - |> fun s -> if String.is_empty s then s else s ^ "\n" in - let unzip l = let l1, l2 = List.fold_left (fun (la, lb) (a, b) -> a :: la, b :: lb) - ([], []) l in - List.rev l1, List.rev l2 in - let names, exprs = unzip vars_decl in - let names = String.concat ", " names - and exprs = List.map Expr.show exprs |> String.concat ", " in - let var_decls = if List.is_empty vars_decl then "" - else Printf.sprintf "int %s = %s;" names exprs in - let pragma = Option.map_default (fun c -> tab ^ "#pragma " ^ c ^ "\n") "" - pragma in - let aux_start = match aux_list with [] -> "" - | _ -> ", " ^ String.concat ", " aux_list in - let aux_list = List.map (fun (index, _) -> Printf.sprintf "%s += %s" - (Index.show_id (Index.id index)) (Expr.show @@ Expr.simplify increment)) aux in - let aux_incr = match aux_list with [] -> "" - | _ -> ", " ^ String.concat ", " aux_list in - scal_list, vec_list, - Format.sprintf "%s%s%sfor (%s = %s%s;\n%s\t%s < %s;\n%s\t%s += %s%s){\n\t%s%s\n%s\n%s}" - comment pragma tab - index_str start_str aux_start tab - index_str halt_str tab index_str increment_str aux_incr - tab var_decls - body_str tab - - let gen_code sizes_indexes loop_nest = -(* print_endline (show_alt loop_nest); *) - let scal_list, vec_list, code = gen_loop_code "" sizes_indexes loop_nest in - let scal_decl = match scal_list with - [] -> "" - | _ -> String.concat " ," (List.sort_unique String.compare scal_list) - |> Printf.sprintf "%s %s;\n" Inst.A.base_type_name in - let vec_decl = match vec_list with - [] -> "" - | _ -> String.concat " ," (List.sort_unique String.compare vec_list) - |> Printf.sprintf "%s %s;\n" Inst.A.vec_type_name in + | Statement inst -> Inst.gen_code tab inst + | Once loops -> gen_body loops + | Unroll { comment; dim_index_id; start; size; body; increment; _ } -> + let comment = + comment |> List.map (fun c -> tab ^ " // " ^ c) |> String.concat "\n" + |> fun s -> if String.is_empty s then s else s ^ "\n" + in + (* List of numbers from 0 to loop_size - 1 *) + List.init size Fun.id + |> List.map (fun cnst_index -> + List.map + (propagate_expr dim_index_id + Expr.(I.(start + (increment * const cnst_index)))) + body) + |> List.flatten |> gen_body + |> fun (sl, vl, code) -> (sl, vl, comment ^ code) + | Loop + { + comment; + pragma; + dim_index_id; + aux; + start; + halt; + increment; + body; + vars_decl; + _; + } -> + let scal_list, vec_list, body_str = gen_body body in + let index_str = Index.show_id (Index.id dim_index_id) in + let start_str = Expr.show start in + let halt_str = Expr.show @@ Expr.simplify halt in + let increment_str = Expr.show @@ Expr.simplify increment in + let aux_list = + List.map + (fun (index, start) -> + Printf.sprintf "%s = %s" + (Index.show_id (Index.id index)) + (Expr.show @@ Expr.simplify start)) + aux + in + let comment = + comment |> List.map (fun c -> tab ^ "// " ^ c) |> String.concat "\n" + |> fun s -> if String.is_empty s then s else s ^ "\n" + in + let vars_decls_str = + List.map + (fun (name, expr) -> + Printf.sprintf "int %s = %s;" name (Expr.show expr)) + vars_decl + in + let var_decls = + if List.is_empty vars_decl then "" + else String.concat "\n" vars_decls_str + in + let pragma = + Option.map_default (fun c -> tab ^ "#pragma " ^ c ^ "\n") "" pragma + in + let aux_start = + match aux_list with + | [] -> "" + | _ -> ", " ^ String.concat ", " aux_list + in + let aux_list = + List.map + (fun (index, _) -> + Printf.sprintf "%s += %s" + (Index.show_id (Index.id index)) + (Expr.show @@ Expr.simplify increment)) + aux + in + let aux_incr = + match aux_list with + | [] -> "" + | _ -> ", " ^ String.concat ", " aux_list + in + ( scal_list, + vec_list, + Format.sprintf + "%s%s%sfor (%s = %s%s;\n\ + %s\t%s < %s;\n\ + %s\t%s += %s%s){\n\ + \t%s%s\n\ + %s\n\ + %s}" + comment pragma tab index_str start_str aux_start tab index_str + halt_str tab index_str increment_str aux_incr tab var_decls body_str + tab ) + + let gen_code loop_nest = + (* print_endline (show_alt loop_nest); *) + let scal_list, vec_list, code = gen_loop_code "" loop_nest in + let scal_decl = + match scal_list with + | [] -> "" + | _ -> + String.concat " ," (List.sort_unique String.compare scal_list) + |> Printf.sprintf "%s %s;\n" Inst.A.base_type_name + in + let vec_decl = + match vec_list with + | [] -> "" + | _ -> + String.concat " ," (List.sort_unique String.compare vec_list) + |> Printf.sprintf "%s %s;\n" Inst.A.vec_type_name + in scal_decl ^ vec_decl ^ code end diff --git a/ml/lib/loop_nest.mli b/ml/lib/loop_nest.mli index b6f09aad916b9d2d662f1de95b68bcda08bbf6dc..a3d25f1da2c35d8a8bff3928bb7cfbdcc2e922b0 100644 --- a/ml/lib/loop_nest.mli +++ b/ml/lib/loop_nest.mli @@ -1,5 +1,3 @@ open Instruction - include module type of Loop_nest_types - -module L(Inst: Inst_t) : Loopnest_t with module Inst := Inst +module L (Inst : Inst_t) : Loopnest_t with module Inst := Inst diff --git a/ml/lib/loop_nest_types.ml b/ml/lib/loop_nest_types.ml index aaf55fec755f0bda1156647e57a169b500e39997..a4154ca4a9073e047078f9ee030a27ffa9a8eb56 100644 --- a/ml/lib/loop_nest_types.ml +++ b/ml/lib/loop_nest_types.ml @@ -3,67 +3,117 @@ open Exprs open Instruction module type Loopnest_t = sig - module Inst: Inst_t + module Inst : Inst_t + type loop = | Once of loop list | Statement of Inst.t list - | Unroll of {comment: string list; dim_index_id: Index.t; start: Expr.t ; - size: int; increment: Expr.t; body: loop list} - | Loop of {comment: string list; pragma: string option; dim_index_id: Index.t; - aux: (Index.t * Expr.t) list; start: Expr.t ; - halt: Expr.t; increment: Expr.t; body: loop list; - vars_decl: (string * Expr.t) list; - } [@@deriving show] + | Unroll of { + comment : string list; + dim_index_id : Index.t; + start : Expr.t; + size : int; + increment : Expr.t; + body : loop list; + } + | Loop of { + comment : string list; + pragma : string option; + dim_index_id : Index.t; + aux : (Index.t * Expr.t) list; + start : Expr.t; + halt : Expr.t; + increment : Expr.t; + body : loop list; + vars_decl : (string * Expr.t) list; + } + [@@deriving show] val map_all_insts : (Inst.t list -> Inst.t list) -> loop -> loop val map_const_expr : (Expr.t -> Expr.t) -> loop -> loop + module Zipper : sig type zipper = - | Top - | LoopChild of {parent : zipper ; pragma : string option; - aux : (Index.t * Expr.t) list; comment: string list; - halt : Expr.t; dim_index_id: Index.t; - start: Expr.t ; increment: Expr.t; - left : loop list; right : loop - list; - vars_decl: (string * Expr.t) list; - } - | UnrChild of { parent : zipper ; comment: string list; dim_index_id: Index.t; - start: Expr.t ; size: int; increment: Expr.t; left : loop list ; - right: loop list} - | OnceChild of {parent : zipper; left : loop list; right : loop list} + | Top + | LoopChild of { + parent : zipper; + pragma : string option; + aux : (Index.t * Expr.t) list; + comment : string list; + halt : Expr.t; + dim_index_id : Index.t; + start : Expr.t; + increment : Expr.t; + left : loop list; + right : loop list; + vars_decl : (string * Expr.t) list; + } + | UnrChild of { + parent : zipper; + comment : string list; + dim_index_id : Index.t; + start : Expr.t; + size : int; + increment : Expr.t; + left : loop list; + right : loop list; + } + | OnceChild of { parent : zipper; left : loop list; right : loop list } type t = zipper * loop [@@deriving show] - val go_down: int -> t -> t option - val go_up_with_path: t -> (zipper * loop * (t -> t)) option - val go_up: t -> t option - + val go_down : int -> t -> t option + val go_up_with_path : t -> (zipper * loop * (t -> t)) option + val go_up : t -> t option val set_top_comment : string list -> t -> t - val to_tree : t -> loop val map_const_expr : (Expr.t -> Expr.t) -> t -> t val map_all_insts : (Inst.t list -> Inst.t list) -> t -> t val unroll_inst : Dim.id -> Expr.t -> int -> Expr.t -> t -> t val map_inst : (Inst.t list -> Inst.t list) -> t -> t val embed_before : loop -> t -> t - val embed_before_at_top: loop -> t -> t + val embed_before_at_top : loop -> t -> t val embed_after : loop -> t -> t - val embed_after_at_top: loop -> t -> t - val new_seq : ?comment:string list -> ?pragma:string -> ?aux:(Index.t * Expr.t) list - -> ?vars_decl:(string * Expr.t) list - -> Index.t -> Expr.t -> Expr.t -> Expr.t -> t -> t - val new_unroll: ?comment:string list -> Index.t -> Expr.t -> int -> Expr.t -> t -> t + val embed_after_at_top : loop -> t -> t + + val new_seq : + ?comment:string list -> + ?pragma:string -> + ?aux:(Index.t * Expr.t) list -> + ?vars_decl:(string * Expr.t) list -> + Index.t -> + Expr.t -> + Expr.t -> + Expr.t -> + t -> + t + + val new_unroll : + ?comment:string list -> Index.t -> Expr.t -> int -> Expr.t -> t -> t end - val show_alt: loop -> string - val unroll: ?st:Expr.t -> int -> Dim.t -> Expr.t -> loop list -> loop - val unroll_insts_dim: string list -> Dim.id -> Expr.t -> int -> Expr.t - -> Inst.t list -> Inst.t list - val loop: ?st:Expr.t -> ?halt:Expr.t -> ?incr:Expr.t -> + val show_alt : loop -> string + val unroll : ?st:Expr.t -> int -> Dim.t -> Expr.t -> loop list -> loop + + val unroll_insts_dim : + string list -> + Dim.id -> + Expr.t -> + int -> + Expr.t -> + Inst.t list -> + Inst.t list + + val loop : + ?st:Expr.t -> + ?halt:Expr.t -> + ?incr:Expr.t -> ?pragma:string -> ?vars_decl:(string * Expr.t) list -> - Dim.t -> loop list -> loop - val set_aux: Index.t -> (Index.t * Expr.t) list -> loop -> loop - val gen_code: (Size.id * int) list -> loop -> string + Dim.t -> + loop list -> + loop + + val set_aux : Index.t -> (Index.t * Expr.t) list -> loop -> loop + val gen_code : loop -> string end diff --git a/ml/lib/packing.ml b/ml/lib/packing.ml index 7c4bc7c0785d472ef0cfef8621d5c42109fcb4ab..9d6b9ad991e7c5f7078b75b86927810128b1ed39 100644 --- a/ml/lib/packing.ml +++ b/ml/lib/packing.ml @@ -3,11 +3,10 @@ open Ids open Exprs open Loop_nest module T = Tensor - open Instruction -module PI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct - module D = Dim_info.DI(Inst)(LN) +module PI (Inst : Inst_t) (LN : Loopnest_t with module Inst := Inst) = struct + module D = Dim_info.DI (Inst) (LN) module Zip = LN.Zipper module DimMap = D.DimMap @@ -18,61 +17,63 @@ module PI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct let build_dims_snapshot tensor dim_map = let dims = T.dims_list tensor in List.map - (fun dim -> - Dim.id dim, - DimMap.find dim dim_map - |> D.freeze dim) + (fun dim -> (Dim.id dim, DimMap.find dim dim_map |> D.freeze dim)) dims let fold_dims vectorize small_snapshot big_snapshot (indexes, glob_accesses, loc_accesses, loop_nest) = - let open Tensor in function - Single d -> - let d_id = Dim.id d in - let small_frozen = Option.get @@ List.assoc_eq Dim.equal_id d_id small_snapshot - and big_frozen = Option.get @@ List.assoc_eq Dim.equal_id d_id small_snapshot in - let to_vectorize = vectorize d_id in - let new_indexes, glob_index, local_index, enclose = - D.Frozen.interval_loop_tile big_frozen small_frozen to_vectorize in - List.rev_append new_indexes indexes, - (d_id, glob_index)::glob_accesses, - (d_id, local_index)::loc_accesses, - enclose % loop_nest + let open Tensor in + function + | Single d -> + let d_id = Dim.id d in + let small_frozen = + Option.get @@ List.assoc_eq Dim.equal_id d_id small_snapshot + and big_frozen = + Option.get @@ List.assoc_eq Dim.equal_id d_id small_snapshot + in + let to_vectorize = vectorize d_id in + let new_indexes, glob_index, local_index, enclose = + D.Frozen.interval_loop_tile big_frozen small_frozen to_vectorize + in + ( List.rev_append new_indexes indexes, + (d_id, glob_index) :: glob_accesses, + (d_id, local_index) :: loc_accesses, + enclose % loop_nest ) (* TODO we may want to change something here - Or do we ? *) | Join (main_dim, aux_dim, iter_dim, _) -> - let main_id = Dim.id main_dim - and aux_id = Dim.id aux_dim in - let get_frozen snapshot d = Option.get @@ List.assoc_eq - Dim.equal_id d snapshot in - let small_frozen_main = get_frozen small_snapshot main_id - and big_frozen_main = get_frozen big_snapshot main_id - and small_frozen_aux = get_frozen small_snapshot aux_id - and big_frozen_aux = get_frozen big_snapshot aux_id in - let new_indexes, new_glob_accesses, new_local_accesses, enclose = - D.Frozen.interval_loop_tile_join iter_dim main_dim aux_dim - big_frozen_main small_frozen_main - big_frozen_aux small_frozen_aux - false in - List.rev_append new_indexes indexes, - List.rev_append new_glob_accesses glob_accesses, - List.rev_append new_local_accesses loc_accesses, - enclose % loop_nest + let main_id = Dim.id main_dim and aux_id = Dim.id aux_dim in + let get_frozen snapshot d = + Option.get @@ List.assoc_eq Dim.equal_id d snapshot + in + let small_frozen_main = get_frozen small_snapshot main_id + and big_frozen_main = get_frozen big_snapshot main_id + and small_frozen_aux = get_frozen small_snapshot aux_id + and big_frozen_aux = get_frozen big_snapshot aux_id in + let new_indexes, new_glob_accesses, new_local_accesses, enclose = + D.Frozen.interval_loop_tile_join iter_dim main_dim aux_dim + big_frozen_main small_frozen_main big_frozen_aux small_frozen_aux + false + in + ( List.rev_append new_indexes indexes, + List.rev_append new_glob_accesses glob_accesses, + List.rev_append new_local_accesses loc_accesses, + enclose % loop_nest ) let build_loop_base dims vectorize small_snapshot big_snapshot = - List.fold_left - (fold_dims vectorize small_snapshot big_snapshot) - ([], [], [], Fun.id) dims + List.fold_left + (fold_dims vectorize small_snapshot big_snapshot) + ([], [], [], Fun.id) dims let build_transpose_loop transpose_dims glob_inner_dim loc_inner_dim - small_snapshot big_snapshot = - let vectorize_dim id = - Dim.equal_id (Dim.id glob_inner_dim) id - || Dim.equal_id (Dim.id loc_inner_dim) id in - build_loop_base transpose_dims vectorize_dim small_snapshot big_snapshot + small_snapshot big_snapshot = + let vectorize_dim id = + Dim.equal_id (Dim.id glob_inner_dim) id + || Dim.equal_id (Dim.id loc_inner_dim) id + in + build_loop_base transpose_dims vectorize_dim small_snapshot big_snapshot let build_loop dims inner_dim small_snapshot big_snapshot = - let vectorize_dim id = - Dim.equal_id (Dim.id inner_dim) id in + let vectorize_dim id = Dim.equal_id (Dim.id inner_dim) id in build_loop_base dims vectorize_dim small_snapshot big_snapshot (* Tensor access body-loop generation *) @@ -81,71 +82,107 @@ module PI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct let open Inst in let stride_src = Option.get @@ T.stride src_tensor dst_inner_dim and stride_dst = Option.get @@ T.stride dst_tensor src_inner_dim in - [Vtranspose(src_tensor, src_accesses, stride_src, - dst_tensor, dst_accesses, stride_dst) + [ + Vtranspose + ( src_tensor, + src_accesses, + stride_src, + dst_tensor, + dst_accesses, + stride_dst ); ] let access to_vectorize src_tensor dst_tensor src_accesses dst_accesses = let open Inst in if to_vectorize then - [Vwrite(Vread(src_tensor, src_accesses), dst_tensor, dst_accesses)] - else [Write(Read(src_tensor, src_accesses), dst_tensor, dst_accesses)] + [ Vwrite (Vread (src_tensor, src_accesses), dst_tensor, dst_accesses) ] + else [ Write (Read (src_tensor, src_accesses), dst_tensor, dst_accesses) ] type goal = Tensor.t -> D.t DimMap.t -> Zip.t -> Zip.t (* Build intermediate structure holding temporary packing information *) - let build_pack local_dim_map tensor local_tensor is_readonly glob_dim_map glob_tensor ln = - let load_comment = Printf.sprintf "Pack %s into %s" (T.show_tid tensor) (T.show_tid local_tensor) in - let store_comment = Printf.sprintf "Unpack %s into %s" (T.show_tid local_tensor) - (T.show_tid tensor) in + let build_pack local_dim_map tensor local_tensor is_readonly glob_dim_map + glob_tensor ln = + let load_comment = + Printf.sprintf "Pack %s into %s" (T.show_tid tensor) + (T.show_tid local_tensor) + in + let store_comment = + Printf.sprintf "Unpack %s into %s" (T.show_tid local_tensor) + (T.show_tid tensor) + in (* is it ok to vectorize on a join dim ? *) - let get_inner_dim = Tensor.inner_dim - %> function Single d -> d - | Join (d,_, _, _) -> d in + let get_inner_dim = + Tensor.inner_dim %> function Single d -> d | Join (d, _, _, _) -> d + in let inner_dim = get_inner_dim tensor in - let inner_dim_size = DimMap.find inner_dim local_dim_map - |> function Some dmap -> begin match D.incr dmap with - Expr.Const i -> i - | _ -> raise Not_found end - | None -> 1 in - let to_vectorize = inner_dim_size >= Inst.A.vec_size && inner_dim_size mod Inst.A.vec_size = 0 in + let inner_dim_size = + DimMap.find inner_dim local_dim_map |> function + | Some dmap -> ( + match D.incr dmap with Expr.Const i -> i | _ -> raise Not_found) + | None -> 1 + in + let to_vectorize = + inner_dim_size >= Inst.A.vec_size + && inner_dim_size mod Inst.A.vec_size = 0 + in (* For the moment we only support packing with both inner dims vectorisable *) - let _ = if not to_vectorize then - let dim_name = Dim.show_id_of_t inner_dim in - let error_msg = Printf.sprintf "Dimension '%s' of size %d (packed tensor) is not vectorizable (vector size %d doesn't divide it)." dim_name inner_dim_size Inst.A.vec_size in - let rule_msg = "Both inner dimensions of packed and global tensor should be vectorisable." in - raise (DimNotVectorisable (Printf.sprintf "%s %s" error_msg rule_msg)) - else - () in + let _ = + if not to_vectorize then + let dim_name = Dim.show_id_of_t inner_dim in + let error_msg = + Printf.sprintf + "Dimension '%s' of size %d (packed tensor) is not vectorizable \ + (vector size %d doesn't divide it)." + dim_name inner_dim_size Inst.A.vec_size + in + let rule_msg = + "Both inner dimensions of packed and global tensor should be \ + vectorisable." + in + raise (DimNotVectorisable (Printf.sprintf "%s %s" error_msg rule_msg)) + else () + in let small_snapshot = build_dims_snapshot tensor local_dim_map in let big_snapshot = build_dims_snapshot tensor glob_dim_map in - let glob_inner_dim = get_inner_dim glob_tensor - and loc_inner_dim = get_inner_dim local_tensor in - - let indexes, glob_accesses, loc_accesses, loop_enclose = - if Dim.equal glob_inner_dim loc_inner_dim - then build_loop (Tensor.t_dims_list local_tensor) glob_inner_dim small_snapshot big_snapshot - else let dims = T.t_dims_list local_tensor in - build_transpose_loop dims glob_inner_dim loc_inner_dim small_snapshot big_snapshot in - let load_kernel = - if Dim.equal glob_inner_dim loc_inner_dim - then access to_vectorize glob_tensor local_tensor glob_accesses loc_accesses - else transpose_access glob_inner_dim loc_inner_dim glob_tensor local_tensor - glob_accesses loc_accesses in - let store_kernel = - if Dim.equal glob_inner_dim loc_inner_dim - then access to_vectorize local_tensor glob_tensor loc_accesses glob_accesses - else transpose_access loc_inner_dim glob_inner_dim local_tensor glob_tensor - loc_accesses glob_accesses in - let load_loop = loop_enclose (Zip.Top, LN.Statement load_kernel) - |> Zip.set_top_comment [load_comment] - |> Zip.to_tree - and store_loop = loop_enclose (Zip.Top, LN.Statement store_kernel) - |> Zip.set_top_comment [store_comment] - |> Zip.to_tree in - indexes, - ln - |> Zip.embed_before_at_top load_loop - |> if is_readonly then Fun.id else Zip.embed_after_at_top store_loop - + let glob_inner_dim = get_inner_dim glob_tensor + and loc_inner_dim = get_inner_dim local_tensor in + + let indexes, glob_accesses, loc_accesses, loop_enclose = + if Dim.equal glob_inner_dim loc_inner_dim then + build_loop + (Tensor.t_dims_list local_tensor) + glob_inner_dim small_snapshot big_snapshot + else + let dims = T.t_dims_list local_tensor in + build_transpose_loop dims glob_inner_dim loc_inner_dim small_snapshot + big_snapshot + in + let load_kernel = + if Dim.equal glob_inner_dim loc_inner_dim then + access to_vectorize glob_tensor local_tensor glob_accesses loc_accesses + else + transpose_access glob_inner_dim loc_inner_dim glob_tensor local_tensor + glob_accesses loc_accesses + in + let store_kernel = + if Dim.equal glob_inner_dim loc_inner_dim then + access to_vectorize local_tensor glob_tensor loc_accesses glob_accesses + else + transpose_access loc_inner_dim glob_inner_dim local_tensor glob_tensor + loc_accesses glob_accesses + in + let load_loop = + loop_enclose (Zip.Top, LN.Statement load_kernel) + |> Zip.set_top_comment [ load_comment ] + |> Zip.to_tree + and store_loop = + loop_enclose (Zip.Top, LN.Statement store_kernel) + |> Zip.set_top_comment [ store_comment ] + |> Zip.to_tree + in + ( indexes, + ln + |> Zip.embed_before_at_top load_loop + |> if is_readonly then Fun.id else Zip.embed_after_at_top store_loop ) end diff --git a/ml/lib/tensor.ml b/ml/lib/tensor.ml index bc05871da097a8c91753d98714b0468586128892..e76a8986c81d13286e314b38dc0235fef28708dd 100644 --- a/ml/lib/tensor.ml +++ b/ml/lib/tensor.ml @@ -2,46 +2,50 @@ open Ids open Exprs open Utils - module Args = struct let prefix = "T" end -module T_id = Id(Args) + +module T_id = Id (Args) + type id = T_id.t [@@deriving show, ord, eq] (* A join dimension means that two distinct dimensions - meaning distinct loops * iterate on the same elements. Represents i + w for example *) -type t_dims = Single of Dim.t - | Join of Dim.t * Dim.t * Dim.t * int option [@@deriving eq, show] +type t_dims = Single of Dim.t | Join of Dim.t * Dim.t * Dim.t * int option +[@@deriving eq, show] let single dim = Single dim let join_dims d1 d2 = - let main_name = Dim.show_id_of_t d1 - and aux_name = Dim.show_id_of_t d2 in + let main_name = Dim.show_id_of_t d1 and aux_name = Dim.show_id_of_t d2 in let name = main_name ^ aux_name in let iter_dim, _ = Dim.fresh_gen name in Join (d1, d2, iter_dim (), None) let join_dims_stride d1 d2 str = - let main_name = Dim.show_id_of_t d1 - and aux_name = Dim.show_id_of_t d2 in + let main_name = Dim.show_id_of_t d1 and aux_name = Dim.show_id_of_t d2 in let name = main_name ^ aux_name in let iter_dim, _ = Dim.fresh_gen name in Join (d1, d2, iter_dim (), Some str) -type t = {id: id; dims: t_dims array; sizes:(t_dims * Expr.t) list; - strides: (t_dims * Expr.t) list} [@@deriving eq, show] +type t = { + id : id; + dims : t_dims array; + sizes : (t_dims * Expr.t) list; + strides : (t_dims * Expr.t) list; +} +[@@deriving eq, show] type gen = unit -> t - type accesses = (Dim.id * Expr.t) list [@@deriving eq, show] let get_dim tensor dim = Array.find_opt - (function Single d when Dim.equal d dim -> true - | Join (d1, d2, _, _) when Dim.equal dim d1 || Dim.equal dim d2 -> true - | _ -> false) + (function + | Single d when Dim.equal d dim -> true + | Join (d1, d2, _, _) when Dim.equal dim d1 || Dim.equal dim d2 -> true + | _ -> false) tensor.dims let stride tensor dim = @@ -51,44 +55,51 @@ let stride tensor dim = let dim_size tensor dim = List.assoc_eq equal_t_dims (single dim) tensor.sizes -let size {sizes;_} = - let non_zero_sizes = List.filter (Bool.not % Expr.(equal zero)) (List.map snd sizes) in +let size { sizes; _ } = + let non_zero_sizes = + List.filter (Bool.not % Expr.(equal zero)) (List.map snd sizes) + in if List.is_empty non_zero_sizes then Expr.zero - else List.fold_left Expr.mul Expr.one non_zero_sizes + else List.fold_left Expr.mul Expr.one non_zero_sizes -let id {id;_} = id -let show_tid {id;_} = show_id id +let id { id; _ } = id +let show_tid { id; _ } = show_id id -let fresh_gen (): ?name:string -> unit -> id = +let fresh_gen () : ?name:string -> unit -> id = let tensor_gen, _ = T_id.gen_id () in - fun ?name -> match name with - Some name -> fun () ->T_id.set_prefix name (tensor_gen ()) + fun ?name -> + match name with + | Some name -> fun () -> T_id.set_prefix name (tensor_gen ()) | None -> tensor_gen -let gen_clone ({id;_} as tensor) () = let f_id, gen_id = T_id.stack id () in - {tensor with id = f_id}, fun () -> {tensor with id = gen_id () } +let gen_clone ({ id; _ } as tensor) () = + let f_id, gen_id = T_id.stack id () in + ({ tensor with id = f_id }, fun () -> { tensor with id = gen_id () }) +let t_dims { dims; _ } = dims +let t_dims_list { dims; _ } = Array.to_list dims -let t_dims {dims;_} = dims -let t_dims_list {dims;_} = Array.to_list dims -let dims_list {dims;_} = Array.fold_right - (fun dim l -> match dim with Single d -> d::l - | Join (d1, d2, _, _) -> d1::d2::l) +let dims_list { dims; _ } = + Array.fold_right + (fun dim l -> + match dim with Single d -> d :: l | Join (d1, d2, _, _) -> d1 :: d2 :: l) dims [] -let strides_in_order {strides;_} = - List.map snd strides +let strides_in_order { strides; _ } = List.map snd strides (* Should take dims from smaller stride to bigger * Return a list where first stride is the biggest and last is one *) let strides_from_size dims dim_size_list = - List.fold_left (fun (strides, current_stride) dim -> + List.fold_left + (fun (strides, current_stride) dim -> let dim_size = Option.default Expr.zero (List.assoc dim dim_size_list) in - if Expr.equal Expr.zero dim_size then (dim, current_stride)::strides, current_stride - else let stride = Expr.(I.( current_stride * dim_size)) in - (dim, current_stride)::strides, stride) - ([], Expr.one) dims |> fst |> List.rev - + if Expr.equal Expr.zero dim_size then + ((dim, current_stride) :: strides, current_stride) + else + let stride = Expr.(I.(current_stride * dim_size)) in + ((dim, current_stride) :: strides, stride)) + ([], Expr.one) dims + |> fst |> List.rev (* Create a tensor with dims specified as an array of Dim, * with strides and sizes that can be optionally specified, @@ -98,51 +109,64 @@ let strides_from_size dims dim_size_list = let make ?name ?strides ?sizes dims = (* For now list are stored in reverse order so that last specified size get stride one *) let dims_list = Array.to_list dims in - let sizes = match sizes with - | None -> List.map (fun d -> single d, Expr.size @@ Dim.size_id d) dims_list + let sizes = + match sizes with + | None -> + List.map (fun d -> (single d, Expr.size @@ Dim.size_id d)) dims_list | Some sizes -> - List.map (fun dim -> - single dim, - Option.default (Expr.size @@ Dim.size_id dim) @@ List.assoc_eq Dim.equal dim - sizes) dims_list in - let strides = match strides with + List.map + (fun dim -> + ( single dim, + Option.default (Expr.size @@ Dim.size_id dim) + @@ List.assoc_eq Dim.equal dim sizes )) + dims_list + in + let strides = + match strides with | None -> strides_from_size (List.map single dims_list) sizes - | Some strides -> List.map (map_fst single) strides in - {id = fresh_gen () ?name (); dims = Array.map single dims; sizes; strides } + | Some strides -> List.map (map_fst single) strides + in + { id = fresh_gen () ?name (); dims = Array.map single dims; sizes; strides } -let make_join ?name ?strides ?sizes (dims: t_dims array) = +let make_join ?name ?strides ?sizes (dims : t_dims array) = let dims_list = Array.to_list dims in let sizes = - let csize d =Expr.size @@ Dim.size_id d in + let csize d = Expr.size @@ Dim.size_id d in match sizes with - | None -> List.map ( - function Single d as sd -> sd, csize d - | Join (d1, d2, _, None) as jd -> - jd, Expr.(I.(csize d1 + csize d2 - const 1)) - | Join (d1, d2, _, Some stride) as jd -> - jd, Expr.(I.(const stride * csize d1 + csize d2 - const 1)) - ) dims_list + | None -> + List.map + (function + | Single d as sd -> (sd, csize d) + | Join (d1, d2, _, None) as jd -> + (jd, Expr.(I.(csize d1 + csize d2 - const 1))) + | Join (d1, d2, _, Some stride) as jd -> + (jd, Expr.(I.((const stride * csize d1) + csize d2 - const 1)))) + dims_list | Some sizes -> - List.map (function - Single d as sd -> - sd, - Option.default (csize d) @@ List.assoc sd sizes - | Join (d1, d2, _, None) as jd -> - jd, - Option.default (Expr.(I.(csize d1 + csize d2 - const 1))) @@ List.assoc jd sizes - | Join (d1, d2, _, Some stride) as jd -> - jd, - Option.default (Expr.(I.(const stride * csize d1 + csize d2 - const 1))) @@ List.assoc jd sizes - ) - dims_list in - let strides = match strides with + List.map + (function + | Single d as sd -> + (sd, Option.default (csize d) @@ List.assoc sd sizes) + | Join (d1, d2, _, None) as jd -> + ( jd, + Option.default Expr.(I.(csize d1 + csize d2 - const 1)) + @@ List.assoc jd sizes ) + | Join (d1, d2, _, Some stride) as jd -> + ( jd, + Option.default + Expr.(I.((const stride * csize d1) + csize d2 - const 1)) + @@ List.assoc jd sizes )) + dims_list + in + let strides = + match strides with | None -> strides_from_size dims_list sizes - | Some strides -> strides in - {id = fresh_gen () ?name (); dims; sizes; strides } - + | Some strides -> strides + in + { id = fresh_gen () ?name (); dims; sizes; strides } let strides_from_size_list tens dim_size_list = - let dims = t_dims_list tens in + let dims = t_dims_list tens in strides_from_size dims dim_size_list let _validate_accesses tensor accesses = @@ -153,10 +177,11 @@ let acc_does_access accesses dim = (* FIXME: accesses_list seems broken, for now we stay conservative *) let tens_does_access tensor dim = - Array.exists + Array.exists (let check_dim d = Dim.equal_id (Dim.id d) dim in - function Single d -> check_dim d - | Join (d1, d2, _, _) -> check_dim d1 || check_dim d2) + function + | Single d -> check_dim d + | Join (d1, d2, _, _) -> check_dim d1 || check_dim d2) tensor.dims (* && (List.assoc_opt dim accesses |> Option.map Expr.is_constant @@ -166,127 +191,136 @@ let tens_does_access tensor dim = *) let does_access tensor accesses dim = - tens_does_access tensor dim - || acc_does_access accesses dim + tens_does_access tensor dim || acc_does_access accesses dim let sort_dim dim_list d1 d2 = if Dim.equal d1 d2 then 0 - else let rec sort_dim = - function - [] -> failwith "None of d1 or d2 is present" - | h::_ when Dim.equal h d1 -> -1 - | h::_ when Dim.equal h d2 -> 1 - | _::t -> sort_dim t in + else + let rec sort_dim = function + | [] -> failwith "None of d1 or d2 is present" + | h :: _ when Dim.equal h d1 -> -1 + | h :: _ when Dim.equal h d2 -> 1 + | _ :: t -> sort_dim t + in sort_dim dim_list -let get_main_dim = function - | Join (d, _, _, _) -> d - | Single d -> d +let get_main_dim = function Join (d, _, _, _) -> d | Single d -> d let sort_tdim dim_list d1 d2 = - let d1 = get_main_dim d1 - and d2 = get_main_dim d2 in + let d1 = get_main_dim d1 and d2 = get_main_dim d2 in sort_dim dim_list d1 d2 (* Maybe useful someday *) -let reorder_dim dim_list size_list = +let reorder_dim dim_list size_list = let sort_dim (d1, _) (d2, _) = - let d1, d2 = get_main_dim d1, get_main_dim d2 in - sort_dim dim_list d1 d2 in + let d1, d2 = (get_main_dim d1, get_main_dim d2) in + sort_dim dim_list d1 d2 + in List.sort sort_dim size_list let reorder_layout tensor dim_list = - let map_t_dim dim = match List.find - (function Single d | Join (d, _, _, _) -> - Dim.equal dim d) - (t_dims_list tensor) with - | Some t_dim -> t_dim - | None -> failwith @@ Printf.sprintf "Dimension %s could not be found in tensor %s\n" - (Dim.show_id_of_t dim) (show_tid tensor) in + let map_t_dim dim = + match + List.find + (function Single d | Join (d, _, _, _) -> Dim.equal dim d) + (t_dims_list tensor) + with + | Some t_dim -> t_dim + | None -> + failwith + @@ Printf.sprintf "Dimension %s could not be found in tensor %s\n" + (Dim.show_id_of_t dim) (show_tid tensor) + in let strides = strides_from_size (List.map map_t_dim dim_list) tensor.sizes in - let sorted_strides = reorder_dim dim_list strides in - let sorted_sizes = reorder_dim dim_list tensor.sizes in - let dims = Array.to_list (t_dims tensor) - |> List.sort (sort_tdim dim_list) - |> Array.of_list in - {tensor with dims; strides=sorted_strides; sizes = sorted_sizes} + let sorted_strides = reorder_dim dim_list strides in + let sorted_sizes = reorder_dim dim_list tensor.sizes in + let dims = + Array.to_list (t_dims tensor) + |> List.sort (sort_tdim dim_list) + |> Array.of_list + in + { tensor with dims; strides = sorted_strides; sizes = sorted_sizes } let modify_size_stride tensor size_list = let dims = t_dims_list tensor in let sizes = - List.map (fun dim -> - dim, - Option.default Expr.zero @@ List.assoc dim size_list) - dims in + List.map + (fun dim -> (dim, Option.default Expr.zero @@ List.assoc dim size_list)) + dims + in let strides = strides_from_size_list tensor sizes in - {tensor with strides; sizes} + { tensor with strides; sizes } let map_stride tensor f = let strides = List.map (fun (d, cexpr) -> (d, f cexpr)) tensor.strides in - {tensor with strides} + { tensor with strides } let modify_stride tensor dim f = - let strides = List.map (map_fst - (function Single d -> d - | _ -> failwith "Does not support join for the moment")) - tensor.strides - |> List.modify_assoc Dim.equal f dim - |> List.map (map_fst single) in - {tensor with strides} + let strides = + List.map + (map_fst (function + | Single d -> d + | _ -> failwith "Does not support join for the moment")) + tensor.strides + |> List.modify_assoc Dim.equal f dim + |> List.map (map_fst single) + in + { tensor with strides } let modify_accesses accesses dim f = List.modify_assoc Dim.equal_id f dim accesses let gen_access tensor accesses = let dim_stride_expr_pair = - let get_expr d accesses = - List.assoc_eq Dim.equal_id (Dim.id d) accesses in - List.map (function - Single d, stride -> - Option.get @@ get_expr d accesses, stride - | Join (d1, d2, iter_dim, str), stride -> - let main = get_expr d1 accesses - and aux = get_expr d2 accesses - and iter = get_expr iter_dim accesses in - let mul_stride e = Option.map_default - (fun s -> Expr.(I.(const s * e))) e str in - match main, aux, iter with - | Some main, Some aux, None -> - Expr.add (mul_stride main) aux, stride - | None, None, Some iter -> - iter, stride - | Some _, _, Some _ | _, Some _, Some _ -> - failwith "Should access either main and aux or iter dimension but - not both" - | _ -> - failwith "Missing accesses" - ) - tensor.strides in - let accesses_expr = List.fold_left - (fun acc (acc_expr, stride) -> - Expr.(I.(stride * acc_expr + acc)) - ) Expr.zero dim_stride_expr_pair in + let get_expr d accesses = List.assoc_eq Dim.equal_id (Dim.id d) accesses in + List.map + (function + | Single d, stride -> (Option.get @@ get_expr d accesses, stride) + | Join (d1, d2, iter_dim, str), stride -> ( + let main = get_expr d1 accesses + and aux = get_expr d2 accesses + and iter = get_expr iter_dim accesses in + let mul_stride e = + Option.map_default (fun s -> Expr.(I.(const s * e))) e str + in + match (main, aux, iter) with + | Some main, Some aux, None -> + (Expr.add (mul_stride main) aux, stride) + | None, None, Some iter -> (iter, stride) + | Some _, _, Some _ | _, Some _, Some _ -> + failwith + "Should access either main and aux or iter dimension but\n\ + \ not both" + | _ -> failwith "Missing accesses")) + tensor.strides + in + let accesses_expr = + List.fold_left + (fun acc (acc_expr, stride) -> Expr.(I.((stride * acc_expr) + acc))) + Expr.zero dim_stride_expr_pair + in Format.sprintf "%s[%s]" (show_id tensor.id) (Expr.show accesses_expr) let compare_dims_id tensor dim1 dim2 = let rec cmp = function - | (Single d)::_ when Dim.equal_id (Dim.id d) dim1 -> 1 - | (Single d)::_ when Dim.equal_id (Dim.id d) dim2 -> -1 - | Join (d1, d2, _, _)::_ when (Dim.equal_id (Dim.id d1) dim1 - && Dim.equal_id (Dim.id d2) dim2) - || (Dim.equal_id (Dim.id d2) dim1 - && Dim.equal_id (Dim.id d1) dim2)-> 0 - | Join (d1, d2, _, _)::_ when Dim.equal_id (Dim.id d1) dim1 - || Dim.equal_id (Dim.id d2) dim1 -> 1 - | Join (d1, d2, _, _)::_ when Dim.equal_id (Dim.id d1) dim2 - || Dim.equal_id (Dim.id d2) dim2 -> -1 - | _::t -> cmp t - | [] -> failwith "Dim1 and dim2 are not present in tensor dims" in - if Dim.equal_id dim1 dim2 then 0 - else cmp @@ t_dims_list @@ tensor + | Single d :: _ when Dim.equal_id (Dim.id d) dim1 -> 1 + | Single d :: _ when Dim.equal_id (Dim.id d) dim2 -> -1 + | Join (d1, d2, _, _) :: _ + when (Dim.equal_id (Dim.id d1) dim1 && Dim.equal_id (Dim.id d2) dim2) + || (Dim.equal_id (Dim.id d2) dim1 && Dim.equal_id (Dim.id d1) dim2) + -> + 0 + | Join (d1, d2, _, _) :: _ + when Dim.equal_id (Dim.id d1) dim1 || Dim.equal_id (Dim.id d2) dim1 -> + 1 + | Join (d1, d2, _, _) :: _ + when Dim.equal_id (Dim.id d1) dim2 || Dim.equal_id (Dim.id d2) dim2 -> + -1 + | _ :: t -> cmp t + | [] -> failwith "Dim1 and dim2 are not present in tensor dims" + in + if Dim.equal_id dim1 dim2 then 0 else cmp @@ t_dims_list @@ tensor let inner_dim tensor = (* TODO Does it make sense for us to have a zero-dimensional tensor ? *) - t_dims_list tensor - |> List.hd - |> Option.get + t_dims_list tensor |> List.hd |> Option.get diff --git a/ml/lib/tensor.mli b/ml/lib/tensor.mli index e1d1a64d5a1aaa04a615c803b5a901faf1a196a0..cca3d7b0ee13436ed680a567e2d8ece2bbfe0c60 100644 --- a/ml/lib/tensor.mli +++ b/ml/lib/tensor.mli @@ -3,40 +3,58 @@ open Exprs type t [@@deriving eq, show] type id [@@deriving eq, ord, show] -type t_dims = Single of Dim.t | Join of Dim.t * Dim.t * Dim.t * int option [@@deriving eq, show] + +type t_dims = Single of Dim.t | Join of Dim.t * Dim.t * Dim.t * int option +[@@deriving eq, show] + type gen = unit -> t -type accesses = (Dim.id * Expr.t) list [@@deriving eq, show] - -val single: Dim.t -> t_dims -val join_dims: Dim.t -> Dim.t -> t_dims -val join_dims_stride: Dim.t -> Dim.t -> int -> t_dims -val make: ?name:string -> ?strides:(Dim.t * Expr.t) list -> ?sizes:(Dim.t * Expr.t) list - -> Dim.t array -> t -val make_join: ?name:string -> ?strides:(t_dims * Expr.t) list -> ?sizes:(t_dims * Expr.t) list - -> t_dims array -> t -val t_dims: t -> t_dims array -val t_dims_list: t -> t_dims list -val dims_list: t -> Dim.t list - -val id: t -> id -val show_tid: t -> string - -val gen_clone: t -> unit -> (t * (unit -> t)) -val strides_from_size_list: t -> (t_dims * Expr.t) list -> (t_dims * Expr.t) list -val strides_in_order: t -> Expr.t list +type accesses = (Dim.id * Expr.t) list [@@deriving eq, show] + +val single : Dim.t -> t_dims +val join_dims : Dim.t -> Dim.t -> t_dims +val join_dims_stride : Dim.t -> Dim.t -> int -> t_dims + +val make : + ?name:string -> + ?strides:(Dim.t * Expr.t) list -> + ?sizes:(Dim.t * Expr.t) list -> + Dim.t array -> + t + +val make_join : + ?name:string -> + ?strides:(t_dims * Expr.t) list -> + ?sizes:(t_dims * Expr.t) list -> + t_dims array -> + t + +val t_dims : t -> t_dims array +val t_dims_list : t -> t_dims list +val dims_list : t -> Dim.t list +val id : t -> id +val show_tid : t -> string +val gen_clone : t -> unit -> t * (unit -> t) + +val strides_from_size_list : + t -> (t_dims * Expr.t) list -> (t_dims * Expr.t) list + +val strides_in_order : t -> Expr.t list + (* val sort_by_stride: t -> t_dims list -> t_dims list *) -val size: t -> Expr.t -val reorder_layout: t -> Dim.t list -> t -val dim_size: t -> Dim.t -> Expr.t option -val stride: t -> Dim.t -> Expr.t option -val acc_does_access: accesses -> Dim.id -> bool -val tens_does_access: t -> Dim.id -> bool -val does_access: t -> accesses -> Dim.id -> bool -val modify_size_stride: t -> (t_dims * Expr.t) list -> t -val map_stride: t -> (Expr.t -> Expr.t) -> t -val modify_stride: t -> Dim.t -> (Expr.t option -> Expr.t option) -> t -val modify_accesses: accesses -> Dim.id -> (Expr.t option -> Expr.t option) -> accesses -val gen_access: t -> accesses -> string - -val compare_dims_id: t -> Dim.id -> Dim.id -> int -val inner_dim: t -> t_dims +val size : t -> Expr.t +val reorder_layout : t -> Dim.t list -> t +val dim_size : t -> Dim.t -> Expr.t option +val stride : t -> Dim.t -> Expr.t option +val acc_does_access : accesses -> Dim.id -> bool +val tens_does_access : t -> Dim.id -> bool +val does_access : t -> accesses -> Dim.id -> bool +val modify_size_stride : t -> (t_dims * Expr.t) list -> t +val map_stride : t -> (Expr.t -> Expr.t) -> t +val modify_stride : t -> Dim.t -> (Expr.t option -> Expr.t option) -> t + +val modify_accesses : + accesses -> Dim.id -> (Expr.t option -> Expr.t option) -> accesses + +val gen_access : t -> accesses -> string +val compare_dims_id : t -> Dim.id -> Dim.id -> int +val inner_dim : t -> t_dims diff --git a/ml/lib/tensor_loops.ml b/ml/lib/tensor_loops.ml index cb358c6c8363035974096321b143cbebe44b4027..eb750d92359d4ef9e2899868e31afffc1dd47134 100644 --- a/ml/lib/tensor_loops.ml +++ b/ml/lib/tensor_loops.ml @@ -1,13 +1,12 @@ module Tensor = Tensor module Ids = Ids -type iter = Tensor_transform.iter = Iter of int [@@deriving show] +type iter = Tensor_transform.iter = Iter of int [@@deriving show] type arg = Tensor_transform.arg = Arg of int [@@deriving show] module Tensor_tile = Tensor_transform.Tensor_tile module Exprs = Exprs module Instruction = Instruction.I module Loop_nest = Loop_nest - module Arch = Arch module Sign = Kernels_sign diff --git a/ml/lib/tensor_transform.ml b/ml/lib/tensor_transform.ml index e3787b3376a4352fa0eeb4f3205238c71af11f61..7fff0c7313c4bd240c57a756488cdb4c1532c962 100644 --- a/ml/lib/tensor_transform.ml +++ b/ml/lib/tensor_transform.ml @@ -2,1024 +2,1645 @@ open Utils open Ids open Exprs module T = Tensor - - open Kernels_sign -type iter = Iter of int [@@deriving show {with_path = false}] -type arg = Arg of int [@@deriving show {with_path = false}] +type iter = Iter of int [@@deriving show { with_path = false }] +type arg = Arg of int [@@deriving show { with_path = false }] -module Tensor_tile(A: Arch.Vec_arch_t) = struct - module Inst = Instruction.I(A) - module LN = Loop_nest.L(Inst) +module Tensor_tile (A : Arch.Vec_arch_t) = struct + module Inst = Instruction.I (A) + module LN = Loop_nest.L (Inst) module Zip = LN.Zipper - let show_arg_iter = [%show: (iter * arg) list ] - - type ext_call_arg = {name: string; - fixed_size : (Dim.t * int) list; - var_dim: Dim.t; - dynamic_sizes : (Dim.t * int) list; - tensors_list : Tensor.t list; - max_range: int} [@@deriving show] - - let pp_ext_call_arg fmt {name;_} = Format.fprintf fmt "<%s>" - name - - type loop_type = U of int * Dim.t [@printer fun fmt (unroll, dim) -> fprintf fmt - "U (%d, %s)" unroll (Dim.show_id_of_t dim)] - | V of Dim.t [@printer fun fmt dim -> fprintf fmt - "V %s" (Dim.show_id_of_t dim)] - | External_call of ext_call_arg - (* Partial_Tile (d, size) does size iteration or less *) - | Tile_partial of int * Dim.t [@printer fun fmt (size, dim) -> fprintf fmt - "Tile_partial (%d, %s)" size (Dim.show_id_of_t dim)] - (* generalized exact tile *) - | Tile_gexact of int * Dim.t [@printer fun fmt (size, dim) -> fprintf fmt - "Tile_gexact (%d, %s)" size (Dim.show_id_of_t dim)] - | Tile_exact of int * Dim.t (* Tile_Exact (n, d) does exactly n - iterations - no multiplication *) - [@printer fun fmt (size, dim) -> fprintf fmt - "Tile_exact (%d, %s)" size (Dim.show_id_of_t dim)] - | ULambda of Dim.t [@printer fun fmt dim -> fprintf fmt - "ULambda %s" (Dim.show_id_of_t dim)] - | TLambda of Dim.t [@printer fun fmt dim -> fprintf fmt - "TLambda %s" (Dim.show_id_of_t dim)] - | Lambda_apply of Dim.t * (iter * arg) list [@printer fun fmt (dim, l) -> - fprintf fmt "Lambda_apply %s %s" (Dim.show_id_of_t dim) (show_arg_iter l)] - | T of int * Dim.t [@printer fun fmt (size, dim) -> fprintf fmt - "T (%d, %s)" size (Dim.show_id_of_t dim)] - | T_par of int * Dim.t [@printer fun fmt (size, dim) -> fprintf fmt - "T_par (%d, %s)" size (Dim.show_id_of_t dim)] - | Pack_tens of T.t [@printer fun fmt tens -> fprintf fmt - "Pack %s" (T.show_tid tens)] - | Pack_trans of T.t * Dim.t list [@printer fun fmt (tens, _) -> fprintf fmt - "Pack/transpose %s" (T.show_tid tens)] - | Hoist_vars of Dim.t list [@printer fun fmt dim_list -> fprintf fmt - "Hoist_vars [%s]" (String.concat ";" @@ List.map Dim.show_id_of_t dim_list)] - | R of Dim.t [@printer fun fmt dim -> fprintf fmt - "R %s" (Dim.show_id_of_t dim)] + let show_arg_iter = [%show: (iter * arg) list] + + type ext_call_arg = { + name : string; + fixed_size : (Dim.t * int) list; + var_dim : Dim.t; + dynamic_sizes : (Dim.t * int) list; + tensors_list : Tensor.t list; + max_range : int; + } + [@@deriving show] + + let pp_ext_call_arg fmt { name; _ } = Format.fprintf fmt "<%s>" name + + type loop_type = + | U of int * Dim.t + [@printer + fun fmt (unroll, dim) -> + fprintf fmt "U (%d, %s)" unroll (Dim.show_id_of_t dim)] + | V of Dim.t + [@printer fun fmt dim -> fprintf fmt "V %s" (Dim.show_id_of_t dim)] + | External_call of ext_call_arg + (* Partial_Tile (d, size) does size iteration or less *) + | Tile_partial of int * Dim.t + [@printer + fun fmt (size, dim) -> + fprintf fmt "Tile_partial (%d, %s)" size (Dim.show_id_of_t dim)] + (* generalized exact tile *) + | Tile_gexact of int * Dim.t + [@printer + fun fmt (size, dim) -> + fprintf fmt "Tile_gexact (%d, %s)" size (Dim.show_id_of_t dim)] + | Tile_exact of + int + * Dim.t + (* Tile_Exact (n, d) does exactly n + iterations - no multiplication *) + [@printer + fun fmt (size, dim) -> + fprintf fmt "Tile_exact (%d, %s)" size (Dim.show_id_of_t dim)] + | ULambda of Dim.t + [@printer + fun fmt dim -> fprintf fmt "ULambda %s" (Dim.show_id_of_t dim)] + | TLambda of Dim.t + [@printer + fun fmt dim -> fprintf fmt "TLambda %s" (Dim.show_id_of_t dim)] + | Lambda_apply of Dim.t * (iter * arg) list + [@printer + fun fmt (dim, l) -> + fprintf fmt "Lambda_apply %s %s" (Dim.show_id_of_t dim) + (show_arg_iter l)] + | T of int * Dim.t + [@printer + fun fmt (size, dim) -> + fprintf fmt "T (%d, %s)" size (Dim.show_id_of_t dim)] + | T_par of int * Dim.t + [@printer + fun fmt (size, dim) -> + fprintf fmt "T_par (%d, %s)" size (Dim.show_id_of_t dim)] + | Fused_T_pars of (int * Dim.t) list + [@printer + fun fmt list_s_d -> + fprintf fmt "Fused_T_pars [%s]" + (String.concat "; " + (List.map + (fun (size, dim) -> + Printf.sprintf "(%d, %s)" size (Dim.show_id_of_t dim)) + list_s_d))] + | Pack_tens of T.t + [@printer fun fmt tens -> fprintf fmt "Pack %s" (T.show_tid tens)] + | Pack_trans of T.t * Dim.t list + [@printer + fun fmt (tens, _) -> fprintf fmt "Pack/transpose %s" (T.show_tid tens)] + | Hoist_vars of Dim.t list + [@printer + fun fmt dim_list -> + fprintf fmt "Hoist_vars [%s]" + (String.concat ";" @@ List.map Dim.show_id_of_t dim_list)] + | R of Dim.t + [@printer fun fmt dim -> fprintf fmt "R %s" (Dim.show_id_of_t dim)] [@@deriving show { with_path = false }] - - - type tile_scheme = loop_type list let tile_to_string = [%show: loop_type list] module Payload = struct - (* Module holding all necessary tensor information - for now just which clone to use *) module Te = struct - type t = {tensor_clone: T.gen; tensor_list: (T.t * Expr.t) list;} + type t = { tensor_clone : T.gen; tensor_list : (T.t * Expr.t) list } let get_new_tens_size new_tens size_list = let new_tensor = T.modify_size_stride new_tens size_list in let new_size = T.size new_tensor in - new_tensor, new_size + (new_tensor, new_size) let make tensor size_list = let tens, tensor_clone = T.gen_clone tensor () in let tens, size = get_new_tens_size tens size_list in - {tensor_clone; tensor_list = [tens, size]} + { tensor_clone; tensor_list = [ (tens, size) ] } - let pack ({tensor_clone;tensor_list;_} as state) size_list () = + let pack ({ tensor_clone; tensor_list; _ } as state) size_list () = let new_tens = tensor_clone () in let new_tensor, new_size = get_new_tens_size new_tens size_list in - {state with tensor_list = (new_tensor, new_size)::tensor_list} + { state with tensor_list = (new_tensor, new_size) :: tensor_list } (* Fails if clone was not called before *) - let current_tensor {tensor_list;_} = + let current_tensor { tensor_list; _ } = List.hd tensor_list |> Option.get |> fst let page_size = 1024 - let decl_alloc_free_tens {tensor_list;_} = + + let decl_alloc_free_tens { tensor_list; _ } = tensor_list |> List.map (fun (t, s) -> - Printf.sprintf "%s * %s = (%s *)ALLOC(%d, sizeof(%s) * %s);" - Inst.A.base_type_name (T.show_id (T.id t)) - Inst.A.base_type_name (*(Inst.A.vec_size * 4)*) page_size - Inst.A.base_type_name (Expr.show s), - Printf.sprintf "FREE(%s);"(T.show_id (T.id t)) - ) - |> List.split - |> fun (allocs, frees) -> (String.concat "\n" allocs),(String.concat "\n" frees) + ( Printf.sprintf "%s * %s = (%s *)ALLOC(%d, sizeof(%s) * %s);" + Inst.A.base_type_name + (T.show_id (T.id t)) + Inst.A.base_type_name (*(Inst.A.vec_size * 4)*) page_size + Inst.A.base_type_name (Expr.show s), + Printf.sprintf "FREE(%s);" (T.show_id (T.id t)) )) + |> List.split + |> fun (allocs, frees) -> + (String.concat "\n" allocs, String.concat "\n" frees) end (* Building Map modules needed to hold dimension and tensor infos *) - module TidMap = Make_map(struct type t = T.id let compare = T.compare_id end) + module TidMap = Make_map (struct + type t = T.id - module D = Dim_info.DI(Inst)(LN) + let compare = T.compare_id + end) + + module D = Dim_info.DI (Inst) (LN) module DimMap = D.DimMap - module PI = Packing.PI(Inst)(LN) + module PI = Packing.PI (Inst) (LN) - type lnest = Concrete of Index.t list * Zip.t - | Lambda of Dim.t * (int -> lnest) - | WaitPack of (Tensor.t * ( D.t DimMap.t -> T.t -> lnest)) - | WaitTile of Dim.t * (Expr.t -> lnest) + type lnest = + | Concrete of Index.t list * Zip.t + | Lambda of Dim.t * (int -> lnest) + | WaitPack of (Tensor.t * (D.t DimMap.t -> T.t -> lnest)) + | WaitTile of Dim.t * (Expr.t -> lnest) let rec compose_lnest f = function | Concrete (indexes, ln) -> Concrete (indexes, f ln) - | WaitPack (tensor, thunk) -> - WaitPack (tensor, fun t di -> - compose_lnest f (thunk t di)) - | Lambda (dim, thunk) -> - Lambda (dim, fun i -> compose_lnest f (thunk i)) + | WaitPack (tensor, thunk) -> + WaitPack (tensor, fun t di -> compose_lnest f (thunk t di)) + | Lambda (dim, thunk) -> Lambda (dim, fun i -> compose_lnest f (thunk i)) | WaitTile (dim, thunk) -> - WaitTile (dim, fun e -> compose_lnest f (thunk e)) - + WaitTile (dim, fun e -> compose_lnest f (thunk e)) let introduce_wait_tile dim f = function - | Concrete (indexes, ln) -> WaitTile (dim, fun e -> Concrete (indexes, f e ln)) + | Concrete (indexes, ln) -> + WaitTile (dim, fun e -> Concrete (indexes, f e ln)) | WaitPack (tensor, thunk) -> - WaitTile (dim, - fun i -> WaitPack (tensor, fun t di -> compose_lnest (f i) (thunk t di))) - | WaitTile (dim', thunk) -> WaitTile (dim, fun e -> WaitTile (dim', - compose_lnest (f e) - % thunk)) + WaitTile + ( dim, + fun i -> + WaitPack (tensor, fun t di -> compose_lnest (f i) (thunk t di)) + ) + | WaitTile (dim', thunk) -> + WaitTile (dim, fun e -> WaitTile (dim', compose_lnest (f e) % thunk)) | Lambda (dim', thunk) -> - WaitTile (dim, fun e -> Lambda (dim', compose_lnest (f e) % thunk)) + WaitTile (dim, fun e -> Lambda (dim', compose_lnest (f e) % thunk)) let rec eliminate_wait_tile dim tile_size = function | Concrete _ as c -> c - | WaitTile (dim', thunk) when Dim.equal dim dim' -> - thunk tile_size + | WaitTile (dim', thunk) when Dim.equal dim dim' -> thunk tile_size | WaitTile (dim', thunk) -> - WaitTile (dim', - fun e -> eliminate_wait_tile dim tile_size @@ thunk e) + WaitTile (dim', fun e -> eliminate_wait_tile dim tile_size @@ thunk e) | Lambda (dim', thunk) -> - Lambda (dim', - fun i -> eliminate_wait_tile dim tile_size @@ thunk i) + Lambda (dim', fun i -> eliminate_wait_tile dim tile_size @@ thunk i) | WaitPack (tensor, thunk) -> - WaitPack (tensor, - fun di t -> eliminate_wait_tile dim tile_size @@ thunk di t ) + WaitPack + (tensor, fun di t -> eliminate_wait_tile dim tile_size @@ thunk di t) let introduce_lambda dim f = function - | Concrete (indexes, ln) -> Lambda (dim, fun i -> Concrete (indexes, f i ln)) + | Concrete (indexes, ln) -> + Lambda (dim, fun i -> Concrete (indexes, f i ln)) | WaitPack (tensor, thunk) -> - Lambda (dim, - fun i -> WaitPack (tensor, fun t di -> compose_lnest (f i) (thunk t di))) - | WaitTile (dim', thunk) -> Lambda (dim, fun i -> WaitTile (dim', - compose_lnest (f i) - % thunk)) + Lambda + ( dim, + fun i -> + WaitPack (tensor, fun t di -> compose_lnest (f i) (thunk t di)) + ) + | WaitTile (dim', thunk) -> + Lambda (dim, fun i -> WaitTile (dim', compose_lnest (f i) % thunk)) | Lambda (dim', thunk) -> - Lambda (dim, fun i -> Lambda (dim', compose_lnest (f i) % thunk)) + Lambda (dim, fun i -> Lambda (dim', compose_lnest (f i) % thunk)) - let rec compose_pair f = function + let rec compose_pair f = function | Concrete (indexes, ln) -> - let idx', ln' = f ln in - Concrete (indexes @ idx', ln') - | WaitPack (tensor, thunk) -> - WaitPack (tensor, fun di t -> - compose_pair f (thunk di t)) + let idx', ln' = f ln in + Concrete (indexes @ idx', ln') + | WaitPack (tensor, thunk) -> + WaitPack (tensor, fun di t -> compose_pair f (thunk di t)) | WaitTile (dim, thunk) -> - WaitTile (dim, fun i -> compose_pair f (thunk i)) - | Lambda (dim, thunk) -> - Lambda (dim, fun i -> compose_pair f (thunk i)) - + WaitTile (dim, fun i -> compose_pair f (thunk i)) + | Lambda (dim, thunk) -> Lambda (dim, fun i -> compose_pair f (thunk i)) let introduce_wait_pack tensor f = function | Concrete (indexes, ln) -> - WaitPack (tensor, fun di t -> - let indexes', ln = f di t ln in - Concrete (indexes @ indexes', ln) - ) + WaitPack + ( tensor, + fun di t -> + let indexes', ln = f di t ln in + Concrete (indexes @ indexes', ln) ) | WaitPack (tensor', thunk) -> - WaitPack (tensor, - fun di t -> - WaitPack (tensor', fun di' t' -> - compose_pair (f di t) (thunk di' t'))) + WaitPack + ( tensor, + fun di t -> + WaitPack + (tensor', fun di' t' -> compose_pair (f di t) (thunk di' t')) + ) | Lambda (dim, thunk) -> - WaitPack (tensor, fun di t -> - Lambda(dim, compose_pair (f di t) % thunk) - ) + WaitPack + (tensor, fun di t -> Lambda (dim, compose_pair (f di t) % thunk)) | WaitTile (dim, thunk) -> - WaitPack (tensor, fun di t -> - WaitTile(dim, compose_pair (f di t) % thunk) - ) - -(* - let map_ln f (ln, pi, k) = (f ln, pi, k) - let map_pi f (ln, pi, k) = ( ln, f pi, k) - let _map_k f (ln, pi, k) = ( ln, f pi, k) -*) - - type t = {dmap: D.t DimMap.t; loop_nest: lnest; - dim_list : Dim.t list; - is_top: bool; - kernel : Inst.t list; - tmap: Te.t TidMap.t;} + WaitPack + (tensor, fun di t -> WaitTile (dim, compose_pair (f di t) % thunk)) + + type t = { + dmap : D.t DimMap.t; + loop_nest : lnest; + dim_list : Dim.t list; + is_top : bool; + kernel : Inst.t list; + tmap : Te.t TidMap.t; + vectorize_dim : Dim.t option; + } let init dim_list inst = - let statement = Zip.Top, LN.Statement inst in - {dim_list; dmap = DimMap.empty; kernel = inst; tmap = TidMap.empty; is_top = true; - loop_nest = Concrete ([], statement)} - - let finalize_loop_nest {dmap;loop_nest;_} = + let statement = (Zip.Top, LN.Statement inst) in + { + dim_list; + dmap = DimMap.empty; + kernel = inst; + tmap = TidMap.empty; + is_top = true; + loop_nest = Concrete ([], statement); + vectorize_dim = None; + } + + let finalize_loop_nest { dmap; loop_nest; _ } = let rec loop = function - Concrete (indexes, ln) -> indexes, ln + | Concrete (indexes, ln) -> (indexes, ln) | WaitTile (dim, _) -> - failwith @@ Printf.sprintf "WaitTile %s was not applied" (Dim.show_id_of_t dim) + failwith + @@ Printf.sprintf "WaitTile %s was not applied" + (Dim.show_id_of_t dim) | Lambda (dim, _) -> - failwith @@ Printf.sprintf "lambda %s was not applied" (Dim.show_id_of_t dim) - | WaitPack (tensor, thunk) -> - loop (thunk dmap tensor) in + failwith + @@ Printf.sprintf "lambda %s was not applied" (Dim.show_id_of_t dim) + | WaitPack (tensor, thunk) -> loop (thunk dmap tensor) + in loop loop_nest - let show_current_state dim_list {dmap; _} = - dim_list + let show_current_state dim_list { dmap; _ } = + dim_list |> List.map (fun dim -> - DimMap.find_opt dim dmap - |> function Some dinfo -> - let incr = D.incr dinfo - and _next_index = D.next_index dinfo in - Printf.sprintf "%s = %s" - (Dim.show_id_of_t dim) ( Expr.show incr) - | None -> - Printf.sprintf "%s = 1" - ((Dim.show_id_of_t dim)) - ) + DimMap.find_opt dim dmap |> function + | Some dinfo -> + let incr = D.incr dinfo and _next_index = D.next_index dinfo in + Printf.sprintf "%s = %s" (Dim.show_id_of_t dim) + (Expr.show incr) + | None -> Printf.sprintf "%s = 1" (Dim.show_id_of_t dim)) |> String.concat ", " - let apply_dmap f payload = f payload.dmap - |> fun dmap -> {payload with dmap} - + let apply_dmap f payload = + f payload.dmap |> fun dmap -> { payload with dmap } let find_dim dim payload = DimMap.find dim payload.dmap - let modify_map_dim dim f payload = let dmap, res = DimMap.modify_map dim f payload.dmap in - {payload with dmap}, res + ({ payload with dmap }, res) (* Is p true for all (dim, dp) in payload *) let _for_all_dims p payload = DimMap.for_all p payload.dmap - let apply_ln f ({loop_nest;_} as payload) = + let apply_ln f ({ loop_nest; _ } as payload) = let loop_nest = f loop_nest in - {payload with loop_nest} - + { payload with loop_nest } let modify_block f = - apply_ln - (compose_lnest (LN.(function z, Statement stmts -> z, Statement (f stmts) - | _ -> failwith "is not a block" - ))) - + apply_ln + (compose_lnest + LN.( + function + | z, Statement stmts -> (z, Statement (f stmts)) + | _ -> failwith "is not a block")) let external_call payload call_name var_dim max_range tensors_list dim_sizes - dynamic_sizes = + dynamic_sizes = let payload, cur_bound_var = apply_dmap - (fun p -> List.fold_left + (fun p -> + List.fold_left (fun p (d, s) -> - DimMap.modify_opt d (function - Some _ -> failwith - "External_call is only valid at innermost position" - | None -> Some ( - D.default d - |> D.map_incr (fun _ -> Expr.const s ) - |> D.set_div_constraint (Dim_info.Div) - ) - ) p) - p (dim_sizes @ dynamic_sizes)) + DimMap.modify_opt d + (function + | Some _ -> + failwith + "External_call is only valid at innermost position" + | None -> + Some + (D.default d + |> D.map_incr (fun _ -> Expr.const s) + |> D.set_div_constraint Dim_info.Div)) + p) + p + (dim_sizes @ dynamic_sizes)) payload - |> modify_map_dim var_dim ( - function None -> - let dp = D.default var_dim - |> D.map_incr (fun _ -> Expr.const max_range) - |> D.set_div_constraint (Dim_info.Flexible 1) in - let cur_bound_var = D.cur_bound_var dp in - Some dp , - cur_bound_var - | Some _ -> failwith - "External_call is only valid at innermost position" - ) in + |> modify_map_dim var_dim (function + | None -> + let dp = + D.default var_dim + |> D.map_incr (fun _ -> Expr.const max_range) + |> D.set_div_constraint (Dim_info.Flexible 1) + in + let cur_bound_var = D.cur_bound_var dp in + (Some dp, cur_bound_var) + | Some _ -> + failwith "External_call is only valid at innermost position") + in let dyn_sizes = List.map (fun (_, c) -> Expr.const c) dynamic_sizes in apply_ln - (introduce_wait_tile var_dim - (fun _ (z, _) -> - let tensors_info = List.map (fun t ->t, - List.map - (fun d -> Dim.id d, Expr.index @@ Index.from_dim d) - @@ Tensor.dims_list t - ) tensors_list in - let var_expr = Expr.Var cur_bound_var in - z, LN.Statement [Inst.External_call {name=call_name; - tensors_info ; - var_exprs = var_expr :: dyn_sizes}] - ) - ) payload + (introduce_wait_tile var_dim (fun _ (z, _) -> + let tensors_info = + List.map + (fun t -> + ( t, + List.map (fun d -> + (Dim.id d, Expr.index @@ Index.from_dim d)) + @@ Tensor.dims_list t )) + tensors_list + in + let var_expr = Expr.Var cur_bound_var in + ( z, + LN.Statement + [ + Inst.External_call + { + name = call_name; + tensors_info; + var_exprs = var_expr :: dyn_sizes; + }; + ] ))) + payload let vectorize_dim payload dim = payload |> apply_dmap - (DimMap.modify_opt dim - (function - Some _ -> failwith "Dim has already been defined" - | None -> Some (D.vectorize @@ D.default dim)) - ) + (DimMap.modify_opt dim (function + | Some _ -> failwith "Dim has already been defined" + | None -> Some (D.vectorize @@ D.default dim))) - let vectorize ({is_top;_} as payload) dim = + let vectorize ({ is_top; _ } as payload) dim = assert is_top; let vect kernel = List.map (Inst.vectorize_on_dim @@ Dim.id dim) kernel in let payload = vectorize_dim payload dim in let payload = modify_block vect payload in payload - let unroll ({loop_nest; dim_list; _ } as payload) dim unroll_size = - let unroll_incr = match unroll_size with - Some unroll_size -> Expr.const unroll_size - | None -> Expr.PlaceHolder (Printf.sprintf "ph_%s" (Dim.show_id_of_t dim)) in + let unroll ({ loop_nest; dim_list; _ } as payload) dim unroll_size = + let unroll_incr = + match unroll_size with + | Some unroll_size -> Expr.const unroll_size + | None -> + Expr.PlaceHolder (Printf.sprintf "ph_%s" (Dim.show_id_of_t dim)) + in let payload, (cur_index, increment, next_incr) = - modify_map_dim dim (function + modify_map_dim dim + (function | Some elem -> - let incr = D.incr elem in - let next_incr = Expr.(I.((unroll_incr * incr ))) in - Some (D.map_incr - (fun _ -> next_incr) elem |> D.set_div_constraint Div), - (D.current_index elem, incr, next_incr ) + let incr = D.incr elem in + let next_incr = Expr.(I.(unroll_incr * incr)) in + ( Some + (D.map_incr (fun _ -> next_incr) elem + |> D.set_div_constraint Div), + (D.current_index elem, incr, next_incr) ) | None -> - Some (D.map_incr (Fun.const (unroll_incr)) (D.default dim) - |> D.set_div_constraint Div), - (Index.from_dim dim, Expr.one, unroll_incr) - ) payload in + ( Some + (D.map_incr (Fun.const unroll_incr) (D.default dim) + |> D.set_div_constraint Div), + (Index.from_dim dim, Expr.one, unroll_incr) )) + payload + in let gen_unroll_loop usize loop = let start = Expr.index cur_index in let state_comment = show_current_state dim_list payload in let unr_comment = - Printf.sprintf "U (%s, %d), (%s / %s)" - (Dim.show_id_of_t dim) - usize (Expr.show next_incr) - (Expr.show increment) in - let comment = [state_comment; unr_comment] in - Zip.new_unroll ~comment cur_index start usize increment loop in + Printf.sprintf "U (%s, %d), (%s / %s)" (Dim.show_id_of_t dim) usize + (Expr.show next_incr) (Expr.show increment) + in + let comment = [ state_comment; unr_comment ] in + Zip.new_unroll ~comment cur_index start usize increment loop + in let gen_kernel usize kernel = - let glob_comm = Printf.sprintf "U (%s, %d), (%s / %s)" (Dim.show_id_of_t dim) - usize (Expr.show next_incr) - (Expr.show increment) in - LN.unroll_insts_dim [glob_comm] (Dim.id dim) (Expr.index cur_index) - usize increment kernel in - let apply_unroll unroll_size = let open LN in - function Zip.Top, Statement insts -> - Zip.Top, Statement (gen_kernel unroll_size insts) - | zln -> - gen_unroll_loop unroll_size zln in - let unroll_ln = match unroll_size with - Some unroll_size -> - compose_lnest @@ apply_unroll unroll_size - | None -> - introduce_lambda dim apply_unroll in - {payload with loop_nest = unroll_ln loop_nest} + let glob_comm = + Printf.sprintf "U (%s, %d), (%s / %s)" (Dim.show_id_of_t dim) usize + (Expr.show next_incr) (Expr.show increment) + in + LN.unroll_insts_dim [ glob_comm ] (Dim.id dim) (Expr.index cur_index) + usize increment kernel + in + let apply_unroll unroll_size = + let open LN in + function + | Zip.Top, Statement insts -> + (Zip.Top, Statement (gen_kernel unroll_size insts)) + | zln -> gen_unroll_loop unroll_size zln + in + let unroll_ln = + match unroll_size with + | Some unroll_size -> compose_lnest @@ apply_unroll unroll_size + | None -> introduce_lambda dim apply_unroll + in + { payload with loop_nest = unroll_ln loop_nest } type ('a, 'b, 'c) ternary = One of 'a | Two of 'b | Three of 'c - let lambda_apply ({loop_nest; dmap;_} as payload) dim iter_args_list = - let cur_index, next_index, incr = + let lambda_apply ({ loop_nest; dmap; _ } as payload) dim iter_args_list = + let cur_index, next_index, incr = match DimMap.find dim dmap with - Some dinfo -> - let cur_index = D.current_index dinfo in - let next_index = D.next_index dinfo in - let incr = D.incr dinfo in - cur_index, next_index, incr + | Some dinfo -> + let cur_index = D.current_index dinfo in + let next_index = D.next_index dinfo in + let incr = D.incr dinfo in + (cur_index, next_index, incr) | None -> - let dinfo = D.default dim in - let cur_index = D.current_index dinfo in - let next_index = D.next_index dinfo in - cur_index, next_index, Expr.one in - let placeholder_name = "ph_" ^ (Dim.show_id_of_t dim) in - let rec apply = function - Concrete _ -> failwith (Printf.sprintf "Dim %s wasn't closured" (Dim.show_id_of_t dim)) + let dinfo = D.default dim in + let cur_index = D.current_index dinfo in + let next_index = D.next_index dinfo in + (cur_index, next_index, Expr.one) + in + let placeholder_name = "ph_" ^ Dim.show_id_of_t dim in + let rec apply = function + | Concrete _ -> + failwith + (Printf.sprintf "Dim %s wasn't closured" (Dim.show_id_of_t dim)) | Lambda (d, g) when Dim.equal d dim -> - let rec handle_ln_list l = - (* Collect results of applying lambda. Each of them should be of - * the same variant *) - List.fold_right (fun ln llist -> - match llist, ln with - | None, Concrete (indexes, ln) -> Some (One [indexes, ln]) - | None, Lambda (d, g) -> Some (Two (d, [g])) - | None, WaitPack (t, h) -> Some (Three (t, [h])) - | Some (One l), Concrete (idx, ln) -> Some (One ((idx, ln) :: l)) - | Some (Two (d, l)), Lambda (d', f) when Dim.equal d d' -> - Some (Two (d, f :: l)) - | Some (Three (t, l)), WaitPack (t', f) when T.equal t t' -> - Some (Three (t, f :: l)) - | _ -> failwith "Incoherent state") - l None - |> function None -> failwith "Empty list" - | Some (One(llist)) -> - let idxs, ln = Option.get @@ List.reduce - (fun (idx, ln) (idx', g) -> idx @ idx', - Zip.embed_after_at_top (Zip.to_tree ln) g) - llist in - Concrete (idxs, ln) - | Some (Two (d, llist)) -> - Lambda(d, fun i -> handle_ln_list @@ List.map (fun f -> f i) llist) - | Some (Three (t, llist)) -> - WaitPack(t, fun pt di -> handle_ln_list @@ List.map (fun f -> f pt di) llist) - in - List.fold_left (fun (expr_start, l) (Iter i, Arg a) -> - let uincr = Expr.instanciate_placeholder placeholder_name a incr in - let halt = Expr.(I.(expr_start + const i * incr)) in - let ln_a = compose_lnest - ( Zip.map_const_expr - (Expr.simplify % Expr.instanciate_placeholder placeholder_name a) + let rec handle_ln_list l = + (* Collect results of applying lambda. Each of them should be of + * the same variant *) + List.fold_right + (fun ln llist -> + match (llist, ln) with + | None, Concrete (indexes, ln) -> Some (One [ (indexes, ln) ]) + | None, Lambda (d, g) -> Some (Two (d, [ g ])) + | None, WaitPack (t, h) -> Some (Three (t, [ h ])) + | Some (One l), Concrete (idx, ln) -> + Some (One ((idx, ln) :: l)) + | Some (Two (d, l)), Lambda (d', f) when Dim.equal d d' -> + Some (Two (d, f :: l)) + | Some (Three (t, l)), WaitPack (t', f) when T.equal t t' -> + Some (Three (t, f :: l)) + | _ -> failwith "Incoherent state") + l None + |> function + | None -> failwith "Empty list" + | Some (One llist) -> + let idxs, ln = + Option.get + @@ List.reduce + (fun (idx, ln) (idx', g) -> + ( idx @ idx', + Zip.embed_after_at_top (Zip.to_tree ln) g )) + llist + in + Concrete (idxs, ln) + | Some (Two (d, llist)) -> + Lambda + (d, fun i -> handle_ln_list @@ List.map (fun f -> f i) llist) + | Some (Three (t, llist)) -> + WaitPack + ( t, + fun pt di -> + handle_ln_list @@ List.map (fun f -> f pt di) llist ) + in + List.fold_left + (fun (expr_start, l) (Iter i, Arg a) -> + let uincr = + Expr.instanciate_placeholder placeholder_name a incr + in + let halt = Expr.(I.(expr_start + (const i * incr))) in + let ln_a = + compose_lnest + (Zip.map_const_expr + (Expr.simplify + % Expr.instanciate_placeholder placeholder_name a) % Zip.new_seq cur_index expr_start halt uincr) - (g a) in - Expr.(I.(expr_start + const i * uincr)), ln_a::l - ) (Expr.index next_index, []) iter_args_list - |> snd - |> handle_ln_list - | Lambda (d, g) -> - Lambda (d, fun i -> apply (g i)) - | WaitTile (d, g) -> - WaitTile (d, fun i -> apply (g i)) - | WaitPack (t, thunk) -> - WaitPack (t, fun t di -> apply (thunk t di)) in - let body_size = List.fold_left (fun s (Iter i, Arg a) -> s + i * a) - 0 iter_args_list in - let new_incr = Expr.(I.(const body_size * instanciate_placeholder placeholder_name 1 incr)) in - let dmap = DimMap.modify dim (D.bump_index % D.map_incr (Fun.const new_incr)) dmap in - {payload with dmap; loop_nest = apply loop_nest} - - let hoist_var ({loop_nest;_} as payload) dim_list = - let loop_nest = compose_lnest ( - function (z, LN.Statement kernel) -> - let accesses, loads, kernel = - Inst.fact_loads_on_dims (List.map Dim.id dim_list) [] kernel in - let _, stores, kernel = - Inst.fact_stores_on_dims (List.map Dim.id dim_list) accesses kernel in - let loads = LN.Statement loads - and stores = LN.Statement stores in - (z, LN.Statement kernel) |> Zip.embed_before_at_top loads |> Zip.embed_after_at_top stores - | _ -> failwith "Ambiguous kernel" - ) loop_nest in - {payload with loop_nest} + (g a) + in + (Expr.(I.(expr_start + (const i * uincr))), ln_a :: l)) + (Expr.index next_index, []) + iter_args_list + |> snd |> handle_ln_list + | Lambda (d, g) -> Lambda (d, fun i -> apply (g i)) + | WaitTile (d, g) -> WaitTile (d, fun i -> apply (g i)) + | WaitPack (t, thunk) -> WaitPack (t, fun t di -> apply (thunk t di)) + in + let body_size = + List.fold_left (fun s (Iter i, Arg a) -> s + (i * a)) 0 iter_args_list + in + let new_incr = + Expr.( + I.(const body_size * instanciate_placeholder placeholder_name 1 incr)) + in + let dmap = + DimMap.modify dim (D.bump_index % D.map_incr (Fun.const new_incr)) dmap + in + { payload with dmap; loop_nest = apply loop_nest } + + let hoist_var ({ loop_nest; _ } as payload) dim_list = + let loop_nest = + compose_lnest + (function + | z, LN.Statement kernel -> + let accesses, loads, kernel = + Inst.fact_loads_on_dims (List.map Dim.id dim_list) [] kernel + in + let _, stores, kernel = + Inst.fact_stores_on_dims (List.map Dim.id dim_list) accesses + kernel + in + let loads = LN.Statement loads + and stores = LN.Statement stores in + (z, LN.Statement kernel) + |> Zip.embed_before_at_top loads + |> Zip.embed_after_at_top stores + | _ -> failwith "Ambiguous kernel") + loop_nest + in + { payload with loop_nest } (* This function is "generic" toward both inner and outer level * both its bound and its increment are flexible *) - let tile_generalized_exact_partial dim_list payload to_outer to_inner - dim tile_size = - let payload, (cur_index, next_index, cur_bound_name, next_bound_name, incr, next_incr) = + let tile_generalized_exact_partial dim_list payload to_outer to_inner dim + tile_size = + let ( payload, + ( cur_index, + next_index, + cur_bound_name, + next_bound_name, + incr, + next_incr ) ) = modify_map_dim dim (function | Some dpayload -> - let div_const = D.div_constraint dpayload in - let cur_index = D.current_index dpayload in - let next_index = D.next_index dpayload in - let incr = D.incr dpayload in - let open Dim_info in - let new_div_constraint = match div_const with - | Div -> - (if (not to_inner) then - match incr with - One -> () - | Const c -> - if (tile_size mod c <> 0) then failwith - @@ Printf.sprintf "T_partial was called but dim %s is not flexible" - @@ Dim.show_id_of_t dim - | _ -> - Printf.printf "Incr %s : %s\n" (Dim.show_id_of_t dim) @@ [%show: Expr.t] incr; - failwith - @@ Printf.sprintf "T_partial was called but dim %s is not flexible" - @@ Dim.show_id_of_t dim - ); - Div - | Flexible factor -> - if to_inner then assert (tile_size mod factor = 0); - if to_outer then Flexible factor - else Div in - let next_incr = Expr.(Const tile_size) in - let cur_bound_name = D.cur_bound_var dpayload in - let next_bound_name = D.next_bound_var dpayload in - Some (D.( - bump_index - % map_incr (fun _ -> next_incr) - % set_div_constraint new_div_constraint - ) - dpayload - ), - (cur_index, next_index, cur_bound_name, next_bound_name, incr, next_incr) + let div_const = D.div_constraint dpayload in + let cur_index = D.current_index dpayload in + let next_index = D.next_index dpayload in + let incr = D.incr dpayload in + let open Dim_info in + let new_div_constraint = + match div_const with + | Div -> + (if not to_inner then + match incr with + | One -> () + | Const c -> + if tile_size mod c <> 0 then + failwith + @@ Printf.sprintf + "T_partial was called but dim %s is not \ + flexible" + @@ Dim.show_id_of_t dim + | _ -> + Printf.printf "Incr %s : %s\n" (Dim.show_id_of_t dim) + @@ [%show: Expr.t] incr; + failwith + @@ Printf.sprintf + "T_partial was called but dim %s is not \ + flexible" + @@ Dim.show_id_of_t dim); + Div + | Flexible factor -> + if to_inner then assert (tile_size mod factor = 0); + if to_outer then Flexible factor else Div + in + let next_incr = Expr.(Const tile_size) in + let cur_bound_name = D.cur_bound_var dpayload in + let next_bound_name = D.next_bound_var dpayload in + ( Some + (D.( + bump_index + % map_incr (fun _ -> next_incr) + % set_div_constraint new_div_constraint) + dpayload), + ( cur_index, + next_index, + cur_bound_name, + next_bound_name, + incr, + next_incr ) ) | None -> - let () = if not to_outer || to_inner then failwith @@ - "Tile Exact cannot be the first specifier on dimension " ^ - Dim.show_id_of_t dim in - let dpayload = D.default dim in - let cur_index = D.current_index dpayload in - let next_index = D.next_index dpayload in - let cur_bound_name = D.cur_bound_var dpayload in - let next_bound_name = D.next_bound_var dpayload in - let div_const = Dim_info.Flexible 1 in - let incr = D.incr dpayload in - let next_incr = Expr.(I.(const tile_size * incr)) in - Some (D.( - bump_index - % map_incr (fun _ -> next_incr) - % set_div_constraint div_const - ) - dpayload - ), - (cur_index, next_index, cur_bound_name, next_bound_name, incr, next_incr) - ) payload in + let () = + if (not to_outer) || to_inner then + failwith + @@ "Tile Exact cannot be the first specifier on dimension " + ^ Dim.show_id_of_t dim + in + let dpayload = D.default dim in + let cur_index = D.current_index dpayload in + let next_index = D.next_index dpayload in + let cur_bound_name = D.cur_bound_var dpayload in + let next_bound_name = D.next_bound_var dpayload in + let div_const = Dim_info.Flexible 1 in + let incr = D.incr dpayload in + let next_incr = Expr.(I.(const tile_size * incr)) in + ( Some + (D.( + bump_index + % map_incr (fun _ -> next_incr) + % set_div_constraint div_const) + dpayload), + ( cur_index, + next_index, + cur_bound_name, + next_bound_name, + incr, + next_incr ) )) + payload + in let gen_loop tile_expr = let start = Expr.index next_index (*and halt = Expr.(Infix.(min (next_index --+ incr * (const tile_size)) * (size (Dim.size_id dim)))) in*) - and halt = let open Expr in let open I in + and halt = + let open Expr in + let open I in if to_outer then index next_index + Var next_bound_name - else index next_index + const tile_size in + else index next_index + const tile_size + in let glob_state = show_current_state dim_list payload in - let loop_comment = match to_outer, to_inner with - | true, false -> - Printf.sprintf "T_partial (%s, %d) (%s / %s)" - (Dim.show_id_of_t dim) tile_size - (Expr.show next_incr) (Expr.show incr) - | true, true -> - Printf.sprintf "T_gexact (%s, %d) (%s / %s)" - (Dim.show_id_of_t dim) tile_size - (Expr.show next_incr) (Expr.show incr) - | false, true -> Printf.sprintf "T_exact (%s, %d) (%s / %s)" - (Dim.show_id_of_t dim) tile_size - (Expr.show next_incr) (Expr.show incr) - | false, false -> assert false in - let comment = [glob_state; loop_comment] in - let vars_decl = [cur_bound_name, - Expr.(I.(Min(incr, tile_expr - (index cur_index - - index next_index))))] in - if to_inner then Zip.new_seq ~comment ~vars_decl cur_index start halt incr - else Zip.new_seq ~comment cur_index start halt incr in - let new_loop = - (if to_outer then - introduce_wait_tile dim gen_loop - else compose_lnest (gen_loop (Const tile_size))) + let loop_comment = + match (to_outer, to_inner) with + | true, false -> + Printf.sprintf "T_partial (%s, %d) (%s / %s)" + (Dim.show_id_of_t dim) tile_size (Expr.show next_incr) + (Expr.show incr) + | true, true -> + Printf.sprintf "T_gexact (%s, %d) (%s / %s)" + (Dim.show_id_of_t dim) tile_size (Expr.show next_incr) + (Expr.show incr) + | false, true -> + Printf.sprintf "T_exact (%s, %d) (%s / %s)" (Dim.show_id_of_t dim) + tile_size (Expr.show next_incr) (Expr.show incr) + | false, false -> assert false + in + let comment = [ glob_state; loop_comment ] in + let vars_decl = + [ + ( cur_bound_name, + Expr.( + I.(Min (incr, tile_expr - (index cur_index - index next_index)))) + ); + ] + in + if to_inner then + Zip.new_seq ~comment ~vars_decl cur_index start halt incr + else Zip.new_seq ~comment cur_index start halt incr + in + let new_loop = + (if to_outer then introduce_wait_tile dim gen_loop + else compose_lnest (gen_loop (Const tile_size))) % - (if to_inner then - eliminate_wait_tile dim (Var cur_bound_name) - else - eliminate_wait_tile dim (Const tile_size)) + if to_inner then eliminate_wait_tile dim (Var cur_bound_name) + else eliminate_wait_tile dim (Const tile_size) in - {payload with loop_nest = new_loop payload.loop_nest} - + { payload with loop_nest = new_loop payload.loop_nest } (* TILING : could be factorized with tile_generalized_exact_partial *) - let tile_on_dim ?(par=false) dim_list payload dim tile_size_opt = - let tile_size = match tile_size_opt with - Some tile_size -> Expr.const tile_size - | None -> Expr.PlaceHolder (Printf.sprintf "ph_%s" (Dim.show_id_of_t dim)) in + let tile_on_dim ?(par = false) dim_list payload dim tile_size_opt = + let tile_size = + match tile_size_opt with + | Some tile_size -> Expr.const tile_size + | None -> + Expr.PlaceHolder (Printf.sprintf "ph_%s" (Dim.show_id_of_t dim)) + in let payload, (cur_index, next_index, incr, next_incr) = let set_par_dim = if par then D.set_par else Fun.id in modify_map_dim dim (function | Some dpayload -> - let cur_index = D.current_index dpayload in - let next_index = D.next_index dpayload in - let incr = D.incr dpayload in - (* weird test, should find a proper semantic here *) - let () = assert (D.div_constraint dpayload = Div - || D.div_constraint dpayload = Flexible 1) in - let next_incr = Expr.(I.(tile_size * incr)) in - Some (D.( - set_par_dim - % bump_index - % map_incr (fun _ -> next_incr) - ) - dpayload - ), - (cur_index, next_index, incr, next_incr) + let cur_index = D.current_index dpayload in + let next_index = D.next_index dpayload in + let incr = D.incr dpayload in + (* weird test, should find a proper semantic here *) + let () = + assert ( + D.div_constraint dpayload = Div + || D.div_constraint dpayload = Flexible 1) + in + let next_incr = Expr.(I.(tile_size * incr)) in + ( Some + (D.( + set_par_dim % bump_index % map_incr (fun _ -> next_incr)) + dpayload), + (cur_index, next_index, incr, next_incr) ) | None -> - let next_incr = tile_size in - let dpayload = D.default dim - |> D.map_incr (Fun.const next_incr) in - let cur_index = D.current_index dpayload in - let next_index = D.next_index dpayload in - Some D.(set_par_dim @@ bump_index dpayload), - (cur_index, next_index, Expr.one, next_incr) - ) payload in + let next_incr = tile_size in + let dpayload = + D.default dim |> D.map_incr (Fun.const next_incr) + in + let cur_index = D.current_index dpayload in + let next_index = D.next_index dpayload in + ( Some D.(set_par_dim @@ bump_index dpayload), + (cur_index, next_index, Expr.one, next_incr) )) + payload + in + let gen_loop tile_size = - let start = Expr.index next_index - (*and halt = Expr.(Infix.(min (next_index --+ incr * (const tile_size)) - * (size (Dim.size_id dim)))) in*) - and halt = Expr.(I.(index next_index + incr * const tile_size)) in + let start = + if par then + Expr.I.((incr * Expr.Var "thread_it_min") + Expr.index next_index) + else Expr.index next_index + and halt = + if par then Expr.I.(incr * Expr.Var "thread_it_max") + else Expr.(I.(index next_index + (incr * const tile_size))) + in let glob_state = show_current_state dim_list payload in - let loop_comment = Printf.sprintf "T (%s, %d) (%s / %s)" - (Dim.show_id_of_t dim) tile_size - (Expr.show next_incr) (Expr.show incr) + let suffix_par = if par then "_par" else "" in + let loop_comment = + Printf.sprintf "T%s (%s, %d) (%s / %s)" suffix_par + (Dim.show_id_of_t dim) tile_size (Expr.show next_incr) + (Expr.show incr) in - let comment = [glob_state; loop_comment] in - let pragma = "omp parallel for" in - if par then Zip.new_seq ~comment ~pragma cur_index start halt incr - else Zip.new_seq ~comment cur_index start halt incr in - let new_loop = match tile_size_opt with + let comment = [ glob_state; loop_comment ] in + Zip.new_seq ~comment cur_index start halt incr + in + let new_loop = + match tile_size_opt with (* WARNING : I don't know how this should interact with lambda - for now * it crashes *) - | Some tile_size -> compose_lnest (gen_loop tile_size) - | None -> introduce_lambda dim gen_loop in - {payload with loop_nest = new_loop payload.loop_nest} + | Some tile_size -> compose_lnest (gen_loop tile_size) + | None -> introduce_lambda dim gen_loop + in + { payload with loop_nest = new_loop payload.loop_nest } + + (**inspired by tile_on_dim to merge all parallel loops declared inside Fused_T_pars *) + let fuse_parallel_loops size_dim_list payload = + (* regroup T_par on the same dimension to one bigger T_par *) + let size_dim_list = + let dims = List.map snd size_dim_list in + let unique_dims = List.unique ~eq:Dim.equal dims in + List.map + (fun udim -> + let sizes_udim = + List.filter_map + (function + | size, dim when Dim.equal dim udim -> Some size | _ -> None) + size_dim_list + in + let total_size = + List.fold (fun acc size -> acc * size) 1 sizes_udim + in + (total_size, udim)) + unique_dims + in + if List.length size_dim_list < 2 then + failwith + "Fused_T_pars should have at least 2 loops to fuse. Use T_par if you \ + only want to run parallelisation on 1 dimension."; + let size_dim_list = + List.map (fun (s, d) -> (Expr.const s, d)) size_dim_list + in + let payload, tmp_info = + List.fold + (fun (old_payload, tuples) (tile_size, dim) -> + let new_payload, tuple = + modify_map_dim dim + (function + | Some dpayload -> + let cur_index = D.current_index dpayload in + let _next_index = D.next_index dpayload in + let incr = D.incr dpayload in + (* weird test, should find a proper semantic here *) + let () = + assert ( + D.div_constraint dpayload = Div + || D.div_constraint dpayload = Flexible 1) + in + let next_incr = Expr.(I.(tile_size * incr)) in + ( Some + (D.( + D.set_par % bump_index + % map_incr (fun _ -> next_incr)) + dpayload), + (cur_index, incr, tile_size) ) + | None -> + let next_incr = tile_size in + let dpayload = + D.default dim |> D.map_incr (Fun.const next_incr) + in + let cur_index = D.current_index dpayload in + let _next_index = D.next_index dpayload in + ( Some D.(set_par @@ bump_index dpayload), + (cur_index, Expr.const 1, tile_size) )) + old_payload + in + (* since payload is updated it is important to keep the updated version at each call of modify_map*) + (new_payload, tuple :: tuples)) + (payload, []) size_dim_list + in + let comment = + [ + Printf.sprintf "Fused_T_pars %s" + @@ [%show: (Expr.t * string) list] + (List.map (fun (s, d) -> (s, Dim.show_id_of_t d)) size_dim_list); + ] + in + let gen_fake_dim, _ = Dim.fresh_gen "thread_it" in + let cur_index = + Index.from_dim (gen_fake_dim ()) + (* hack create new dimension to create our own index here for thread work sharing*) + and start = Expr.Var "thread_it_min" + and halt = Expr.Var "thread_it_max" + and incr = Expr.Const 1 in + + (* need to be careful with the dimension sizes from inner to outter for the right choice of % and / to compute indices from thread iter*) + (* example: Fused_T_pars (I, i) (J, j) (K, k) -> thread_it = i + I*j + IJ*k + so we must do thread_it % I to find i, (thread_it / I) % J to find j, thread_it / (JJ) to find k*) + let vars_decl = + List.fold_lefti + (fun acc i (index, incr, size) -> + let product_previous_sizes = + match acc with [] -> 1 | h :: _ -> fst h + in + let size_int, incr_int = + match (size, incr) with + | Expr.Const si, Expr.Const incr -> (si, incr) + | _ -> failwith "Expr.const expected in tile_size" + in + let str = + if i = 0 then + Printf.sprintf "(thread_it %% %d) * %d" size_int incr_int + else if i = List.length tmp_info - 1 then + Printf.sprintf "(thread_it / %d) * %d" product_previous_sizes + incr_int + else + Printf.sprintf "((thread_it / %d) %% %d) * %d" + product_previous_sizes size_int incr_int + in + ( product_previous_sizes * size_int, + (Index.show_id_of_t index, Expr.Var str) ) + :: acc) + [] tmp_info + in + let vars_decl = List.map snd vars_decl in + let new_loop = + compose_lnest + (Zip.new_seq ~comment ~vars_decl cur_index start halt incr) + in + { payload with loop_nest = new_loop payload.loop_nest } (* Same semantic as List.fold_left f init (DimMap.bindings payload.dmap) *) - let fold_dims f init payload = - DimMap.fold f payload.dmap init + let fold_dims f init payload = DimMap.fold f payload.dmap init (* Get a list of every indexes used in loops generated so far *) let all_indexes payload = - let dim_indexes = fold_dims (fun _ dp ind_list -> - let aux = D.indexes_list dp - |> List.map (fun (ind, _, l) -> ind, l) in - List.fold_left - (fun l (ind, ind_l) -> ind::(List.rev_append (List.map fst ind_l) l)) - ind_list aux) [] payload in + let dim_indexes = + fold_dims + (fun _ dp ind_list -> + let aux = + D.indexes_list dp |> List.map (fun (ind, _, l) -> (ind, l)) + in + List.fold_left + (fun l (ind, ind_l) -> + ind :: List.rev_append (List.map fst ind_l) l) + ind_list aux) + [] payload + in dim_indexes - let fold_tens f init payload = - TidMap.fold f payload.tmap init + let fold_tens f init payload = TidMap.fold f payload.tmap init + + let check_vectorisation payload tensor = + match payload.vectorize_dim with + | Some dim -> + if List.mem dim (T.dims_list tensor) then + let inner_dim = T.inner_dim tensor in + match inner_dim with + | Single i_dim when Dim.equal dim i_dim -> () + | _ -> + let s1 = Dim.show_id_of_t dim in + let s2 = T.show_tid tensor in + let s3 = [%show: T.t_dims] inner_dim in + let error_msg = + Printf.sprintf + "Incorrect vectorisation of dim %s, whereas tensor %s \ + inner dim is %s." + s1 s2 s3 + in + failwith error_msg + else () + | None -> () + + let pack_tensor ({ dmap; loop_nest; _ } as payload) tensor size_list + is_readonly transpose_opt accesses = + let tmap, (tpayload, is_packed) = + TidMap.modify_map (T.id tensor) + (function + | Some tpld -> + let clone = Te.pack tpld size_list () in + (Some clone, (Te.current_tensor clone, true)) + | None -> + let def_tens_pld = Te.make tensor size_list in + (Some def_tens_pld, (Te.current_tensor def_tens_pld, false))) + payload.tmap + in + let payload, current_tensor = ({ payload with tmap }, tpayload) in + let current_tensor = + match transpose_opt with + | Some transposition -> + T.reorder_layout current_tensor (List.rev transposition) + | None -> current_tensor + in + + let _ = + match TidMap.find (T.id current_tensor) payload.tmap with + | None -> + check_vectorisation payload current_tensor + (* in case of same tensor packed twice only check for innermost packed array *) + | _ -> () + in - let pack_tensor ({dmap; loop_nest;_} as payload) tensor size_list - is_readonly transpose_opt accesses = - let tmap, (tpayload, is_packed) = TidMap.modify_map (T.id tensor) (function - Some tpld -> let clone = Te.pack tpld size_list () in - Some clone, (Te.current_tensor clone, true) - | None -> let def_tens_pld = Te.make tensor size_list in - Some def_tens_pld, (Te.current_tensor def_tens_pld, false) - ) - payload.tmap in - let payload, current_tensor = {payload with tmap}, tpayload in - let current_tensor = if Option.is_some transpose_opt then - let transposition = List.rev (Option.get transpose_opt) in - T.reorder_layout current_tensor transposition - else - current_tensor in let rec apply_pack_to_last = function | Concrete _ as c -> c | Lambda (dim, f) -> Lambda (dim, fun i -> apply_pack_to_last (f i)) | WaitTile (dim, f) -> WaitTile (dim, fun i -> apply_pack_to_last (f i)) | WaitPack (tensor', thunk) when T.equal tensor tensor' -> - thunk dmap current_tensor - | WaitPack (tensor', thunk) -> - WaitPack (tensor', fun t di -> apply_pack_to_last (thunk t di)) in + thunk dmap current_tensor + | WaitPack (tensor', thunk) -> + WaitPack (tensor', fun t di -> apply_pack_to_last (thunk t di)) + in (* Swap A for A0/A1 in every inst expression *) - let transform insts = - Inst.swap_tensor tensor current_tensor insts in + let transform insts = Inst.swap_tensor tensor current_tensor insts in let map_access dim access_expr new_accesses_map = - let access = Option.default (Fun.const Expr.zero) - (List.assoc_eq Dim.equal_id dim new_accesses_map) in - dim, access access_expr in + let access = + Option.default (Fun.const Expr.zero) + (List.assoc_eq Dim.equal_id dim new_accesses_map) + in + (dim, access access_expr) + in let swap_new_tensor_access insts = - Inst.swap_accesses (fun other_tensor old_accesses -> + Inst.swap_accesses + (fun other_tensor old_accesses -> if T.equal_id (T.id other_tensor) (T.id current_tensor) then - ( - List.map (fun (d, old_expr) -> map_access d old_expr accesses) old_accesses - ) - else (old_accesses)) - insts in + List.map + (fun (d, old_expr) -> map_access d old_expr accesses) + old_accesses + else old_accesses) + insts + in let pack kernel = let kernel = transform kernel in - if is_packed then kernel - else swap_new_tensor_access kernel in - let intro_pack = introduce_wait_pack tensor - (fun gdi gt -> - PI.build_pack dmap tensor current_tensor is_readonly gdi gt % (Zip.map_all_insts pack)) in - {payload with loop_nest = intro_pack @@ apply_pack_to_last loop_nest} - - + if is_packed then kernel else swap_new_tensor_access kernel + in + let intro_pack = + introduce_wait_pack tensor (fun gdi gt -> + PI.build_pack dmap tensor current_tensor is_readonly gdi gt + % Zip.map_all_insts pack) + in + { payload with loop_nest = intro_pack @@ apply_pack_to_last loop_nest } end type sym_size = Const_size of int | TimesVar of int + (* Get size of tile on this particular dimension *) let dim_tile_size tile_scheme dim = let aux current_size = function | V d when Dim.equal d dim -> Const_size Inst.A.vec_size - | External_call {fixed_size; max_range; dynamic_sizes;_} -> - ( match List.assoc_opt dim fixed_size with - Some s -> Const_size s - | None -> (match List.assoc_opt dim dynamic_sizes with - | Some s -> Const_size s - | None -> Const_size max_range - ) - ) - | U (size, d) | T (size, d) when Dim.equal d dim -> - (match current_size with Const_size c -> Const_size (c * size) - | TimesVar c -> TimesVar (c * size)) - | Tile_exact (size, d) when Dim.equal d dim -> - (match current_size with Const_size _ -> Const_size (size) - | TimesVar _ -> assert false) - | ULambda d when Dim.equal d dim -> - (match current_size with Const_size c -> TimesVar c - | TimesVar _ -> failwith "Multiple lambda on a single dimension") - | Lambda_apply (d, liter_arg) when Dim.equal d dim -> - (match current_size with Const_size _ -> failwith "lambda apply called without lambda" - | TimesVar c -> - Const_size - ( List.fold_left (fun acc (Iter i, Arg a) -> acc + i * a) 0 liter_arg * c) - ) - | _ -> current_size in - List.fold_left aux (Const_size 1) tile_scheme - |> function Const_size c -> c - | _ -> failwith @@ "Lambda without apply on dim " ^ (Dim.show_id_of_t dim) + | External_call { fixed_size; max_range; dynamic_sizes; _ } -> ( + match List.assoc_opt dim fixed_size with + | Some s -> Const_size s + | None -> ( + match List.assoc_opt dim dynamic_sizes with + | Some s -> Const_size s + | None -> Const_size max_range)) + | (U (size, d) | T (size, d)) when Dim.equal d dim -> ( + match current_size with + | Const_size c -> Const_size (c * size) + | TimesVar c -> TimesVar (c * size)) + | Tile_exact (size, d) when Dim.equal d dim -> ( + match current_size with + | Const_size _ -> Const_size size + | TimesVar _ -> assert false) + | ULambda d when Dim.equal d dim -> ( + match current_size with + | Const_size c -> TimesVar c + | TimesVar _ -> failwith "Multiple lambda on a single dimension") + | Lambda_apply (d, liter_arg) when Dim.equal d dim -> ( + match current_size with + | Const_size _ -> failwith "lambda apply called without lambda" + | TimesVar c -> + Const_size + (List.fold_left + (fun acc (Iter i, Arg a) -> acc + (i * a)) + 0 liter_arg + * c)) + | _ -> current_size + in + List.fold_left aux (Const_size 1) tile_scheme |> function + | Const_size c -> c + | _ -> failwith @@ "Lambda without apply on dim " ^ Dim.show_id_of_t dim (* Monstruous function - This thing is an absolute mess *) - let handle_pack payload tensor transpose_opt = + let handle_pack payload tensor transpose_opt = let module P = Payload in let t_dims = T.t_dims_list tensor in let dims = T.dims_list tensor in (* Retrieve current status of relevant dimensions - dimensions that appear in the tensor *) - let dims_payload_opt = List.map - (fun dim -> dim, P.(find_dim dim payload)) - dims in + let dims_payload_opt = + List.map (fun dim -> (dim, P.(find_dim dim payload))) dims + in (* Filtered dims payloads - that is, only dimensions that concretely * appears in a loop nest around the kernel appear there *) - let dims_payload = List.filter_map - (fun (d, opt) -> Option.map (fun pl -> d, pl) opt) - dims_payload_opt in + let dims_payload = + List.filter_map + (fun (d, opt) -> Option.map (fun pl -> (d, pl)) opt) + dims_payload_opt + in (* List of sizes for new tensor dimensions *) let size_list = - List.map (let open Tensor in function - Single dim as sd -> + List.map + (let open Tensor in + function + | Single dim as sd -> let d_opt = P.find_dim dim payload in let incr = Option.map_default P.D.incr Expr.zero d_opt in (sd, incr) - | Join (d1, d2, _, stride) as jd -> + | Join (d1, d2, _, stride) as jd -> ( let d1_opt = P.find_dim d1 payload in let d2_opt = P.find_dim d2 payload in - let mul_stride e = Option.map_default (fun s -> - Expr.(I.(const s * e))) e stride in - match Option.map P.D.incr d1_opt, Option.map P.D.incr d2_opt with - | Some inc1, (Some (Expr.SizeVar _ as size) - | Some (Expr.Const _ as size)) -> - (* TODO TEST TEST TEST *) - (jd, Expr.add (mul_stride inc1) size) - | Some _, Some (_) -> - (* TODO clean that, this is dirty *) - failwith "Tiling on small dim !!!" - | None, Some incr -> - jd, incr - | Some incr, None -> - jd, incr - | None, None -> - jd, Expr.zero - ) t_dims in - let current_pack_indexes_opt = List.map (fun (d, dp) -> - d, (P.D.current_pack dp)) - dims_payload in - let accesses = List.map (function - d, Some index -> Dim.id d, fun expr -> - Expr.alpha_replace (Index.from_dim d) (Expr.index index) expr - | d, None -> Dim.id d, fun expr -> - Expr.alpha_replace (Index.from_dim d) Expr.zero expr - ) - current_pack_indexes_opt in - (* TODO: find a cleaner way to do that ?*) - let is_readonly = Inst.is_tensor_readonly payload.kernel tensor in + let mul_stride e = + Option.map_default (fun s -> Expr.(I.(const s * e))) e stride + in + match (Option.map P.D.incr d1_opt, Option.map P.D.incr d2_opt) with + | ( Some inc1, + (Some (Expr.SizeVar _ as size) | Some (Expr.Const _ as size)) ) + -> + (* TODO TEST TEST TEST *) + (jd, Expr.add (mul_stride inc1) size) + | Some _, Some _ -> + (* TODO clean that, this is dirty *) + failwith "Tiling on small dim !!!" + | None, Some incr -> (jd, incr) + | Some incr, None -> (jd, incr) + | None, None -> (jd, Expr.zero))) + t_dims + in + let current_pack_indexes_opt = + List.map (fun (d, dp) -> (d, P.D.current_pack dp)) dims_payload + in + let accesses = + List.map + (function + | d, Some index -> + ( Dim.id d, + fun expr -> + Expr.alpha_replace (Index.from_dim d) (Expr.index index) expr + ) + | d, None -> + ( Dim.id d, + fun expr -> Expr.alpha_replace (Index.from_dim d) Expr.zero expr + )) + current_pack_indexes_opt + in + (* TODO: find a cleaner way to do that ?*) + let is_readonly = Inst.is_tensor_readonly payload.kernel tensor in P.pack_tensor payload tensor size_list is_readonly transpose_opt accesses + (** check that: + * - all tiles after T_par or Fused_T_pars are on red dims + * - no red dimension is on T_par or Fused_T_pars + *) + let check_parallel leftover_tiles red_dims = + if List.is_empty leftover_tiles then () + else + let is_red_dim dim = List.mem dim red_dims in + let error_msg_red_dim op dim = + let s_rd = [%show: Dim.id] (Dim.id dim) in + Printf.sprintf "Red dims like %s are not allowed in %s." s_rd op + in + let to_check = List.drop 1 leftover_tiles in + let tile_par = List.first leftover_tiles in + let _ = + match tile_par with + | T_par (_, dim) when is_red_dim dim -> + failwith (error_msg_red_dim "T_par" dim) + | Fused_T_pars list -> + let dim_list = List.map snd list in + let opt_red_dim = List.find is_red_dim dim_list in + let res = + match opt_red_dim with + | Some dim -> failwith (error_msg_red_dim "Fused_T_pars" dim) + | None -> () + in + res + | _ -> () + in + List.iter + (fun tile -> + match tile with + | T (_, dim) when is_red_dim dim -> () + | _ -> + let s = [%show: loop_type] tile in + let s2 = [%show: loop_type] tile_par in + let error_msg = + Printf.sprintf + "%s encountered after %s, only T(_, dim) on red dims are \ + allowed." + s s2 + in + failwith error_msg) + to_check + + (** Generate a parallel loop merged from all the T_par loops *) + let gen_parallel_utils n_iters fused = + if n_iters = 1 then ("", "", "") + else + let thread_id = "tid" in + let thread_it = "thread_it" in + let n_threads = "n_threads" in + let ind_min = "thread_it_min" in + let ind_max = "thread_it_max" in + let base_job = "base_job" in + let residual_job = "residual" in + let header = + [ + "#pragma omp parallel"; + "{\n" (* important to block the entire following code *); + ] + in + let declarations = + [ + (thread_id, "omp_get_thread_num()"); + (n_threads, "omp_get_num_threads()"); + (base_job, Printf.sprintf "%d / %s" n_iters n_threads); + (residual_job, Printf.sprintf "%d %% %s" n_iters n_threads); + (ind_min, "0"); + (ind_max, "0"); + ] + in + let declarations = + if fused then (thread_it, "0") :: declarations else declarations + in + let declarations = + List.map + (fun (name, value) -> Printf.sprintf "int %s = %s;" name value) + declarations + in + let balance_threads = + [ + Printf.sprintf "if (%s < %s) {" thread_id residual_job; + Printf.sprintf "\t%s = %s * (%s + 1);" ind_min thread_id base_job; + Printf.sprintf "\t%s = %s + (%s + 1);" ind_max ind_min base_job; + "} else {"; + Printf.sprintf "\t%s = %s * (%s + 1) + (%s - %s) * %s;" ind_min + residual_job base_job thread_id residual_job base_job; + Printf.sprintf "\t%s = %s + %s;" ind_max ind_min base_job; + "}\n"; + ] + in + let footer = "} // pragma omp parallel\n" in + let header = String.concat "\n" header in + let declarations = + String.concat "\n" + (List.map (String.concat "\n") [ declarations; balance_threads ]) + in + (header, declarations, footer) - let gen_loop ?(const_dim_sizes=[]) dim_list inst loop_tiles = + let gen_loop ?(const_dim_sizes = []) dim_list red_dim_list inst loop_tiles = let module P = Payload in let sanitize final_payload = - List.iter (fun d -> + (* catch info to future correctness checking for vectorization dimension *) + let tensors_list = + let open Inst in + match inst with + | Write (Add (_, Mul (Read (a, _), Read (b, _))), c, _) -> [ a; b; c ] + | _ -> [] + in + let transposed_tensors = + List.filter_map + (fun loop_t -> + match loop_t with + | Pack_trans (tensor, _) -> Some tensor + | _ -> None) + loop_tiles + in + let to_check_vec_tensors = + List.filter + (fun tensor -> not (List.mem tensor transposed_tensors)) + tensors_list + in + List.iter + (fun tensor -> P.check_vectorisation final_payload tensor) + to_check_vec_tensors; + (* end of the correctness checking on vectorization *) + List.iter + (fun d -> match P.find_dim d final_payload with - | Some dim_p -> - begin match P.D.is_closed dim_p, - P.D.incr dim_p, - List.assoc_eq Dim.equal d const_dim_sizes with - | true, _, _ -> () - (* TODO : to_int_opt is supposed to take a size_id * size list, - * we give it an empty list. This should be enough but is a bit dirty *) - | false, inc, Some dim_size when (Expr.to_int_opt [] inc = Some dim_size) -> () - | false, _, None -> - failwith (Printf.sprintf "Dim %s is not closed and no dim size was given" - (Dim.show_id_of_t d)) - | false, Expr.Const inc, Some dim_size -> - failwith (Printf.sprintf "Dim %s is not closed and current tile size is %d while given size is %d" - (Dim.show_id_of_t d) inc dim_size) - | is_closed, incr, expected_incr -> - failwith - (Printf.sprintf "Inconsistent state with Dim %s: closed %b, incr : %s, expected incr : %s" - (Dim.show_id_of_t d) is_closed (Expr.show incr) ([%show:int option] expected_incr) - ) - end - | None -> - match List.assoc_eq Dim.equal d const_dim_sizes with - | Some size when size = 1 -> () - | Some size -> failwith @@ - Printf.sprintf "Dim %s does not appear in tile scheme but was expected of size %d" - (Dim.show_id_of_t d) size - | None -> failwith @@ - Printf.sprintf "Dim %s does not appear in tile scheme " - (Dim.show_id_of_t d) - ) - dim_list in - let finalize final_payload = + | Some dim_p -> ( + match + ( P.D.is_closed dim_p, + P.D.incr dim_p, + List.assoc_eq Dim.equal d const_dim_sizes ) + with + | true, _, _ -> () + (* TODO : to_int_opt is supposed to take a size_id * size list, + * we give it an empty list. This should be enough but is a bit dirty *) + | false, inc, Some dim_size + when Expr.to_int_opt [] inc = Some dim_size -> + () + | false, _, None -> + failwith + (Printf.sprintf + "Dim %s is not closed and no dim size was given" + (Dim.show_id_of_t d)) + | false, Expr.Const inc, Some dim_size -> + failwith + (Printf.sprintf + "Dim %s is not closed and current tile size is %d while \ + given size is %d" + (Dim.show_id_of_t d) inc dim_size) + | is_closed, incr, expected_incr -> + failwith + (Printf.sprintf + "Inconsistent state with Dim %s: closed %b, incr : %s, \ + expected incr : %s" + (Dim.show_id_of_t d) is_closed (Expr.show incr) + ([%show: int option] expected_incr))) + | None -> ( + match List.assoc_eq Dim.equal d const_dim_sizes with + | Some size when size = 1 -> () + | Some size -> + failwith + @@ Printf.sprintf + "Dim %s does not appear in tile scheme but was expected \ + of size %d" + (Dim.show_id_of_t d) size + | None -> + failwith + @@ Printf.sprintf "Dim %s does not appear in tile scheme " + (Dim.show_id_of_t d))) + dim_list + in + let finalize final_payload = sanitize final_payload; - let init_unclosed_dims = List.fold_left (fun buf d -> - match P.find_dim d final_payload with - Some dp -> - if P.D.is_closed dp then buf - else let last_index = P.D.current_index dp in - let init = Printf.sprintf "int %s = 0;\n" - @@ Index.show_id_of_t last_index in - buf ^ init - | None -> - let init = Printf.sprintf "int %s = 0;\n" - @@ Index.show_id_of_t @@ Index.from_dim d in - buf ^ init - ) "" dim_list in + let init_unclosed_dims = + List.fold_left + (fun buf d -> + match P.find_dim d final_payload with + | Some dp -> + if P.D.is_closed dp then buf + else + let last_index = P.D.current_index dp in + let init = + Printf.sprintf "int %s = 0;\n" + @@ Index.show_id_of_t last_index + in + buf ^ init + | None -> + let init = + Printf.sprintf "int %s = 0;\n" + @@ Index.show_id_of_t @@ Index.from_dim d + in + buf ^ init) + "" dim_list + in let tensor_decl, tensor_free = - P.fold_tens (fun _ tp (list_decl, list_free) -> + P.fold_tens + (fun _ tp (list_decl, list_free) -> let decl, free = P.Te.decl_alloc_free_tens tp in - decl::list_decl, free::list_free) ([], []) final_payload in - let tensor_decl = Printf.sprintf "%s\n" (String.concat "\n" tensor_decl) in - let tensor_free = Printf.sprintf "\n%s\n" (String.concat "\n" tensor_free) in - let guards = List.fold_left (fun buf d -> - match P.find_dim d final_payload with - | Some dp -> - let dim_size = P.D.dim dp |> Dim.size_id in - let last_incr = Option.default Expr.zero (P.D.last_tile_size dp) in - let div_guard = if Expr.equal Expr.zero last_incr then "" else - Printf.sprintf" && (%s %% %s == 0)" (Size.show_id dim_size) - (Expr.show last_incr) in - (match List.assoc_eq Dim.equal d const_dim_sizes with - Some size -> - let dim_size_s =(Size.show_id dim_size) in - Printf.sprintf "%sassert((%s == %d)%s);\n" buf - dim_size_s size div_guard - | None -> - Printf.sprintf "%sassert((%s <= %s)%s);\n" buf - (Expr.show last_incr) (Size.show_id dim_size) div_guard - ) - | None -> - (match List.assoc_eq Dim.equal d const_dim_sizes with - Some _ -> - let dim_size = Dim.size_id d in - let dim_size_s =(Size.show_id dim_size) in - Printf.sprintf "%sassert((%s == 1));\n" buf - dim_size_s - | None -> failwith "What ?" - ) - ) - "" dim_list in - tensor_decl, tensor_free, guards, init_unclosed_dims in + (decl :: list_decl, free :: list_free)) + ([], []) final_payload + in + let tensor_decl = + Printf.sprintf "%s\n" (String.concat "\n" tensor_decl) + in + let tensor_free = + Printf.sprintf "\n%s\n" (String.concat "\n" tensor_free) + in + let guards = + List.fold_left + (fun buf d -> + match P.find_dim d final_payload with + | Some dp -> ( + let dim_size = P.D.dim dp |> Dim.size_id in + let last_incr = + Option.default Expr.zero (P.D.last_tile_size dp) + in + let div_guard = + if Expr.equal Expr.zero last_incr then "" + else + Printf.sprintf " && (%s %% %s == 0)" (Size.show_id dim_size) + (Expr.show last_incr) + in + match List.assoc_eq Dim.equal d const_dim_sizes with + | Some size -> + let dim_size_s = Size.show_id dim_size in + Printf.sprintf "%sassert((%s == %d)%s);\n" buf dim_size_s + size div_guard + | None -> + Printf.sprintf "%sassert((%s <= %s)%s);\n" buf + (Expr.show last_incr) (Size.show_id dim_size) div_guard) + | None -> ( + match List.assoc_eq Dim.equal d const_dim_sizes with + | Some _ -> + let dim_size = Dim.size_id d in + let dim_size_s = Size.show_id dim_size in + Printf.sprintf "%sassert((%s == 1));\n" buf dim_size_s + | None -> failwith "What ?")) + "" dim_list + in + (tensor_decl, tensor_free, guards, init_unclosed_dims) + in (* We are going to have an helper payload structure that * specifies the index to use, the increment, etc.*) - let gen_loop_fold payload = - function - | External_call {fixed_size; var_dim; max_range; tensors_list; - name; dynamic_sizes;} -> - P.external_call payload name var_dim max_range tensors_list fixed_size - dynamic_sizes - | Lambda_apply (dim, iter_args) -> - P.lambda_apply payload dim iter_args + let gen_loop_fold payload = function + | External_call + { fixed_size; var_dim; max_range; tensors_list; name; dynamic_sizes } + -> + P.external_call payload name var_dim max_range tensors_list fixed_size + dynamic_sizes + | Lambda_apply (dim, iter_args) -> P.lambda_apply payload dim iter_args | V dim -> - P.vectorize payload dim - | U (unroll_size, dim) -> - P.unroll payload dim (Some unroll_size) - | ULambda dim -> - P.unroll payload dim None + let payload = { payload with vectorize_dim = Some dim } in + P.vectorize payload dim + | U (unroll_size, dim) -> P.unroll payload dim (Some unroll_size) + | ULambda dim -> P.unroll payload dim None | T (tile_size, dim) -> P.tile_on_dim dim_list payload dim (Some tile_size) | T_par (tile_size, dim) -> P.tile_on_dim ~par:true dim_list payload dim (Some tile_size) - | TLambda dim -> - P.tile_on_dim dim_list payload dim None - | Hoist_vars dim_list -> - P.hoist_var payload dim_list + | Fused_T_pars size_dim_list -> + P.fuse_parallel_loops size_dim_list payload + | TLambda dim -> P.tile_on_dim dim_list payload dim None + | Hoist_vars dim_list -> P.hoist_var payload dim_list | Pack_trans (tensor, transpose_list) -> - handle_pack payload tensor (Some transpose_list) - | Pack_tens tensor -> - handle_pack payload tensor None + handle_pack payload tensor (Some transpose_list) + | Pack_tens tensor -> handle_pack payload tensor None | Tile_partial (size, dim) -> P.tile_generalized_exact_partial dim_list payload true false dim size | Tile_gexact (n_iter, dim) -> - P.tile_generalized_exact_partial dim_list payload true true dim n_iter + P.tile_generalized_exact_partial dim_list payload true true dim n_iter | Tile_exact (n_iter, dim) -> - P.tile_generalized_exact_partial dim_list payload false true dim n_iter + P.tile_generalized_exact_partial dim_list payload false true dim + n_iter | R dim -> - let payload, (cur_index, div_const, boundvar, incr) = P.modify_map_dim dim (function - | Some dp -> - Some (P.D.close dp), - (P.D.current_index dp, P.D.div_constraint dp,P.D.cur_bound_var dp, P.D.incr dp) - | None -> - Some (P.D.close @@ P.D.default dim), - (Index.from_dim dim, Dim_info.Div, "", Expr.one) - ) - payload in - let size_dim = (Expr.size (Dim.size_id dim)) in - let start = Expr.zero - and halt = size_dim in - let loop_comment = Printf.sprintf "R %s (%s / %s)" (Dim.show_id_of_t dim) - (Expr.show size_dim) (Expr.show incr) - and state = Payload.show_current_state dim_list payload in - let comment = [state; loop_comment] in - let new_loop =let open Dim_info in match div_const with - Div -> Zip.new_seq ~comment cur_index start halt incr - | Flexible _ -> let vars_decl = [boundvar, - Expr.(I.(min (incr) - (size ( Dim.size_id dim) - - index cur_index)))] in - Zip.new_seq ~comment ~vars_decl cur_index start halt incr - in - { payload with loop_nest = P.eliminate_wait_tile dim (Expr.SizeVar (Dim.size_id dim)) - @@ P.compose_lnest new_loop payload.loop_nest} in - (*End aux function *) + let payload, (cur_index, div_const, boundvar, incr) = + P.modify_map_dim dim + (function + | Some dp -> + ( Some (P.D.close dp), + ( P.D.current_index dp, + P.D.div_constraint dp, + P.D.cur_bound_var dp, + P.D.incr dp ) ) + | None -> + ( Some (P.D.close @@ P.D.default dim), + (Index.from_dim dim, Dim_info.Div, "", Expr.one) )) + payload + in + let size_dim = Expr.size (Dim.size_id dim) in + let start = Expr.zero and halt = size_dim in + let loop_comment = + Printf.sprintf "R %s (%s / %s)" (Dim.show_id_of_t dim) + (Expr.show size_dim) (Expr.show incr) + and state = Payload.show_current_state dim_list payload in + let comment = [ state; loop_comment ] in + let new_loop = + let open Dim_info in + match div_const with + | Div -> Zip.new_seq ~comment cur_index start halt incr + | Flexible _ -> + let vars_decl = + [ + ( boundvar, + Expr.( + I.(min incr (size (Dim.size_id dim) - index cur_index))) + ); + ] + in + Zip.new_seq ~comment ~vars_decl cur_index start halt incr + in + { + payload with + loop_nest = + P.eliminate_wait_tile dim (Expr.SizeVar (Dim.size_id dim)) + @@ P.compose_lnest new_loop payload.loop_nest; + } + in - let payload = P.init dim_list [inst] in - let payload = List.fold_left gen_loop_fold payload loop_tiles in - let _par_dims = P.fold_dims (fun d dp l -> if P.D.par dp then d::l else l) [] payload in - let tensor_decl, tensor_free, guards, init_unclosed_dims = finalize payload in + (*End aux function *) + let payload = P.init dim_list [ inst ] in + let payload = List.fold_left gen_loop_fold payload loop_tiles in + let _par_dims = + P.fold_dims (fun d dp l -> if P.D.par dp then d :: l else l) [] payload + in + + let tiles_with_parallel = + List.drop_while + (fun tile -> + match tile with T_par _ | Fused_T_pars _ -> false | _ -> true) + loop_tiles + in + check_parallel tiles_with_parallel red_dim_list; + let n_par_iters = + List.fold + (fun acc tile -> + acc + * + match tile with + | T_par (x, _) -> x + | Fused_T_pars list -> List.fold (fun prod (i, _) -> prod * i) 1 list + | _ -> 1) + 1 tiles_with_parallel + in + let is_fused = + Option.is_some + (List.find + (function Fused_T_pars _ -> true | _ -> false) + tiles_with_parallel) + in + let par_header, par_decl, par_footer = + gen_parallel_utils n_par_iters is_fused + in + let tensor_decl, tensor_free, guards, init_unclosed_dims = + finalize payload + in let indexes, loop_nest = P.finalize_loop_nest payload in - let list_indexes = (indexes @ P.all_indexes payload) |> List.sort_unique Index.compare in - let indexes_decl = String.concat ", " @@ List.map (Index.show_id % Index.id) list_indexes - |> Printf.sprintf "int %s;\n" in - let final_loop = P.fold_dims (fun _ dp loop_nest -> - let aux = P.D.indexes_list dp - |> List.map (fun (ind, _, l) -> ind, l) in - List.fold_left (fun ln (ind, l) -> LN.set_aux ind l ln) loop_nest aux - ) - (Zip.to_tree loop_nest) payload in - (Printf.sprintf "/*\n%s\n*/\n" ([%show: loop_type list] loop_tiles)), - indexes_decl, tensor_decl, init_unclosed_dims, tensor_free, guards, - final_loop - - module Conv_build(Conv_args: Conv_args_t) = struct + let list_indexes = + indexes @ P.all_indexes payload |> List.sort_unique Index.compare + in + let indexes_decl = + String.concat ", " @@ List.map (Index.show_id % Index.id) list_indexes + |> Printf.sprintf "int %s;\n" + in + let final_loop = + P.fold_dims + (fun _ dp loop_nest -> + let aux = + P.D.indexes_list dp |> List.map (fun (ind, _, l) -> (ind, l)) + in + List.fold_left (fun ln (ind, l) -> LN.set_aux ind l ln) loop_nest aux) + (Zip.to_tree loop_nest) payload + in + ( Printf.sprintf "/*\n%s\n*/\n" ([%show: loop_type list] loop_tiles), + indexes_decl, + tensor_decl, + init_unclosed_dims, + tensor_free, + guards, + final_loop, + par_header, + par_decl, + par_footer ) + + module Conv_build (Conv_args : Conv_args_t) = struct include Conv_args (* * Output[i,j,f] += Input[i + w, j + h, c] * Params[w, h, c, f] * *) - let loop_body = let open Inst in - let from_dim dim = Dim.id dim, Expr.index @@ Index.from_dim dim in - let read_out = Read (output, [from_dim x_dim; from_dim y_dim; from_dim f_dim]) - and read_1 = Read (input, [from_dim x_dim ; from_dim w_dim; from_dim y_dim ; from_dim h_dim; from_dim c_dim]) - and read_2 = Read (params, [from_dim c_dim; from_dim f_dim; from_dim h_dim; from_dim w_dim]) in + let loop_body = + let open Inst in + let from_dim dim = (Dim.id dim, Expr.index @@ Index.from_dim dim) in + let read_out = + Read (output, [ from_dim x_dim; from_dim y_dim; from_dim f_dim ]) + and read_1 = + Read + ( input, + [ + from_dim x_dim; + from_dim w_dim; + from_dim y_dim; + from_dim h_dim; + from_dim c_dim; + ] ) + and read_2 = + Read + ( params, + [ from_dim c_dim; from_dim f_dim; from_dim h_dim; from_dim w_dim ] + ) + in let contract = Add (read_out, Mul (read_1, read_2)) in - Write (contract, output, [from_dim x_dim; from_dim y_dim; from_dim f_dim]) + Write + (contract, output, [ from_dim x_dim; from_dim y_dim; from_dim f_dim ]) - let dim_list = [y_dim; x_dim; h_dim; w_dim; c_dim; f_dim;] + let dim_list = [ y_dim; x_dim; h_dim; w_dim; c_dim; f_dim ] + let red_dim_list = [ h_dim; w_dim; c_dim ] - let gen_code ?(dim_sizes=[]) tile_scheme = + let gen_code ?(dim_sizes = []) tile_scheme = reset (); - let tile_scheme, indexes_decl, tensor_decl, init_dims, tensor_free, - guards, loop = - gen_loop ~const_dim_sizes:dim_sizes dim_list loop_body tile_scheme in - let loop_code = LN.gen_code [] loop in + let ( tile_scheme, + indexes_decl, + tensor_decl, + init_dims, + tensor_free, + guards, + loop, + par_header, + par_decl, + par_footer ) = + gen_loop ~const_dim_sizes:dim_sizes dim_list red_dim_list loop_body + tile_scheme + in + let loop_code = LN.gen_code loop in let buffer = Buffer.create 2000 in - let () = List.iter (Buffer.add_string buffer) - [tile_scheme; indexes_decl; tensor_decl; guards; init_dims; loop_code; tensor_free] in + let () = + List.iter (Buffer.add_string buffer) + [ + tile_scheme; + par_header; + par_decl; + indexes_decl; + tensor_decl; + guards; + init_dims; + loop_code; + tensor_free; + par_footer; + ] + in Buffer.contents buffer end - - module MM_build(MM_args: MM_args_t) = struct + module MM_build (MM_args : MM_args_t) = struct include MM_args - let loop_body = let open Inst in - let from_dim dim = Dim.id dim, Expr.index @@ Index.from_dim dim in - let read_out = Read (c, [from_dim i_dim; from_dim j_dim]) - and read_1 = Read (a, [from_dim i_dim; from_dim k_dim]) - and read_2 = Read (b, [from_dim k_dim; from_dim j_dim]) in - let contract = Add (read_out, Mul (read_1, read_2)) in - Write (contract, c, [from_dim i_dim; from_dim j_dim]) + let loop_body = + let open Inst in + let from_dim dim = (Dim.id dim, Expr.index @@ Index.from_dim dim) in + let read_out = Read (c, [ from_dim i_dim; from_dim j_dim ]) + and read_1 = Read (a, [ from_dim i_dim; from_dim k_dim ]) + and read_2 = Read (b, [ from_dim k_dim; from_dim j_dim ]) in + let contract = Add (read_out, Mul (read_1, read_2)) in + Write (contract, c, [ from_dim i_dim; from_dim j_dim ]) - let dim_list = [i_dim; j_dim; k_dim] + let dim_list = [ i_dim; j_dim; k_dim ] + let red_dim_list = [ k_dim ] - let gen_code ?(dim_sizes=[]) tile_scheme = + let gen_code ?(dim_sizes = []) tile_scheme = reset (); - let tile_scheme, indexes_decl, tensor_decl, init_dims, tensor_free, - guards, loop = - gen_loop ~const_dim_sizes:dim_sizes dim_list loop_body tile_scheme in - let loop_code = LN.gen_code [] loop in + let ( tile_scheme, + indexes_decl, + tensor_decl, + init_dims, + tensor_free, + guards, + loop, + par_header, + par_decl, + par_footer ) = + gen_loop ~const_dim_sizes:dim_sizes dim_list red_dim_list loop_body + tile_scheme + in + let loop_code = LN.gen_code loop in let buffer = Buffer.create 2000 in - let () = List.iter (Buffer.add_string buffer) - [tile_scheme; indexes_decl; tensor_decl; guards; init_dims; loop_code; tensor_free] in + let () = + List.iter (Buffer.add_string buffer) + [ + tile_scheme; + par_header; + par_decl; + indexes_decl; + tensor_decl; + guards; + init_dims; + loop_code; + tensor_free; + par_footer; + ] + in Buffer.contents buffer end - module TC_build(TC_args: TC_args_t) = struct + module TC_build (TC_args : TC_args_t) = struct include TC_args - let loop_body = let open Inst in - let from_dim dim = Dim.id dim, Expr.index @@ Index.from_dim dim in + let loop_body = + let open Inst in + let from_dim dim = (Dim.id dim, Expr.index @@ Index.from_dim dim) in let out_access = List.map from_dim left @ List.map from_dim right in - let read_out = Read (c, out_access) - and read_1 = Read (a, List.map from_dim left @ List.map from_dim red) - and read_2 = Read (b, List.map from_dim red @ List.map from_dim right) in - let contract = Add (read_out, Mul (read_1, read_2)) in - Write (contract, c, out_access) + let read_out = Read (c, out_access) + and read_1 = Read (a, List.map from_dim left @ List.map from_dim red) + and read_2 = Read (b, List.map from_dim red @ List.map from_dim right) in + let contract = Add (read_out, Mul (read_1, read_2)) in + Write (contract, c, out_access) let dim_list = left @ right @ red - let gen_code ?(dim_sizes=[]) tile_scheme = + let gen_code ?(dim_sizes = []) tile_scheme = reset (); - let tile_scheme, indexes_decl, tensor_decl, init_dims, tensor_free, - guards, loop = - gen_loop ~const_dim_sizes:dim_sizes dim_list loop_body tile_scheme in - let loop_code = LN.gen_code [] loop in + let ( tile_scheme, + indexes_decl, + tensor_decl, + init_dims, + tensor_free, + guards, + loop, + par_header, + par_decl, + par_footer ) = + gen_loop ~const_dim_sizes:dim_sizes dim_list red loop_body tile_scheme + in + let loop_code = LN.gen_code loop in let buffer = Buffer.create 2000 in - let () = List.iter (Buffer.add_string buffer) - [tile_scheme; indexes_decl; tensor_decl; guards; init_dims; loop_code; tensor_free] in + let () = + List.iter (Buffer.add_string buffer) + [ + tile_scheme; + par_header; + par_decl; + indexes_decl; + tensor_decl; + guards; + init_dims; + loop_code; + tensor_free; + par_footer; + ] + in Buffer.contents buffer end end diff --git a/ml/lib/tensor_transform.mli b/ml/lib/tensor_transform.mli index 0a768385b844943c09dbc1285cbbcbeaa87dc5c0..725af937c3c1b090057a39cc3a7a919b927c58e5 100644 --- a/ml/lib/tensor_transform.mli +++ b/ml/lib/tensor_transform.mli @@ -5,57 +5,67 @@ open Kernels_sign type iter = Iter of int [@@deriving show] type arg = Arg of int [@@deriving show] -module Tensor_tile(A: Arch.Vec_arch_t): sig - module Inst: Inst_t with module A = A - module LN: Loop_nest.Loopnest_t with module Inst := Inst +module Tensor_tile (A : Arch.Vec_arch_t) : sig + module Inst : Inst_t with module A = A + module LN : Loop_nest.Loopnest_t with module Inst := Inst - type ext_call_arg = {name: string; - fixed_size : (Dim.t * int) list; - var_dim: Dim.t; - dynamic_sizes : (Dim.t * int) list; - tensors_list : Tensor.t list; - max_range: int} [@@deriving show] + type ext_call_arg = { + name : string; + fixed_size : (Dim.t * int) list; + var_dim : Dim.t; + dynamic_sizes : (Dim.t * int) list; + tensors_list : Tensor.t list; + max_range : int; + } + [@@deriving show] (* Invariant : Some transformers impose divisibility, such as V/U/Ulambda * Tile_exact n is legal if and only if n is dividible by the current * divisibilty constraint *) - type loop_type = U of int * Dim.t - | V of Dim.t - | External_call of ext_call_arg - (* Tile_partial supports a dynamic bound *) - | Tile_partial of int * Dim.t - | Tile_gexact of int * Dim.t - (* Tile_exact declares a dynamic variable that can be used as a - * dynamic bound *) - | Tile_exact of int * Dim.t (* Tile_Exact (n, d) does exactly n - iterations - no multiplication *) - | ULambda of Dim.t - | TLambda of Dim.t - | Lambda_apply of Dim.t * (iter * arg) list - | T of int * Dim.t - | T_par of int * Dim.t - | Pack_tens of Tensor.t - | Pack_trans of Tensor.t * Dim.t list - | Hoist_vars of Dim.t list - | R of Dim.t [@@deriving show] + type loop_type = + | U of int * Dim.t + | V of Dim.t + | External_call of ext_call_arg + (* Tile_partial supports a dynamic bound *) + | Tile_partial of int * Dim.t + | Tile_gexact of int * Dim.t + (* Tile_exact declares a dynamic variable that can be used as a + * dynamic bound *) + | Tile_exact of int * Dim.t + (* Tile_Exact (n, d) does exactly n + iterations - no multiplication *) + | ULambda of Dim.t + | TLambda of Dim.t + | Lambda_apply of Dim.t * (iter * arg) list + | T of int * Dim.t + | T_par of int * Dim.t + | Fused_T_pars of (int * Dim.t) list + | Pack_tens of Tensor.t + | Pack_trans of Tensor.t * Dim.t list + | Hoist_vars of Dim.t list + | R of Dim.t + [@@deriving show] type tile_scheme = loop_type list - val tile_to_string: tile_scheme -> string - val dim_tile_size: tile_scheme -> Dim.t -> int + val tile_to_string : tile_scheme -> string + val dim_tile_size : tile_scheme -> Dim.t -> int - module MM_build(MM_args: MM_args_t) : sig + module MM_build (MM_args : MM_args_t) : sig include MM_args_t - val gen_code: ?dim_sizes:((Dim.t * int) list) -> tile_scheme -> string + + val gen_code : ?dim_sizes:(Dim.t * int) list -> tile_scheme -> string end - module Conv_build(Conv_args: Conv_args_t) : sig + module Conv_build (Conv_args : Conv_args_t) : sig include Conv_args_t - val gen_code: ?dim_sizes:((Dim.t * int) list) -> tile_scheme -> string + + val gen_code : ?dim_sizes:(Dim.t * int) list -> tile_scheme -> string end - module TC_build(TC_args: TC_args_t) : sig + module TC_build (TC_args : TC_args_t) : sig include TC_args_t - val gen_code: ?dim_sizes:((Dim.t * int) list) -> tile_scheme -> string + + val gen_code : ?dim_sizes:(Dim.t * int) list -> tile_scheme -> string end end diff --git a/ml/lib/utils.ml b/ml/lib/utils.ml index af1185c18282d9f92dc795ba36105f726dea0661..4b16aebf6b9ac05085973370186033f3bf20a724 100644 --- a/ml/lib/utils.ml +++ b/ml/lib/utils.ml @@ -1,166 +1,189 @@ include Batteries module Opt_syntax = struct - let (let+) = fun opt f -> Option.map f opt + let ( let+ ) opt f = Option.map f opt - let (and+) = fun opt1 opt2 -> match opt1, opt2 with - | Some x, Some y -> Some (x, y) - | _ -> None + let ( and+ ) opt1 opt2 = + match (opt1, opt2) with Some x, Some y -> Some (x, y) | _ -> None - let (let*) = Option.bind - let (and*) = (and+) + let ( let* ) = Option.bind + let ( and* ) = ( and+ ) end let bimap (f, g) (a, b) = (f a, g b) -let map_fst f = bimap (f, Fun.id) -let map_snd f = bimap (Fun.id, f) +let map_fst f = bimap (f, Fun.id) +let map_snd f = bimap (Fun.id, f) module Either = struct type ('a, 'b) t = Left of 'a | Right of 'b [@@deriving show] let left x = Left x - let right x = Right x - - let map_left f = function - | Left a -> Left (f a) - | r -> r + let map_left f = function Left a -> Left (f a) | r -> r let map_left_or_fail f = function | Left a -> Left (f a) | _ -> failwith "Supposed to be called on left" - let apply_left_exn: exn -> ('a -> 'b -> 'c) -> ('a, 'd) t -> 'b -> ('c, 'd) t = - fun error f either arg -> match either with - | Left a -> Left (f a arg) - | Right _ -> raise error - - let apply_left2_exn: exn -> ('a -> 'arg1 -> 'arg2 -> 'c) -> ('a, 'd) t -> 'arg1 -> - 'arg2 -> ('c, 'd) t = - fun error f either arg1 arg2 -> match either with - | Left a -> Left (f a arg1 arg2) - | Right _ -> raise error - - let apply_left3_exn: exn -> ('a -> 'arg1 -> 'arg2-> 'arg3 -> 'c) - -> ('a, 'd) t -> 'arg1 -> 'arg2 -> 'arg3 -> ('c, 'd) t = - fun error f either arg1 arg2 arg3 -> match either with - | Left a -> Left (f a arg1 arg2 arg3) - | Right _ -> raise error - - let apply_left4_exn: exn -> ('a -> 'arg1 -> 'arg2-> 'arg3 -> 'arg4 -> 'c) - -> ('a, 'd) t -> 'arg1 -> 'arg2 -> 'arg3 -> 'arg4 -> ('c, 'd) t = - fun error f either arg1 arg2 arg3 arg4 -> match either with - | Left a -> Left (f a arg1 arg2 arg3 arg4) - | Right _ -> raise error - - let map_right f = function - | Right b -> Left (f b) - | l -> l + let apply_left_exn : exn -> ('a -> 'b -> 'c) -> ('a, 'd) t -> 'b -> ('c, 'd) t + = + fun error f either arg -> + match either with Left a -> Left (f a arg) | Right _ -> raise error + + let apply_left2_exn : + exn -> + ('a -> 'arg1 -> 'arg2 -> 'c) -> + ('a, 'd) t -> + 'arg1 -> + 'arg2 -> + ('c, 'd) t = + fun error f either arg1 arg2 -> + match either with Left a -> Left (f a arg1 arg2) | Right _ -> raise error + + let apply_left3_exn : + exn -> + ('a -> 'arg1 -> 'arg2 -> 'arg3 -> 'c) -> + ('a, 'd) t -> + 'arg1 -> + 'arg2 -> + 'arg3 -> + ('c, 'd) t = + fun error f either arg1 arg2 arg3 -> + match either with + | Left a -> Left (f a arg1 arg2 arg3) + | Right _ -> raise error + + let apply_left4_exn : + exn -> + ('a -> 'arg1 -> 'arg2 -> 'arg3 -> 'arg4 -> 'c) -> + ('a, 'd) t -> + 'arg1 -> + 'arg2 -> + 'arg3 -> + 'arg4 -> + ('c, 'd) t = + fun error f either arg1 arg2 arg3 arg4 -> + match either with + | Left a -> Left (f a arg1 arg2 arg3 arg4) + | Right _ -> raise error + + let map_right f = function Right b -> Left (f b) | l -> l let map_right_or_fail f = function | Right b -> Right (f b) | _ -> failwith "Supposed to be called on right" - let apply_right_exn: exn -> ('a -> 'b -> 'c) -> ('d, 'a) t -> 'b -> ('d, 'c) t = - fun error f either arg -> match either with - | Right a -> Right (f a arg) - | Left _ -> raise error - - let bimap f g = function - | Left a -> Left (f a) - | Right b -> Right (g b) + let apply_right_exn : + exn -> ('a -> 'b -> 'c) -> ('d, 'a) t -> 'b -> ('d, 'c) t = + fun error f either arg -> + match either with Right a -> Right (f a arg) | Left _ -> raise error + let bimap f g = function Left a -> Left (f a) | Right b -> Right (g b) (* ('a -> ('b, 'c) Either.t) -> 'a list -> ('b list, 'c list) Either.t list *) let partition_consecutive f seg_list = let rec segregate current_list acc_list proceeding = - match current_list, proceeding with + match (current_list, proceeding) with | Left current_list, [] -> - List.rev (Left (List.rev current_list)::acc_list) + List.rev (Left (List.rev current_list) :: acc_list) | Right current_list, [] -> - List.rev (Right (List.rev current_list)::acc_list) - | Left current_list, head::tail -> begin match f head with - | Left head -> segregate (Left (head::current_list)) acc_list tail - | Right head -> segregate (Right [head]) (Left (List.rev current_list)::acc_list) tail - end - | Right current_list, head::tail -> begin match f head with - | Right head -> segregate (Right (head::current_list)) acc_list tail - | Left head -> segregate (Left [head]) (Right (List.rev current_list)::acc_list) tail - end in + List.rev (Right (List.rev current_list) :: acc_list) + | Left current_list, head :: tail -> ( + match f head with + | Left head -> segregate (Left (head :: current_list)) acc_list tail + | Right head -> + segregate (Right [ head ]) + (Left (List.rev current_list) :: acc_list) + tail) + | Right current_list, head :: tail -> ( + match f head with + | Right head -> segregate (Right (head :: current_list)) acc_list tail + | Left head -> + segregate (Left [ head ]) + (Right (List.rev current_list) :: acc_list) + tail) + in match seg_list with | [] -> [] - | head::tail -> match f head with - | Left head -> segregate (Left [head]) [] tail - | Right head -> segregate (Right [head]) [] tail + | head :: tail -> ( + match f head with + | Left head -> segregate (Left [ head ]) [] tail + | Right head -> segregate (Right [ head ]) [] tail) end -type ('a, 'b) either = ('a, 'b) Either.t = Left of 'a | Right of 'b +type ('a, 'b) either = ('a, 'b) Either.t = Left of 'a | Right of 'b module Option = struct include Option - let from_cond_val cond x = - if cond then Some x else None + + let from_cond_val cond x = if cond then Some x else None end module List = struct include List include List.Exceptionless + let insert_at_end l elt = let rec insert acc = function - | [] -> List.rev (elt::acc) - | head::tail -> insert (head::acc) tail in + | [] -> List.rev (elt :: acc) + | head :: tail -> insert (head :: acc) tail + in insert [] l - let rec assoc_eq eq a list_a = match list_a with + let rec assoc_eq eq a list_a = + match list_a with | [] -> None - | (a', ret)::_ when eq a' a -> Some ret - | _::tail -> assoc_eq eq a tail + | (a', ret) :: _ when eq a' a -> Some ret + | _ :: tail -> assoc_eq eq a tail - let rec mem_eq eq a list_a = match list_a with + let rec mem_eq eq a list_a = + match list_a with | [] -> false - | head::tail -> eq a head || mem_eq eq a tail + | head :: tail -> eq a head || mem_eq eq a tail - let remove_assoc eq a list_a: 'b option * ('a * 'b) list = + let remove_assoc eq a list_a : 'b option * ('a * 'b) list = let rec remove accum a = function - | [] -> None, List.rev accum - | (a', ret)::tail when eq a' a -> Some ret, List.rev_append accum tail - | head::tail -> remove (head::accum) a tail - in remove [] a list_a + | [] -> (None, List.rev accum) + | (a', ret) :: tail when eq a' a -> (Some ret, List.rev_append accum tail) + | head :: tail -> remove (head :: accum) a tail + in + remove [] a list_a - let modify_assoc eq f elt l = + let modify_assoc eq f elt l = let rec modify acc = function - | [] -> (match f None with - Some value -> List.rev ((elt, value)::acc) + | [] -> ( + match f None with + | Some value -> List.rev ((elt, value) :: acc) | None -> List.rev acc) - | (key, value)::tail when eq key elt -> - List.rev_append acc - (match f @@ Some value with - Some value -> (key, value)::tail - | None -> tail - ) - | (key, value)::tail -> modify ((key, value)::acc) tail in + | (key, value) :: tail when eq key elt -> + List.rev_append acc + (match f @@ Some value with + | Some value -> (key, value) :: tail + | None -> tail) + | (key, value) :: tail -> modify ((key, value) :: acc) tail + in modify [] l let split3 l = let rec split l1 l2 l3 = function - | [] -> List.rev l1, List.rev l2, List.rev l3 - | (x1, x2, x3)::tail -> split (x1::l1) (x2::l2) (x3::l3) tail in + | [] -> (List.rev l1, List.rev l2, List.rev l3) + | (x1, x2, x3) :: tail -> split (x1 :: l1) (x2 :: l2) (x3 :: l3) tail + in split [] [] [] l end module Array = struct include Array - let find_opt p arr = - try Some (Array.find p arr) with - Not_found -> None + let find_opt p arr = try Some (Array.find p arr) with Not_found -> None end let span p l = let rec loop pre_list = function - | [] -> List.rev pre_list, None, [] - | h::t when p h -> List.rev pre_list, Some h, t - | h::t -> loop (h::pre_list) t in + | [] -> (List.rev pre_list, None, []) + | h :: t when p h -> (List.rev pre_list, Some h, t) + | h :: t -> loop (h :: pre_list) t + in loop [] l (* fold_apply: 'a -> ('a -> 'a) list -> 'a *) @@ -169,41 +192,52 @@ let fold_apply init f_list = module type MAPPABLE = sig type 'a t - val return: 'a -> 'a t - val map: 'a t -> ('a -> 'b) -> 'b t + + val return : 'a -> 'a t + val map : 'a t -> ('a -> 'b) -> 'b t end + module type MONAD = sig type 'a t - val return: 'a -> 'a t - val bind: 'a t -> ('a -> 'b t) -> 'b t + + val return : 'a -> 'a t + val bind : 'a t -> ('a -> 'b t) -> 'b t end -module List_monad: MONAD with type 'a t = 'a list = struct +module List_monad : MONAD with type 'a t = 'a list = struct type 'a t = 'a list - let return x = [x] + + let return x = [ x ] let bind x f = List.concat_map f x end -module Mappable(M: MONAD): MAPPABLE with type 'a t = 'a M.t = struct +module Mappable (M : MONAD) : MAPPABLE with type 'a t = 'a M.t = struct include M + let map m f = bind m (fun a -> return (f a)) end module type STATE = sig type state + include MONAD + val get : state t val put : state -> unit t val runState : 'a t -> init:state -> state * 'a end -module State (S: sig type t end): STATE with type state = S.t = struct +module State (S : sig + type t +end) : STATE with type state = S.t = struct type state = S.t type 'a t = state -> state * 'a - let return v = fun s -> s, v + + let return v s = (s, v) let bind m k s = - let s', v = m s in k v s' + let s', v = m s in + k v s' let get s = (s, s) let put s _ = (s, ()) @@ -212,34 +246,43 @@ end module type SYNTAX = sig include MONAD - val ( let* ): 'a t -> ('a -> 'b t) -> 'b t - val ( let+ ): 'a t -> ('a -> 'b) -> 'b t + + val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t + val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t end -module MSyntax(M: sig include MONAD val map: 'a t -> ('a -> 'b) -> 'b t end): - SYNTAX with type 'a t = 'a M.t = struct +module MSyntax (M : sig + include MONAD + + val map : 'a t -> ('a -> 'b) -> 'b t +end) : SYNTAX with type 'a t = 'a M.t = struct include M + let ( let* ) = bind let ( let+ ) = map end module type CUSTOM_MAP = sig type ord + include Map.S with type key = ord - val find: key -> 'a t -> 'a option - val choose: 'a t -> (key * 'a) option - val any: 'a t -> (key * 'a) option - val modify_map: key -> ('a option -> 'a option * 'b) -> 'a t -> 'a t * 'b + + val find : key -> 'a t -> 'a option + val choose : 'a t -> (key * 'a) option + val any : 'a t -> (key * 'a) option + val modify_map : key -> ('a option -> 'a option * 'b) -> 'a t -> 'a t * 'b end -module Make_map(Ord: sig type t val compare: t -> t -> int end): - CUSTOM_MAP with type ord := Ord.t -= struct - module Map_exc = Map.Make(Ord) + +module Make_map (Ord : sig + type t + + val compare : t -> t -> int +end) : CUSTOM_MAP with type ord := Ord.t = struct + module Map_exc = Map.Make (Ord) include Map_exc include Map_exc.Exceptionless let modify_map key f dict = let new_dict_opt, value = f (find key dict) in - modify_opt key (Fun.const new_dict_opt) dict, value + (modify_opt key (Fun.const new_dict_opt) dict, value) end -