diff --git a/ml/lib/front.ml b/ml/lib/front.ml new file mode 100644 index 0000000000000000000000000000000000000000..b5c83638fd5d91e5240012e0d304ad68998d5b3b --- /dev/null +++ b/ml/lib/front.ml @@ -0,0 +1,47 @@ +open Ids +open Utils +module T = Tensor +type operator = Add | Mul + + +type fold = + Tens of T.t + | Val of float + | Red of operator * DimSet.t * t * t + | Join of operator * DimSet.t * t * t + +and t = DimSet.t * fold + +let tensor tens = + DimSet.of_list @@ T.dims_list tens, Tens tens + +let make_val dims value = + let dim_set = DimSet.of_list dims in + dim_set, Val value + +let join operator dims ((dims1, _) as a1) ((dims2, _) as a2) = + if not ( DimSet.subset dims dims1 && DimSet.subset dims dims2) + then None + else Some + (DimSet.union dims1 dims2, Join (operator, dims, a1, a2) ) + +let reduce operator dims ((base_dims, _) as base) ((arg_dims, _) as exp) = + if not ( DimSet.subset dims arg_dims) && (DimSet.equal base_dims (DimSet.diff arg_dims dims)) + then None + else Some + (base_dims, Red (operator, dims, base, exp) ) + +let dims (dims, _) = dims + +let matmul a b i j k = let open Option.Infix in + let k_set = DimSet.singleton k in + join Mul k_set (tensor a) (tensor b) + >>= reduce Add k_set (make_val [i; j] 0.) + +type 'a foldF = + TensF of T.t + | ValF of float + | RedF of operator * DimSet.t * 'a * 'a + | JoinF of operator * DimSet.t * 'a * 'a [@@deriving map] + +module FuncAlgebra = Algebra.A(struct type 'a t = 'a foldF let map = map_foldF end) diff --git a/ml/lib/front.mli b/ml/lib/front.mli new file mode 100644 index 0000000000000000000000000000000000000000..e0d9b3a153c83c4e13d250e48d8248e5c48cac77 --- /dev/null +++ b/ml/lib/front.mli @@ -0,0 +1,20 @@ +open Ids +type t + +type operator = Add | Mul + +val tensor: Tensor.t -> t +val make_val: Dim.t list -> float -> t +val join: operator -> DimSet.t -> t -> t -> t option +val reduce: operator -> DimSet.t -> t -> t -> t option +val dims: t -> DimSet.t + +type 'a foldF = + TensF of Tensor.t + | ValF of float + | RedF of operator * DimSet.t * 'a * 'a + | JoinF of operator * DimSet.t * 'a * 'a + +val matmul: Tensor.t -> Tensor.t -> Dim.t -> Dim.t -> Dim.t -> t option + +module FuncAlgebra : Algebra.Algebra_t with type 'a base := 'a foldF