diff --git a/ml/bin/main.ml b/ml/bin/main.ml index 065301997eae9dff5d90f6a8e3daed131fa293c5..1716a16db3c75a6607168f48dd01498540c4eff0 100644 --- a/ml/bin/main.ml +++ b/ml/bin/main.ml @@ -11,7 +11,7 @@ module Conv = T_tile.Conv_build(struct let c_dim = dim_gen ~name:"c" () let w_dim = dim_gen ~name:"w" () let h_dim = dim_gen ~name:"h" () - let input = Tensor.(make_join ~name:"I" [|join_dims x_dim w_dim; + let input = Tensor.(make_join ~name:"I" [|join_dims_stride x_dim w_dim 2; join_dims y_dim h_dim; single c_dim|]) let params = Tensor.make ~name:"P" [|w_dim;h_dim;c_dim;f_dim|] @@ -65,7 +65,7 @@ let trans_tile = let open T_tile in let open MM in [R i_dim; R k_dim; Pack_trans (a, [k_dim; i_dim]); R j_dim] let dull_tile_conv2 = let open T_tile in let open Conv in - [R f_dim; R x_dim; R y_dim; R w_dim; R c_dim; R h_dim] + [V f_dim; R f_dim; R x_dim; R y_dim; R w_dim; R c_dim; R h_dim] let unroll_tile = let open T_tile in let open MM in [U (2, j_dim); R i_dim; R j_dim; R k_dim] @@ -109,7 +109,6 @@ let tile_9x3= let open T_tile in let open MM in [V j_dim; U (3, j_dim); U (9, i_dim); T (148, k_dim); Hoist_vars [k_dim]; ] let () = - let loop_code = let open MM in - let dim_sizes = [i_dim, 9; j_dim, 48; k_dim, 148] in - MM.gen_code ~dim_sizes tile_9x3 in + let loop_code = let open Conv in + gen_code dull_tile_conv2 in print_endline loop_code diff --git a/ml/lib/pack_info.ml b/ml/lib/pack_info.ml index 9150b6d2ebf3b704fc0f44082cf85c4c523dc571..794a5720344ee3a5cf9c68eb8da1644cf435d613 100644 --- a/ml/lib/pack_info.ml +++ b/ml/lib/pack_info.ml @@ -56,7 +56,8 @@ module PI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct | Some key, Some dim_loop -> Some key, snd (RL.enclose loop_nest dim_loop) | _, None -> base_key, loop_nest end - | Join (main_dim, aux_dim, iter_dim) -> + (* TODO we may want to change something here - Or do we ? *) + | Join (main_dim, aux_dim, iter_dim, _) -> let main_id = Dim.id main_dim and aux_id = Dim.id aux_dim in let get_frozen snapshot d = Option.get @@ List.assoc_eq @@ -97,7 +98,7 @@ module PI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct let vectorize_dim id = let inner_dim = Tensor.(inner_dim pack_tensor |> function Single d -> d - | Join (d,_, _) -> d) in + | Join (d,_, _, _) -> d) in Dim.equal_id (Dim.id inner_dim) id in List.fold_left (fold_dims vectorize_dim small_snapshot big_snapshot) @@ -142,7 +143,7 @@ module PI(Inst: Inst_t)(LN: Loopnest_t with module Inst := Inst) = struct (* is it ok to vectorize on a join dim ? *) let get_inner_dim = Tensor.inner_dim %> function Single d -> d - | Join (d,_, _) -> d in + | Join (d,_, _, _) -> d in let inner_dim = get_inner_dim tensor in let to_vectorize = DimMap.find inner_dim dim_map |> function Some dmap -> begin match D.incr dmap with diff --git a/ml/lib/tensor.ml b/ml/lib/tensor.ml index ae51b06a5cd2bbef3eb8cf9f43c9529c10ca669e..1dba696e159acdae8746db34ae3141d3a9f4ffe1 100644 --- a/ml/lib/tensor.ml +++ b/ml/lib/tensor.ml @@ -13,7 +13,7 @@ type id = T_id.t [@@deriving show, ord, eq] (* A join dimension means that two distinct dimensions - meaning distinct loops * iterate on the same elements. Represents i + w for example *) -type t_dims = Single of Dim.t | Join of Dim.t * Dim.t * Dim.t [@@deriving eq, show] +type t_dims = Single of Dim.t | Join of Dim.t * Dim.t * Dim.t * int option [@@deriving eq, show] let single dim = Single dim @@ -22,7 +22,14 @@ let join_dims d1 d2 = and aux_name = Dim.show_id_of_t d2 in let name = main_name ^ aux_name in let iter_dim = Dim.fresh_gen ~name () () in - Join (d1, d2, iter_dim) + Join (d1, d2, iter_dim, None) + +let join_dims_stride d1 d2 str = + let main_name = Dim.show_id_of_t d1 + and aux_name = Dim.show_id_of_t d2 in + let name = main_name ^ aux_name in + let iter_dim = Dim.fresh_gen ~name () () in + Join (d1, d2, iter_dim, Some str) type t = {id: id; dims: t_dims array; sizes:(t_dims * CE.t) list; strides: (t_dims * CE.t) list} [@@deriving eq, show] @@ -34,7 +41,7 @@ type accesses = (Dim.id * VarExpr.t) list [@@deriving eq, show] let get_dim tensor dim = Array.find_opt (function Single d when Dim.equal d dim -> true - | Join (d1, d2, _) when Dim.equal dim d1 || Dim.equal dim d2 -> true + | Join (d1, d2, _, _) when Dim.equal dim d1 || Dim.equal dim d2 -> true | _ -> false) tensor.dims @@ -67,7 +74,7 @@ let t_dims {dims;_} = dims let t_dims_list {dims;_} = Array.to_list dims let dims_list {dims;_} = Array.fold_left (fun l -> function Single d -> d::l - | Join (d1, d2, _) -> d1::d2::l) + | Join (d1, d2, _, _) -> d1::d2::l) [] dims (* Should take dims from smaller stride to bigger @@ -109,7 +116,7 @@ let make_join ?name ?strides ?sizes (dims: t_dims array) = match sizes with | None -> List.map ( function Single d as sd -> sd, csize d - | Join (d1, d2, _) as jd -> + | Join (d1, d2, _, _) as jd -> jd, CE.(I.(csize d1 + csize d2 - const 1)) ) dims_list | Some sizes -> @@ -117,7 +124,7 @@ let make_join ?name ?strides ?sizes (dims: t_dims array) = Single d as sd -> sd, Option.default (csize d) @@ List.assoc sd sizes - | Join (d1, d2, _) as jd -> + | Join (d1, d2, _, _) as jd -> jd, Option.default (CE.(I.(csize d1 + csize d2 - const 1))) @@ List.assoc jd sizes ) @@ -143,7 +150,7 @@ let tens_does_access tensor dim = Array.exists (let check_dim d = Dim.equal_id (Dim.id d) dim in function Single d -> check_dim d - | Join (d1, d2, _) -> check_dim d1 || check_dim d2) + | Join (d1, d2, _, _) -> check_dim d1 || check_dim d2) tensor.dims (* && (List.assoc_opt dim accesses |> Option.map VE.is_constant @@ -167,7 +174,7 @@ let sort_dim dim_list d1 d2 = sort_dim dim_list let get_main_dim = function - | Join (d, _, _) -> d + | Join (d, _, _, _) -> d | Single d -> d let sort_tdim dim_list d1 d2 = @@ -184,7 +191,7 @@ let reorder_dim dim_list size_list = let reorder_layout tensor dim_list = let map_t_dim dim = match List.find - (function Single d | Join (d, _, _) -> + (function Single d | Join (d, _, _, _) -> Dim.equal dim d) (t_dims_list tensor) with | Some t_dim -> t_dim @@ -231,13 +238,15 @@ let gen_access tensor accesses = List.map (function Single d, stride -> Option.get @@ get_expr d accesses, stride - | Join (d1, d2, iter_dim), stride -> + | Join (d1, d2, iter_dim, str), stride -> let main = get_expr d1 accesses and aux = get_expr d2 accesses and iter = get_expr iter_dim accesses in + let mul_stride e = Option.map_default + (fun s -> CE.(I.(const s * e))) e str in match main, aux, iter with | Some main, Some aux, None -> - VE.add main aux, stride + VE.add (mul_stride main) aux, stride | None, None, Some iter -> iter, stride | Some _, _, Some _ | _, Some _, Some _ -> @@ -257,13 +266,13 @@ let compare_dims_id tensor dim1 dim2 = let rec cmp = function | (Single d)::_ when Dim.equal_id (Dim.id d) dim1 -> 1 | (Single d)::_ when Dim.equal_id (Dim.id d) dim2 -> -1 - | Join (d1, d2, _)::_ when (Dim.equal_id (Dim.id d1) dim1 + | Join (d1, d2, _, _)::_ when (Dim.equal_id (Dim.id d1) dim1 && Dim.equal_id (Dim.id d2) dim2) || (Dim.equal_id (Dim.id d2) dim1 && Dim.equal_id (Dim.id d1) dim2)-> 0 - | Join (d1, d2, _)::_ when Dim.equal_id (Dim.id d1) dim1 + | Join (d1, d2, _, _)::_ when Dim.equal_id (Dim.id d1) dim1 || Dim.equal_id (Dim.id d2) dim1 -> 1 - | Join (d1, d2, _)::_ when Dim.equal_id (Dim.id d1) dim2 + | Join (d1, d2, _, _)::_ when Dim.equal_id (Dim.id d1) dim2 || Dim.equal_id (Dim.id d2) dim2 -> -1 | _::t -> cmp t | [] -> failwith "Dim1 and dim2 are not present in tensor dims" in diff --git a/ml/lib/tensor.mli b/ml/lib/tensor.mli index a7988a01d0f4685e45320fc875dce00c2911e377..4174275e6766345dfee5d948980b3124f3565b2e 100644 --- a/ml/lib/tensor.mli +++ b/ml/lib/tensor.mli @@ -3,12 +3,13 @@ open Exprs type t [@@deriving eq, show] type id [@@deriving eq, ord, show] -type t_dims = Single of Dim.t | Join of Dim.t * Dim.t * Dim.t [@@deriving eq, show] +type t_dims = Single of Dim.t | Join of Dim.t * Dim.t * Dim.t * int option [@@deriving eq, show] type gen = unit -> t type accesses = (Dim.id * VarExpr.t) list [@@deriving eq, show] val single: Dim.t -> t_dims val join_dims: Dim.t -> Dim.t -> t_dims +val join_dims_stride: Dim.t -> Dim.t -> int -> t_dims val make: ?name:string -> ?strides:(Dim.t * ConstExpr.t) list -> ?sizes:(Dim.t * ConstExpr.t) list -> Dim.t array -> t val make_join: ?name:string -> ?strides:(t_dims * ConstExpr.t) list -> ?sizes:(t_dims * ConstExpr.t) list diff --git a/ml/lib/tensor_transform.ml b/ml/lib/tensor_transform.ml index e1e16f2a3262ef3cdd3b1023ad74dd762f75f1f0..9929d14a3d5cf3e03fb54599874eeb97b6efd3a3 100644 --- a/ml/lib/tensor_transform.ml +++ b/ml/lib/tensor_transform.ml @@ -323,13 +323,16 @@ module Tensor_tile(A: Arch.Vec_arch_t) = struct let d_opt = P.find_dim dim payload in let incr = Option.map_default P.D.incr CE.zero d_opt in (sd, incr) - | Join (d1, d2, _) as jd -> + | Join (d1, d2, _, stride) as jd -> let d1_opt = P.find_dim d1 payload in let d2_opt = P.find_dim d2 payload in + let mul_stride e = Option.map_default (fun s -> + CE.(I.(const s * e))) e stride in match Option.map P.D.incr d1_opt, Option.map P.D.incr d2_opt with | Some inc1, (Some (CE.SizeVar _ as size) | Some (CE.Const _ as size)) -> - (jd, CE.add inc1 size) + (* TODO TEST TEST TEST *) + (jd, CE.add (mul_stride inc1) size) | Some _, Some (_) -> (* TODO clean that, this is dirty *) failwith "Tiling on small dim !!!"